In [None]:
import os
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, roc_curve
from sklearn.model_selection import StratifiedKFold, train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import datetime
from time import time

start_time=datetime.datetime.now()
torch.manual_seed(0)
np.random.seed(0)
torch.set_printoptions(precision=6)
np.set_printoptions(precision=6)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataset = 'PCBA_Metrics.txt'

# Load and preprocess the data
dirname = os.path.abspath('')
path = os.path.join(dirname, 'tapAUC')
metric_df_include = pd.read_csv(os.path.join(path, dataset), index_col=False)
# Impute NaN values with the average column
metric_df_include.fillna(metric_df_include.mean(), inplace=True)
# Drop all constant columns
metric_df_include = metric_df_include.drop(metric_df_include.iloc[:, 8:].columns[metric_df_include.iloc[:, 8:].nunique() <= 1], axis=1)
# Replace +/-inf values with np.finfo(np.float32).max/min
metric_df_include = metric_df_include.replace(' inf', np.finfo(np.float32).max)
metric_df_include = metric_df_include.replace(' -inf', np.finfo(np.float32).min)
# Drop highly correlated features
corr_matrix = metric_df_include.iloc[:, 8:].corr().abs()
upper = corr_matrix.where(np.triu(np.ones(corr_matrix.shape, dtype=bool), k=1))
to_drop = [column for column in upper.columns if any(upper[column] > 0.95)]
metric_df_include.drop(to_drop, axis=1, inplace=True)
col_names = metric_df_include.columns.values
metric_df_include = metric_df_include.set_axis(col_names, axis=1, inplace=False)
data_full = metric_df_include.reset_index(drop=True)

data_full.loc[data_full["status"] == "NG", "status"] = 1
data_full.loc[data_full["status"] == "OK", "status"] = 0
X = data_full.iloc[:, 8:].to_numpy()
y = data_full.iloc[:, 2].to_numpy().astype(np.float32)
test_files = data_full['test_file'].to_numpy()  # Extract test_file column
statuses = data_full['status'].astype(str).to_numpy()
scaler = StandardScaler()
X = scaler.fit_transform(X)

class CustomDataset(Dataset):
    def __init__(self, X, y, test_files=None, statuses=None):
        self.X = torch.FloatTensor(X)
        self.y = torch.FloatTensor(y)
        self.test_files = test_files
        self.statuses = statuses
    def __getitem__(self, index):
        item = (self.X[index], self.y[index])
        if self.test_files is not None and self.statuses is not None:
            item += (self.test_files[index], self.statuses[index])
        return item
    def __len__(self):
        return len(self.X)

class binaryClassification(nn.Module):
    def __init__(self):
        super(binaryClassification, self).__init__()
        self.layer_1 = nn.Linear(feat_count, int(feat_count/2))
        self.layer_2 = nn.Linear(int(feat_count/2), int(feat_count/2))
        self.layer_out = nn.Linear(int(feat_count/2), 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)
        self.batchnorm1 = nn.BatchNorm1d(int(feat_count/2))
        self.batchnorm2 = nn.BatchNorm1d(int(feat_count/2))
        self.sigmoid = nn.Sigmoid()
    def forward(self, inputs):
        x = self.relu(self.layer_1(inputs))
        x = self.batchnorm1(x)
        x = self.relu(self.layer_2(x))
        x = self.batchnorm2(x)
        x = self.dropout(x)
        x = self.layer_out(x)
        x = self.sigmoid(x)
        return x

class tapAUCLoss(nn.modules.loss._Loss):
    def forward(self, y_pred, y_true):
        if not isinstance(y_true, torch.Tensor):
            y_true = torch.tensor(y_true)
        if isinstance(y_pred, tuple):
            y_pred = y_pred.values
        y_true = y_true.to(y_pred.device)
        positive = y_pred[y_true == 1]
        negative = y_pred[y_true == 0]
        loss = torch.nn.functional.relu((negative.view(1, -1) + gamma - positive.view(-1, 1)) ** 2).mean()
        return loss

