In [1]:
import torch
import numpy as np
import pandas as pd 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
import seaborn as sns

In [2]:
%run focalloss.ipynb
%run mfeloss.ipynb
%run msfeloss.ipynb

In [11]:
#https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html

class NeuralN(nn.Module):
    
    def __init__(self, input_dimension = None, output_dimension = None, hidden_layers=None, num_hidden_layers = None, hidden_dim = 64,
                 activation_default = "relu", threshold = 0.3,
                 activations = None, loss_method = "BCE", opt_method = "SGD", lr = 0.01, class_weights = None, alpha=None, data_type = None, gamma = None, epochs=None):
        
        super().__init__() 
        
        self.loss_method = loss_method
        self.opt_method = opt_method
        self.lr = lr 
        self.alpha = alpha 
        self.gamma = gamma 
        self.epochs = epochs
        self.input_dimension = input_dimension 
        self.output_dimension = output_dimension
        self.hidden_layers = hidden_layers #list
        self.num_hidden_layers = num_hidden_layers
        self.hidden_dim = hidden_dim #default
        self.activation_default = activation_default #default 
        self.activations = activations #list
        self.data_type = data_type
        self.threshold = threshold
        
        if class_weights is not None:
            self.class_weights = torch.tensor([class_weights[1], class_weights[0]], dtype= torch.float32)
        else: 
            self.class_weights = class_weights
        
        self.process = nn.ModuleList()

        layer = None
        if self.hidden_layers is not None:
            
            layer = [self.input_dimension] + self.hidden_layers + [self.output_dimension]

        else: 

            layer = [self.input_dimension] + [self.hidden_dim]*self.num_hidden_layers + [self.output_dimension]

        act = None
        if self.activations is not None:
            
            if len(self.activations) < (len(layer)):
                need = (len(layer)) - len(self.activations)
                act = self.activations + ["identity"]*need #Could be self.activation[0] or desired activation function. 
            
            else: 
                act = self.activations

        else: 

            act = [self.activation_default]* (len(layer) - 1) 

        for i in range(1, len(layer)):
            self.process.append(nn.Linear(layer[i-1], layer[i]))
            
            if i < (len(layer) - 1):
                self.process.append(self.get_activation(act[i-1]))
            
            elif i == (len(layer) - 1):

                if self.loss_method == "BCE": 
                    self.process.append(nn.Sigmoid())

                else: 
                    self.process.append(nn.Identity())

    
    def forward(self, x):
        #https://medium.com/data-scientists-diary/advanced-guide-to-using-nn-modulelist-in-pytorch-da4d49c109fc
        x = x.float()
        for m in self.process:
            x = m(x)
        return x

    def get_activation(self, activation_):

        if activation_ == "relu": 
            return nn.ReLU()

        elif activation_ == "tanh": 
            return nn.Tanh()

        elif activation_ == "identity": 
            return nn.Identity()
        
    def get_loss(self):
        
        
        if self.loss_method == "BCE":
            return nn.BCELoss()
        
        elif self.loss_method == "L1":
            return nn.L1Loss()
        
        elif self.loss_method == "MSE":
            return nn.MSELoss()

        elif self.loss_method == "CE":
            return nn.CrossEntropyLoss(weight = self.class_weights)

        elif self.loss_method == "BCEwLogit":
            if self.class_weights is not None:
                pos_weight = torch.tensor([self.class_weights[1] / self.class_weights[0]])
                return nn.BCEWithLogitsLoss(pos_weight = pos_weight)
            else: 
                return nn.BCEWithLogitsLoss()

        elif self.loss_method == "focal_loss": #https://medium.com/visionwizard/understanding-focal-loss-a-quick-read-b914422913e7
            return FocalLoss(alpha=self.alpha, gamma=self.gamma)

        elif self.loss_method == "MFE":
            return MFELoss()

        elif self.loss_method == "MSFE":
            return MSFELoss()

        else:
            raise ValueError(f"{self.loss_method} is not valid!")
    
    def get_optimizer(self):
        
        if self.opt_method == "SGD":
            return torch.optim.SGD(params = self.parameters(), lr = self.lr)

        elif self.opt_method == "Adam":
            return torch.optim.Adam(params = self.parameters(), lr = self.lr)

        elif self.opt_method == "RMSprop":
            return torch.optim.RMSprop(params = self.parameters(), lr = self.lr)

        else: 
            raise ValueError(f"{self.opt_method} is not valid!")
    
    def train_model(self, train_loader, val_loader, optimizer=None, loss_fn=None):#https://www.geeksforgeeks.org/how-to-implement-neural-networks-in-pytorch/

        loss_fn = self.get_loss()
        optimizer = self.get_optimizer()        
        size = len(train_loader.dataset)
        t_loss=[]
        val_loss = []
        for e in range(self.epochs):
            self.train()
            train_loss = 0
            for batch, (X, y) in enumerate(train_loader):
                
                y_logits = self(X).squeeze()
                
                loss = loss_fn(y_logits, y)
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
                train_loss += loss.item() * X.size(0)
            train_loss_ = train_loss/len(train_loader.dataset) 
            t_loss.append(train_loss_)
            
            self.eval()
            test_loss = 0
            with torch.inference_mode():
                for X, y in val_loader: 
                    
                    y_logits = self(X).squeeze()
                    test_loss += loss_fn(y_logits, y).item() * y.size(0)
                
        
            test_loss_ = test_loss/len(val_loader.dataset)
            val_loss.append(test_loss_)
    
        return t_loss, val_loss
        
            #https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html
        
    #https://flock-io.medium.com/credit-card-fraud-detection-build-your-own-model-part-1-9b6cac3c991c
    #https://discuss.pytorch.org/t/correct-way-to-calculate-train-and-valid-loss/178974 --> For loss visualization
    def test_model(self, test_loader, loss_fn=None):
        
        loss_fn = self.get_loss()
        val_loss = []

        for e in range(self.epochs): 
            self.eval()
            with torch.inference_mode():
 
                
                test_loss = 0
                for X, y in test_loader: 
                    y_logits = self(X).squeeze()
                    test_loss += loss_fn(y_logits, y).item() * y.size(0)
                    
                test_loss_ = test_loss/len(test_loader.dataset)
                val_loss.append(test_loss_)
        return val_loss

    def store(self, operation=None, path = None): #https://pytorch.org/tutorials/beginner/saving_loading_models.html
    
        if operation == "save": 
            torch.save(self.state_dict(), path)
    
        elif operation == "load": 
            self.load_state_dict(torch.load(path, weights_only = True))
            print("Loading successfull ! ")

    def analysis(self, l1, l2):
        #https://www.geeksforgeeks.org/how-to-create-a-multiline-plot-using-seaborn/
        l1_arr = np.array(l1)
        l2_arr = np.array(l2)

        epoch_l = [i for i in range(1, self.epochs + 1)]
        epoch_arr = np.array(epoch_l)

        sns.lineplot(x = epoch_arr, y = l1_arr, label="Train Error")
        sns.lineplot(x = epoch_arr, y = l2_arr, label="Valid Error")
        plt.legend()
        plt.show()


    def predict(self, test):
    #https://discuss.pytorch.org/t/how-to-use-pytorch-to-output-the-probability-of-binary-classfication/101043/2
    #https://codesignal.com/learn/courses/building-a-neural-network-in-pytorch/lessons/making-predictions-with-a-trained-pytorch-model
    #https://discuss.pytorch.org/t/how-to-make-pytorch-model-predict/167950
        predictions=[]
        labels = []
        probs=[]
        if self.loss_method == "BCE": 
            
            self.eval()
            with torch.inference_mode():
                for X, y in test:
                    
                    output = self(X).squeeze()
                    preds = (output > self.threshold).int()
                    predictions.append(preds)
                    labels.append(y)
                    probs.append(output)
                    
        elif self.loss_method == "CE": 
            self.eval()
            with torch.inference_mode():
                for X, y in test:
                    output = self(X).squeeze()
                    prob = torch.argmax(output, dim=1)
                    preds = (prob>self.threshold).int()
                    predictions.append(preds)
                    labels.append(y)
                    probs.append(prob)
                    
        else:
            self.eval()
            with torch.inference_mode():
                for X, y in test:
                    output = self(X).squeeze()
                    prob = torch.sigmoid(output)
                    preds = (prob > self.threshold).int()
                    predictions.append(preds)
                    labels.append(y)
                    probs.append(prob)

        all_preds = torch.cat(predictions).ravel()
        all_labels = torch.cat(labels).ravel()
        all_probs = torch.cat(probs).ravel()
        
        return all_probs, all_preds, all_labels
        
    def report(self, test, pred, labels): 

        all_probs, all_preds, all_labels = self.predict(test)
        all_probs_arr = all_probs.detach().numpy().ravel()
        all_preds_arr = all_preds.detach().numpy().ravel()
        all_labels_arr = all_labels.detach().numpy().ravel()
        
        cm = confusion_matrix(all_labels_arr, all_preds_arr)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt="d")
        plt.title("Confusion Matrix")
        plt.ylabel("Actual Class")
        plt.xlabel("Predicted Class")

        plt.show()
            
        
