In [8]:
import torch
from torch.autograd import Variable
import torch.nn.functional as F
from torch.utils import data
from torch import nn 
import copy

from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from time import time
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, roc_curve, confusion_matrix, precision_score, recall_score, auc
from sklearn.model_selection import KFold
torch.manual_seed(1)    # reproducible torch:2 np:3
np.random.seed(1)

from config import BIN_config_DBPE
from models import BIN_Interaction_Flat
from stream import BIN_Data_Encoder

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [None]:
def test(data_generator, model):
    y_pred = []
    y_label = []
    model.eval()
    loss_accumulate = 0.0
    count = 0.0
    for i, (d, p, d_mask, p_mask, label) in enumerate(data_generator):
        # move inputs to device before forward
        score = model(d.long().to(device), p.long().to(device), d_mask.long().to(device), p_mask.long().to(device))
        
        m = torch.nn.Sigmoid()
        logits = torch.squeeze(m(score))
        loss_fct = torch.nn.BCELoss()            
        
        label = Variable(torch.from_numpy(np.array(label)).float()).to(device)

        loss = loss_fct(logits, label)
        
        loss_accumulate += loss
        count += 1
        
        logits = logits.detach().cpu().numpy()
        
        label_ids = label.to('cpu').numpy()
        y_label = y_label + label_ids.flatten().tolist()
        y_pred = y_pred + logits.flatten().tolist()
        
    loss = loss_accumulate/count
    
    fpr, tpr, thresholds = roc_curve(y_label, y_pred)

    try:
        precision = tpr / (tpr + fpr)
    except:
        print("Precision error: tpr: ",tpr,", fpr: ", fpr)
        precision = tpr / (tpr + fpr + 0.00001) 

    f1 = 2 * precision * tpr / (tpr + precision + 0.00001)

    thred_optim = thresholds[5:][np.argmax(f1[5:])]

    print("optimal threshold: " + str(thred_optim))

    y_pred_s = [1 if i else 0 for i in (y_pred >= thred_optim)]

    auc_k = auc(fpr, tpr)
    print("AUROC:" + str(auc_k))
    print("AUPRC: "+ str(average_precision_score(y_label, y_pred)))


    pre = precision_score(y_label, y_pred_s)
    rec = recall_score(y_label, y_pred_s)
    cm1 = confusion_matrix(y_label, y_pred_s)
    print('Confusion Matrix : \n-', cm1)
    print('Recall : ', recall_score(y_label, y_pred_s))
    print('Precision : ', precision_score(y_label, y_pred_s))

    total1=sum(sum(cm1))
    #####from confusion matrix calculate accuracy
    accuracy1=(cm1[0,0]+cm1[1,1])/total1
    print ('Accuracy : ', accuracy1)

    sensitivity1 = cm1[0,0]/(cm1[0,0]+cm1[0,1])
    print('Sensitivity : ', sensitivity1 )

    specificity1 = cm1[1,1]/(cm1[1,0]+cm1[1,1])
    print('Specificity : ', specificity1)

    outputs = np.asarray([1 if i else 0 for i in (np.asarray(y_pred) >= 0.5)])
    return accuracy1, pre, rec, roc_auc_score(y_label, y_pred), average_precision_score(y_label, y_pred), f1_score(y_label, outputs), y_pred, loss.item()