def get_metric(mode,y_pred, y_true, th=None):
    y_pred = y_pred.cpu().detach().numpy()
    y_true = y_true.cpu().detach().numpy()
    positive = y_pred[y_true == 1]
    negative = y_pred[y_true == 0]
    if th == None:
        if mode == 'std':
            allfpr, alltpr, thresholds = roc_curve(y_true, y_pred)
            accuracy_scores = [accuracy_score(y_true, [m >= thresh for m in y_pred]) for thresh in thresholds]
            accuracies = np.array(accuracy_scores)
            acc = accuracies.max()*100
            th = thresholds[accuracies.argmax()]
        else:
             th = min(positive)
    FP = len(negative[negative >= th])
    TP = len(positive[positive >= th])
    TN = len(negative[negative < th])
    FN = len(positive[positive < th])
    if ((th != None and mode == 'std') or (mode == 'zfn')):
        acc = [(TP+TN)/(TP+FN+FP+TN) if (TP+FN+FP+TN)!=0 else 0][0]
    specificity= [1-(FP/(FP+TN)) if (FP+TN)!=0 else 1][0]
    sensitivity= [TP/(FN+TP) if (FN+TP)!=0 else 0][0]
    gmean=np.sqrt(sensitivity*specificity)
    precision= [TP/(TP+FP) if (TP+FP)!=0 else 0][0]
    f1score=[(((1**2)+1)*precision*sensitivity)/(((1**2)*precision)+sensitivity) if (((1**2)*precision)+sensitivity)!=0 else 0][0]
    f2score=[(((2**2)+1)*precision*sensitivity)/(((2**2)*precision)+sensitivity) if (((2**2)*precision)+sensitivity)!=0 else 0][0]
    return th,TP,FP,TN,FN,acc,sensitivity,specificity,gmean,precision,f1score,f2score 

