In [None]:

import torch
from torch import nn
import torch.nn.functional as F
from utils import get_performance, get_confusion_matrix

from load_dataset import GetDataSet, iDataSet ,iFunction
import os
from torch.utils import data
from torchkeras import KerasModel
import pandas as pd
from models import *

In [None]:
class Softmax_Accuracy(nn.Module):
    def __init__(self):
        super().__init__()
        self.correct = nn.Parameter(torch.tensor(0),requires_grad=False)
        self.total = nn.Parameter(torch.tensor(0),requires_grad=False)
        
    def forward(self, preds: torch.Tensor, targets: torch.Tensor):

        matrix = get_confusion_matrix(F.softmax(preds, -1), targets) 
        TP ,FP ,FN ,TN = matrix[[1,1,0,0],[1,0,1,0]]  
        correct_i = (TP + TN).long()
        total_i = targets.numel()
        self.correct += correct_i 
        self.total += total_i
        return correct_i.float()/total_i
    
    def compute(self):
        return self.correct.float()/self.total 
    
    def reset(self):
        self.correct-=self.correct
        self.total-=self.total

class Recall(nn.Module):
    'Recall for binary-classification task'
    def __init__(self):
        super().__init__()
        self.true_positive = nn.Parameter(torch.tensor(0),requires_grad=False)
        self.total_positive = nn.Parameter(torch.tensor(0),requires_grad=False)

    def forward(self, preds: torch.Tensor, targets: torch.Tensor):
        y_pred = torch.softmax(preds, -1)[:,1].reshape(-1)
        y_true = targets.reshape(-1)
        assert y_pred.shape == y_true.shape
        
        true_positive_i = torch.sum((y_pred>=0.5)*(y_true>=0.5))
        total_positive_i = torch.sum(y_true>=0.5)
        self.true_positive += true_positive_i
        self.total_positive += total_positive_i
        return torch.true_divide(true_positive_i,total_positive_i)

    def compute(self):
        return torch.true_divide(self.true_positive,self.total_positive) 
    
    def reset(self):
        self.true_positive -= self.true_positive
        self.total_positive -= self.total_positive

class Mcc(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.matrix = nn.Parameter(torch.zeros(2,2),requires_grad=False)


    def forward(self, preds: torch.Tensor, targets: torch.Tensor):
        y_pred = preds.argmax(axis=1).type(targets.dtype).reshape(-1)
        y_true = targets.reshape(-1)
        self.matrix[0,0] += (y_pred | y_true == 0).sum()
        self.matrix[1,1] += (y_pred & y_true).sum()
        self.matrix[1,0] += (y_pred - y_true == 1).sum()
        self.matrix[0,1] += (y_true - y_pred == 1).sum()        
        TP ,FP ,FN ,TN = self.matrix[[1,1,0,0],[1,0,1,0]]
        MCC = ((TP * TN) - (FP * FN)) / torch.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN) + 1e-06)
        return MCC
        
    def compute(self):
        TP ,FP ,FN ,TN = self.matrix[[1,1,0,0],[1,0,1,0]]
        MCC = ((TP * TN) - (FP * FN)) / torch.sqrt((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN) + 1e-06)
        return MCC
    
    def reset(self):
        self.matrix -= self.matrix

class AUC(nn.Module):
    'approximate AUC calculation for binary-classification task'
    def __init__(self):
        super().__init__()
        self.tp = nn.Parameter(torch.zeros(10001),requires_grad=False)
        self.fp = nn.Parameter(torch.zeros(10001),requires_grad=False)
        
    def eval_auc(self,tp,fp):
        tp_total = torch.sum(tp)
        fp_total = torch.sum(fp)
        length = len(tp)
        tp_reverse = tp[range(length-1,-1,-1)]
        tp_reverse_cum = torch.cumsum(tp_reverse,axis=0)-tp_reverse/2.0
        fp_reverse = fp[range(length-1,-1,-1)]
        
        auc = torch.sum(torch.true_divide(tp_reverse_cum,tp_total)
                        *torch.true_divide(fp_reverse,fp_total))
        return auc
        
    def forward(self, preds: torch.Tensor, targets: torch.Tensor):
        y_pred = (10000*torch.softmax(preds, -1)[:,1]).reshape(-1).type(torch.int)
        y_true = targets.reshape(-1)
        
        tpi = self.tp-self.tp
        fpi = self.fp-self.fp
        assert y_pred.shape == y_true.shape
        for i,label in enumerate(y_true):
            if label>=0.5:
                tpi[y_pred[i]]+=1.0
            else:
                fpi[y_pred[i]]+=1.0
        self.tp+=tpi
        self.fp+=fpi
        return self.eval_auc(tpi,fpi)
          
    def compute(self):
        return self.eval_auc(self.tp,self.fp)
    
    def reset(self):
        self.tp-=self.tp
        self.fp-=self.fp