def main(fold_n, lr):
    config = BIN_config_DBPE()
    
    lr = lr
    BATCH_SIZE = config['batch_size']
    train_epoch = 100
    
    loss_history = []
    
    model = BIN_Interaction_Flat(**config)
    model = model.to(device)
    
    if use_cuda and torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model, dim = 0)
            
    opt = torch.optim.Adam(model.parameters(), lr = lr)
    #opt = torch.optim.SGD(model.parameters(), lr = lr, momentum=0.9)
    
    print('--- Data Preparation ---')
    
    params = {'batch_size': BATCH_SIZE,
              'shuffle': True,
              'num_workers': 6, 
              'drop_last': True}

    # dataFolder = './dataset/BindingDB'
    dataFolder = "./dataset/b_cancer/splitted"
    df_train = pd.read_csv(dataFolder + '/train.csv')
    df_val = pd.read_csv(dataFolder + '/val.csv')
    df_test = pd.read_csv(dataFolder + '/test.csv')
    
    training_set = BIN_Data_Encoder(df_train.index.values, df_train.Label.values, df_train)
    training_generator = data.DataLoader(training_set, **params)

    validation_set = BIN_Data_Encoder(df_val.index.values, df_val.Label.values, df_val)
    validation_generator = data.DataLoader(validation_set, **params)
    
    testing_set = BIN_Data_Encoder(df_test.index.values, df_test.Label.values, df_test)
    testing_generator = data.DataLoader(testing_set, **params)
    
    # early stopping
    max_auc = 0
    model_max = copy.deepcopy(model)
    
    print('--- Go for Training ---')
    torch.backends.cudnn.benchmark = True

    resultfile = open('results/b_cancer.txt', 'w')
    resultfile.write('Epoch,AUROC,Accuracy,Precision,Recall,AUPRC,F1\n')
    
    for epo in range(train_epoch):
        model.train()
        for i, (d, p, d_mask, p_mask, label) in enumerate(training_generator):
            # move inputs to device before forward
            score = model(d.long().to(device), p.long().to(device), d_mask.long().to(device), p_mask.long().to(device))

            label = Variable(torch.from_numpy(np.array(label)).float()).to(device)
            
            loss_fct = torch.nn.BCELoss()
            m = torch.nn.Sigmoid()
            n = torch.squeeze(m(score))
            
            loss = loss_fct(n, label)
            loss_history.append(loss)
            
            opt.zero_grad()
            loss.backward()
            opt.step()
            
            if (i % 100 == 0):
                print('Training at Epoch ' + str(epo + 1) + ' iteration ' + str(i) + ' with loss ' + str(loss.cpu().detach().numpy()))
            
        # every epoch test
        with torch.set_grad_enabled(False):
            acc, pre, rec, auc, auprc, f1, logits, loss = test(validation_generator, model)
            if auc > max_auc:
                model_max = copy.deepcopy(model)
                max_auc = auc
            
            print('Validation at Epoch '+ str(epo + 1) + ' :: , AUROC: '+ str(auc) + ' ',' Accuracy: ',acc,' , Precision: ',pre,' , Recall: ',rec,' , AUPRC: ' + str(auprc) + ' , F1: '+str(f1))
            resultfile.write(f'{epo + 1},{auc},{acc},{pre},{rec},{auprc},{f1}\n')
    resultfile.close()
    
    print('--- Go for Testing ---')
    try:
        with torch.set_grad_enabled(False):
            acc, pre, rec, auc, auprc, f1, logits, loss = test(testing_generator, model_max)
            print('Testing :: ','Accuracy: ',acc,' , Precision: ',pre,' , Recall: ',rec,' , AUROC: ' + str(auc) + ' , AUPRC: ' + str(auprc) + ' , F1: '+str(f1) + ' , Test loss: '+str(loss))
    except:
        print('testing failed')
    return model_max, loss_history

In [12]:
# fold 1
#biosnap interaction times 1e-6, flat, batch size 64, len 205, channel 3, epoch 50
s = time()
model_max, loss_history = main(1, 5e-6)
e = time()
print(e-s)
lh = list(filter(lambda x: x < 1, loss_history))
plt.plot(lh)

--- Data Preparation ---
--- Go for Training ---
Training at Epoch 1 iteration 0 with loss 0.7316459
Training at Epoch 1 iteration 0 with loss 0.7316459


  precision = tpr / (tpr + fpr)


optimal threshold: 0.42815282940864563
AUROC:0.7445312500000001
AUPRC: 0.7617311964885616
Confusion Matrix : 
- [[47 33]
 [17 63]]
Recall :  0.7875
Precision :  0.65625
Accuracy :  0.6875
Sensitivity :  0.5875
Specificity :  0.7875
Validation at Epoch 1 :: , AUROC: 0.7445312500000001   Accuracy:  0.6875  , Precision:  0.65625  , Recall:  0.7875  , AUPRC: 0.7617311964885616 , F1: 0.6266666666666667
Validation at Epoch 1 :: , AUROC: 0.7445312500000001   Accuracy:  0.6875  , Precision:  0.65625  , Recall:  0.7875  , AUPRC: 0.7617311964885616 , F1: 0.6266666666666667
Training at Epoch 2 iteration 0 with loss 0.6334496
Training at Epoch 2 iteration 0 with loss 0.6334496


  precision = tpr / (tpr + fpr)


optimal threshold: 0.550786554813385
AUROC:0.7590248476324425
AUPRC: 0.7787297335383943
Confusion Matrix : 
- [[61 18]
 [24 57]]
Recall :  0.7037037037037037
Precision :  0.76
Accuracy :  0.7375
Sensitivity :  0.7721518987341772
Specificity :  0.7037037037037037
Validation at Epoch 2 :: , AUROC: 0.7590248476324425   Accuracy:  0.7375  , Precision:  0.76  , Recall:  0.7037037037037037  , AUPRC: 0.7787297335383943 , F1: 0.7111111111111111
Validation at Epoch 2 :: , AUROC: 0.7590248476324425   Accuracy:  0.7375  , Precision:  0.76  , Recall:  0.7037037037037037  , AUPRC: 0.7787297335383943 , F1: 0.7111111111111111
Training at Epoch 3 iteration 0 with loss 0.7008484
Training at Epoch 3 iteration 0 with loss 0.7008484


KeyboardInterrupt: 