def cv(X, y, test_files, statuses, repeat, n_splits=5):
    train_metrics=[[] for _ in range(n_splits)]
    test_metrics=[[] for _ in range(n_splits)]
    
     # Perform k-fold cross-validation on the train+val set
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=repeat)   
    for fold, (train_index, test_index) in enumerate(skf.split(X, y), 1):
        # Split data into train and validation sets
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        test_files_train, test_files_test = test_files[train_index], test_files[test_index]
        statuses_train, statuses_test = statuses[train_index], statuses[test_index]

        # Create datasets and loaders
        train_dataset = CustomDataset(X_train, y_train, test_files=test_files_train, statuses=statuses_train)
        test_dataset = CustomDataset(X_test, y_test, test_files=test_files_test, statuses=statuses_test)
        train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)
        
        # Initialize model and optimizer
        model = binaryClassification()
        model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

        # Training loop
        for epoch in range(1, total_epochs + 1):
            model.train()
            for X_batch_train, y_batch_train, _, _ in train_loader:
                X_batch_train, y_batch_train = X_batch_train.to(device), y_batch_train.to(device)
                optimizer.zero_grad()
                y_pred_train = model(X_batch_train)
                if epoch >= int(total_epochs * e_warmup):
                    y_ = y_batch_train.clone()
                    y_score_train = y_pred_train.clone().flatten()
                    y_score_train_idx = (-y_score_train).argsort()
                    y_score_sorted_train = y_score_train[y_score_train_idx]
                    y_sorted_train = y_[y_score_train_idx]
                    S_minus = y_score_sorted_train[torch.where(y_sorted_train == 0)]
                    S_minus_idx = torch.where(y_sorted_train == 0)[0]
                    S_plus = y_score_sorted_train[torch.where(y_sorted_train == 1)]
                    S_plus_idx = torch.where(y_sorted_train == 1)[0]
                    if neg_count != 1:
                        S_minus_alpha_idx = S_minus_idx[-int(len(S_minus_idx) * neg_count):]
                    else:
                        S_minus_alpha_idx = S_minus_idx[-1:]
                    S_minus_alpha = y_score_sorted_train[S_minus_alpha_idx]
                    S_alpha, _ = torch.sort(torch.cat((S_minus_alpha, S_plus), axis=0), descending=True)
                    S_alpha_idx, _ = torch.sort(torch.cat((S_minus_alpha_idx, S_plus_idx), axis=0), descending=False)
                    y_train_alpha = y_sorted_train[S_alpha_idx]
                    loss = tapAUCLoss()(S_alpha, y_train_alpha)
                    pred_train_lst[repeat-1][fold-1][epoch-1].append(S_alpha)
                    label_train_lst[repeat-1][fold-1][epoch-1].append(y_train_alpha)
                else:
                    loss = tapAUCLoss()(y_pred_train, y_batch_train.unsqueeze(1))
                    pred_train_lst[repeat-1][fold-1][epoch-1].append(y_pred_train)
                    label_train_lst[repeat-1][fold-1][epoch-1].append(y_batch_train.unsqueeze(1))
                loss.backward(retain_graph=True)
                optimizer.step()
            train_std_th,train_std_TP,train_std_FP,train_std_TN,train_std_FN,train_std_acc,train_std_sensitivity,train_std_specificity,train_std_gmean,train_std_precision,train_std_f1score,train_std_f2score=get_metric('std',y_pred_train, y_batch_train.unsqueeze(1))
            train_zfn_th,train_zfn_TP,train_zfn_FP,train_zfn_TN,train_zfn_FN,train_zfn_acc,train_zfn_sensitivity,train_zfn_specificity,train_zfn_gmean,train_zfn_precision,train_zfn_f1score,train_zfn_f2score=get_metric('zfn',y_pred_train, y_batch_train.unsqueeze(1))
            train_zfn_th_lst[repeat-1][fold-1][epoch-1].append(train_zfn_th)

            # Validation
            model.eval()
            test_std_acc, test_std_th = 0, 0
            test_zfn_acc, test_zfn_th = 0, 0
            with torch.no_grad():
                for X_batch_test, y_batch_test, _, _ in test_loader:
                    X_batch_test = X_batch_test.to(device)
                    y_pred_test = model(X_batch_test)
                    pred_test_lst[repeat-1][fold-1][epoch-1].append(y_pred_test)
                    label_test_lst[repeat-1][fold-1][epoch-1].append(y_batch_test)
                    test_std_th,test_std_TP,test_std_FP,test_std_TN,test_std_FN,test_std_acc,test_std_sensitivity,test_std_specificity,test_std_gmean,test_std_precision,test_std_f1score,test_std_f2score=get_metric('std',y_pred_test, y_batch_test.unsqueeze(1),train_std_th)
                    test_zfn_th,test_zfn_TP,test_zfn_FP,test_zfn_TN,test_zfn_FN,test_zfn_acc,test_zfn_sensitivity,test_zfn_specificity,test_zfn_gmean,test_zfn_precision,test_zfn_f1score,test_zfn_f2score=get_metric('zfn',y_pred_test, y_batch_test.unsqueeze(1),train_zfn_th)
                    test_zfn_th_lst[repeat-1][fold-1][epoch-1].append(test_zfn_th)
            
            if epoch == total_epochs:
                train_metrics[fold-1]=[fold,repeat,loss.item(),train_std_th,train_std_TP,train_std_FP,train_std_TN,train_std_FN,train_std_acc,train_std_sensitivity,train_std_specificity,train_std_gmean,train_std_precision,train_std_f1score,train_std_f2score,train_zfn_th,train_zfn_TP,train_zfn_FP,train_zfn_TN,train_zfn_FN,train_zfn_acc,train_zfn_sensitivity,train_zfn_specificity,train_zfn_gmean,train_zfn_precision,train_zfn_f1score,train_zfn_f2score]
                test_metrics[fold-1]=[fold,repeat,'-',test_std_th,test_std_TP,test_std_FP,test_std_TN,test_std_FN,test_std_acc,test_std_sensitivity,test_std_specificity,test_std_gmean,test_std_precision,test_std_f1score,test_std_f2score,test_zfn_th,test_zfn_TP,test_zfn_FP,test_zfn_TN,test_zfn_FN,test_zfn_acc,test_zfn_sensitivity,test_zfn_specificity,test_zfn_gmean,test_zfn_precision,test_zfn_f1score,test_zfn_f2score]

    return train_metrics, test_metrics, pred_train_lst, pred_test_lst, label_train_lst, label_test_lst, train_zfn_th_lst, test_zfn_th_lst
        