In [None]:
device = torch.device('cuda')

In [None]:
feature_encoding = iFunction.fe1
path = os.path.join("data")

In [None]:
def one_trial(lr, embedding_dim, batchsize):
    dataset_name =  "mouse"
    datasets = GetDataSet(os.path.join(path, dataset_name), iFunction.read_txt_to_pd2, dataset_name)
    test_data = datasets(feature_encoding,1)
    train_data = datasets(feature_encoding)
    metric_dict = {"Mcc":Mcc()}
    train_iter = data.DataLoader(train_data,batchsize,True)
    valid_iter = data.DataLoader(test_data,len(test_data)) 
    net = classifier(4,embedding_dim,True).to(device)
    net2 = classifier2(embedding_dim,True).to(device)
    target_net = nn.Sequential(net,net2).to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)
    opti = torch.optim.Adam(target_net.parameters(),lr)
    model = KerasModel(target_net,loss_fn,optimizer=opti,metrics_dict=metric_dict)
    model.fit(train_iter,valid_iter,200,patience=40, monitor='val_Mcc', mode="max")
    with torch.no_grad():
        target_net.load_state_dict(torch.load("checkpoint.pt"))
        X, y = next(iter(valid_iter))
        X, y = X.to(device),y.to(device)
        y_hat = target_net(X)
        perfor = get_performance(y_hat.softmax(-1), y)
        
    return perfor

In [None]:
from sklearn.model_selection import KFold
def kfold_valid(num_folds, batchsize, embedding_dim, lr):
    all_perfors = []
    kfd = KFold(num_folds,shuffle=True)
    dataset_name =  "mouse"
    datasets = GetDataSet(os.path.join(path, dataset_name), iFunction.read_txt_to_pd2, dataset_name)
    train_pos, train_neg = datasets(feature_encoding, 0)
    for ifold,(pos,neg) in enumerate(zip(kfd.split(train_pos),kfd.split(train_neg))):
        metric_dict = {"Mcc":Mcc()}
        train_set,valid_set = data.Subset(train_pos,pos[0]),data.Subset(train_pos,pos[1])
        train_set += data.Subset(train_neg,neg[0])
        valid_set += data.Subset(train_neg,neg[1])
        train_iter = data.DataLoader(train_set,batchsize,True)
        valid_iter = data.DataLoader(valid_set,len(valid_set)) 
        net = classifier(4,embedding_dim,True).to(device)
        net2 = classifier2(embedding_dim,True).to(device)
        target_net = nn.Sequential(net,net2).to(device)
        loss_fn = nn.CrossEntropyLoss().to(device)
        opti = torch.optim.Adam(target_net.parameters(),lr)
        model = KerasModel(target_net, loss_fn, optimizer=opti, metrics_dict=metric_dict)
        model.fit(train_iter, valid_iter, epochs=200, patience=40, monitor='val_Mcc', mode="max", quiet=True, plot=False)
        with torch.no_grad():
            target_net.load_state_dict(torch.load("checkpoint.pt"))
            X, y = next(iter(valid_iter))
            X, y = X.to(device),y.to(device)
            y_hat = target_net(X)
            perfor = get_performance(y_hat.softmax(-1), y)
        all_perfors.append(pd.DataFrame.from_dict(perfor, orient='index').T)
    return pd.concat(all_perfors, ignore_index=True).mean(axis=0)