def run_code(i_loop, current_file, total_epochs, neg_count, gamma, e_warmup, n_repeats=5):
    
    n_splits = 5
    pos_gt = np.sum(y == 1)
    neg_gt = np.sum(y == 0)
    repeat=0
    train_metrics_repeat=[]
    test_metrics_repeat=[]
        
    for _ in range(n_repeats):
        repeat+=1
        train_metrics, test_metrics, pred_train_lst, pred_test_lst, label_train_lst, label_test_lst, train_zfn_th_lst, test_zfn_th_lst = cv(X,y, test_files, statuses, repeat, n_splits=n_splits)
        train_metrics_repeat.append(train_metrics)
        test_metrics_repeat.append(test_metrics)
   
    if neg_count != '-':
        if neg_count==1:
            neg_count=str(neg_count)
        else:
            neg_count=str(neg_count)+'%'
            
    def log_results(metric_list,split):
        
        losses=[]
        std_ths,std_TPs,std_FPs,std_TNs,std_FNs,std_accs,std_sensitivitys,std_specificitys,std_gmeans,std_precisions,std_f1scores,std_f2scores=[],[],[],[],[],[],[],[],[],[],[],[]
        zfn_ths,zfn_TPs,zfn_FPs,zfn_TNs,zfn_FNs,zfn_accs,zfn_sensitivitys,zfn_specificitys,zfn_gmeans,zfn_precisions,zfn_f1scores,zfn_f2scores=[],[],[],[],[],[],[],[],[],[],[],[]

        # Extract metrics
        for fold in metric_list:
            for entry in fold:
                losses.append(entry[2])
                std_ths.append(entry[3])
                std_TPs.append(entry[4])
                std_FPs.append(entry[5])
                std_TNs.append(entry[6])
                std_FNs.append(entry[7])
                std_accs.append(entry[8])
                std_sensitivitys.append(entry[9])
                std_specificitys.append(entry[10])
                std_gmeans.append(entry[11])
                std_precisions.append(entry[12])
                std_f1scores.append(entry[13])
                std_f2scores.append(entry[14])
                zfn_ths.append(entry[15])
                zfn_TPs.append(entry[16])
                zfn_FPs.append(entry[17])
                zfn_TNs.append(entry[18])
                zfn_FNs.append(entry[19])
                zfn_accs.append(entry[20])
                zfn_sensitivitys.append(entry[21])
                zfn_specificitys.append(entry[22])
                zfn_gmeans.append(entry[23])
                zfn_precisions.append(entry[24])
                zfn_f1scores.append(entry[25])
                zfn_f2scores.append(entry[26])
                
        if split == 'train':
            mean_loss_line =  np.mean(losses)
            std_loss_line =  np.std(losses)
        else:
            mean_loss_line =  '-'
            std_loss_line =  '-'
        averages = [
            mean_loss_line,
            np.mean(std_ths), np.mean(std_TPs), np.mean(std_FPs), np.mean(std_TNs),np.mean(std_FNs),
            np.mean(std_accs), np.mean(std_sensitivitys), np.mean(std_specificitys), np.mean(std_gmeans),
            np.mean(std_precisions),np.mean(std_f1scores), np.mean(std_f2scores),
            np.mean(zfn_ths), np.mean(zfn_TPs), np.mean(zfn_FPs), np.mean(zfn_TNs),np.mean(zfn_FNs),
            np.mean(zfn_accs), np.mean(zfn_sensitivitys), np.mean(zfn_specificitys), np.mean(zfn_gmeans),
            np.mean(zfn_precisions),np.mean(zfn_f1scores), np.mean(zfn_f2scores)
        ]
        std_devs = [
            std_loss_line,
            np.std(std_ths), np.std(std_TPs), np.std(std_FPs), np.std(std_TNs),np.std(std_FNs),
            np.std(std_accs), np.std(std_sensitivitys), np.std(std_specificitys), np.std(std_gmeans),
            np.std(std_precisions),np.std(std_f1scores), np.std(std_f2scores),
            np.std(zfn_ths), np.std(zfn_TPs), np.std(zfn_FPs), np.std(zfn_TNs),np.std(zfn_FNs),
            np.std(zfn_accs), np.std(zfn_sensitivitys), np.std(zfn_specificitys), np.std(zfn_gmeans),
            np.std(zfn_precisions),np.std(zfn_f1scores), np.std(zfn_f2scores)
        ]
        summary_entry = [averages] + [std_devs]
        metric_list.insert(0, summary_entry)
        with open(mean_file, "a") as m:
            with open(stdev_file, "a") as s:
                with open(detailed_file, "a") as d:
                    for fold_index, fold in enumerate(metric_list):
                        for entry_index, entry in enumerate(fold):
                            if fold_index == 0:
                                if entry_index == 0:
                                    line_m = f"{str(i_loop)},{dataset},{str(pos_gt)},{str(neg_gt)},{str(total_epochs)},{str(gamma)},{str(e_warmup)},{neg_count},{split}," + ",".join(map(str, entry[0:])) + "\n"
                                    m.write(line_m)
                                else:
                                    line_s = f"{str(i_loop)},{dataset},{str(pos_gt)},{str(neg_gt)},{str(total_epochs)},{str(gamma)},{str(e_warmup)},{neg_count},{split}," + ",".join(map(str, entry[0:])) + "\n"
                                    s.write(line_s)
                            else:
                                line_d = f"{str(i_loop)},{dataset},{str(pos_gt)},{str(neg_gt)},{str(total_epochs)},{str(gamma)},{str(e_warmup)},{neg_count},{split}," + ",".join(map(str, entry[0:])) + "\n"
                                d.write(line_d)
        print(line_m)
    log_results(train_metrics_repeat,'train')
    log_results(test_metrics_repeat,'test')
    return pred_train_lst, pred_test_lst, label_train_lst, label_test_lst, train_zfn_th_lst, test_zfn_th_lst

i_loop=0
feat_count = len(X[0])
detailed_file='log_tapAUC_detail_'+str(round(time()))+'_'+dataset
mean_file='log_tapAUC_mean_'+str(round(time()))+'_'+dataset
stdev_file='log_tapAUC_stdev_'+str(round(time()))+'_'+dataset
for current_file in [mean_file,stdev_file,detailed_file]:
    if current_file == detailed_file:
        with open(current_file, "w") as f:
            f.write('idx, dataset,pos,neg,total_epochs,gamma,e_warmup,neg_count,split,fold,repeat,loss,std_ths,std_TPs,std_FPs,std_TNs,std_FNs,std_accs,std_sensitivitys,std_specificitys,std_gmeans,std_precisions,std_f1scores,std_f2scores,zfn_ths,zfn_TPs,zfn_FPs,zfn_TNs,zfn_FNs,zfn_accs,zfn_sensitivitys,zfn_specificitys,zfn_gmeans,zfn_precisions,zfn_f1scores,zfn_f2scores,'+'\n')
    else:
        with open(current_file, "w") as f:
            f.write('idx, dataset,pos,neg,total_epochs,gamma,e_warmup,neg_count,split,loss,std_ths,std_TPs,std_FPs,std_TNs,std_FNs,std_accs,std_sensitivitys,std_specificitys,std_gmeans,std_precisions,std_f1scores,std_f2scores,zfn_ths,zfn_TPs,zfn_FPs,zfn_TNs,zfn_FNs,zfn_accs,zfn_sensitivitys,zfn_specificitys,zfn_gmeans,zfn_precisions,zfn_f1scores,zfn_f2scores,'+'\n')

for total_epochs in [500]:
    for gamma in [0.1]:
        for e_warmup in [0.25]:
            for neg_count in [0.05]:
                i_loop+=1
                start_hyper_time = datetime.datetime.now()
                n_repeats=5
                n_splits=5
                pred_train_lst=[[[[] for _ in range(total_epochs)] for _ in range(n_splits)] for _ in range(n_repeats)]
                label_train_lst=[[[[] for _ in range(total_epochs)] for _ in range(n_splits)] for _ in range(n_repeats)]
                train_zfn_th_lst=[[[[] for _ in range(total_epochs)] for _ in range(n_splits)] for _ in range(n_repeats)]
                pred_test_lst=[[[[] for _ in range(total_epochs)] for _ in range(n_splits)] for _ in range(n_repeats)]
                label_test_lst=[[[[] for _ in range(total_epochs)] for _ in range(n_splits)] for _ in range(n_repeats)]
                test_zfn_th_lst=[[[[] for _ in range(total_epochs)] for _ in range(n_splits)] for _ in range(n_repeats)]
                pred_train_lst, pred_test_lst, label_train_lst, label_test_lst, train_zfn_th_lst, test_zfn_th_lst = run_code(i_loop,current_file,total_epochs,neg_count,gamma,e_warmup)
                print('elapsed time :', datetime.datetime.now()-start_hyper_time)
exec_time = datetime.datetime.now()-start_time
print('exec time :', exec_time)