In [1]:
import os, tqdm
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from   torch.utils.data import TensorDataset, DataLoader
from   torch.utils.tensorboard import SummaryWriter


from scripts.Evaluator     import evaluator
from scripts.Generator     import generatorNet
from scripts.Discriminator import ensembleNet


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('the device is %s' % device)

the device is cpu


## Data preprocess

In [3]:
shuffle_seed           = 37
oncogenic_variant_size = 500
benign_variant_size    = 500
batchsize              = 500

In [4]:
training_unlabeled = pd.read_csv("example/training_data/unlabeled_training.dat")
training_labeled   = pd.read_csv("example/training_data/labeled_training.dat")
training_labeled.head(2)

Unnamed: 0,SIFT_score,Polyphen2_HDIV_score,Polyphen2_HVAR_score,LRT_score,MutationTaster_score,MutationAssessor_score,FATHMM_score,PROVEAN_score,VEST3_score,CADD_raw,...,evs_10_2,evs_11_-1,evs_11_0,evs_11_1,evs_11_2,evs_12_-1,evs_12_0,evs_12_1,evs_12_2,true_label
0,0.012,0.996,0.877,0.0,0.967,0.631809,0.617818,0.497366,0.860861,0.42625,...,0.070834,0.098091,0.938011,0.073036,0.070834,0.098091,0.938011,0.073036,0.070834,0
1,0.003,0.79,0.365,0.694,1.0,0.491245,0.672773,0.542581,0.067067,0.442278,...,0.077686,0.106276,0.091509,0.916682,0.077686,0.106276,0.091509,0.916682,0.077686,0


In [5]:
## labeled
training_labeled = training_labeled.values

np.random.seed(shuffle_seed)
np.random.shuffle(training_labeled)

labeled_features, labeled_targets = training_labeled[:, :-1], training_labeled[:, -1]

training_features   = np.vstack( (labeled_features[labeled_targets == 1][:oncogenic_variant_size], 
                                  labeled_features[labeled_targets == 0][:benign_variant_size]) )

validation_features = np.vstack( (labeled_features[labeled_targets == 1][oncogenic_variant_size:], 
                                  labeled_features[labeled_targets == 0][benign_variant_size:]) )


training_targets    = np.hstack( (labeled_targets[labeled_targets == 1][:oncogenic_variant_size], 
                                  labeled_targets[labeled_targets == 0][:benign_variant_size]) )

validation_targets  = np.hstack( (labeled_targets[labeled_targets == 1][oncogenic_variant_size:], 
                                  labeled_targets[labeled_targets == 0][benign_variant_size:]) )


## unlabeled
unlabeled_features = training_unlabeled.values


In [6]:
tensor_dat       = TensorDataset(torch.Tensor(training_features[:, np.newaxis, :]), torch.Tensor(training_targets).long())
training_batch   = DataLoader(dataset=tensor_dat, batch_size = batchsize, shuffle=True)

tensor_dat       = TensorDataset(torch.Tensor(validation_features[:, np.newaxis, :]), torch.Tensor(validation_targets).long())
validation_batch = DataLoader(dataset=tensor_dat, batch_size = batchsize, shuffle=False)

tensor_dat       = TensorDataset(torch.Tensor(unlabeled_features[:, np.newaxis, :]))
unlabeled_batch  = DataLoader(dataset=tensor_dat, batch_size = batchsize, shuffle=False)

## Model settings

In [7]:
discrminator  = ensembleNet().to(device)
generator     = generatorNet().to(device)
cross_entropy = nn.CrossEntropyLoss()
optimizerDis  = optim.AdamW(discrminator.parameters(), lr = 0.01)
optimizerGen  = optim.AdamW(generator.parameters(), lr = 0.01)

schedulerDis  = optim.lr_scheduler.LambdaLR(optimizerDis, lambda epoch: 0.9**epoch)
schedulerGen  = optim.lr_scheduler.LambdaLR(optimizerGen, lambda epoch: 0.9**epoch)

In [8]:
writer = SummaryWriter(comment="my_test")

num_epochs  = 10
global_step = 1
iter_labeled = iter(training_batch)


for epoch in tqdm.tqdm(range(num_epochs)):
    for step, x_unlabeled in enumerate(unlabeled_batch):

        #################################################################################  Classifier/Discriminator 
        discrminator.train()
        generator.eval()
        
        optimizerDis.zero_grad()
        
        ## label
        try:
            x_labeled, y_labeled = next(iter_labeled)
            x_labeled, y_labeled = x_labeled.to(device), y_labeled.to(device)
        except StopIteration:
            iter_labeled = iter(training_batch)
            x_labeled, y_labeled = next(iter_labeled)
            x_labeled, y_labeled = x_labeled.to(device), y_labeled.to(device)
        
        
        _, outClassLabeled  = discrminator(x_labeled)
        lossLabeled      = cross_entropy(outClassLabeled, y_labeled)
        
        
        ## unlabel
        x_unlabeled = x_unlabeled[0].to(device)
        _, outClassUnlabeled  = discrminator(x_unlabeled)
        
        logz_unlabeled = torch.logsumexp(outClassUnlabeled, dim=1)
        lossUnlabeled  = -0.5 * torch.mean(logz_unlabeled) + 0.5 * torch.mean(F.softplus(logz_unlabeled))
        
        ## Fake
        fakeNoise1       = torch.randn(x_unlabeled.size(0), 30, device=device)
        x_Fake1          = ( generator(fakeNoise1) + 1.0 ) / 2
        _, outClassFake1 = discrminator(x_Fake1)

        logz_fake1 = torch.logsumexp(outClassFake1, dim=1)
        lossFake  = 0.5 * torch.mean(F.softplus(logz_fake1))
        

        ## loss
        totalLoss = lossLabeled + lossUnlabeled + lossFake
        
        ## optimization
        writer.add_scalar("training_loss/supervised", lossLabeled, global_step)
        writer.add_scalar("training_loss/unsupervised", lossUnlabeled+lossFake, global_step)
        writer.add_scalar("training_loss/Discriminator", totalLoss, global_step)

        totalLoss.backward()
        optimizerDis.step()
        
                
        #################################################################################  Generator
        discrminator.eval()
        generator.train()
        optimizerGen.zero_grad()
        
        fakeNoise2 = torch.randn(x_unlabeled.size(0), 30, device=device)
        x_Fake2    = ( generator(fakeNoise2) + 1.0 ) / 2
        
        ## loss
        y_pred_unlabeled, _ = discrminator(x_unlabeled)
        y_pred_fake, _      = discrminator(x_Fake2)
        mom_real = torch.mean(y_pred_unlabeled, dim=0)
        mom_fake = torch.mean(y_pred_fake, dim=0)
        diff = mom_fake * 100 - mom_real * 100
        lossG = torch.mean(diff * diff)
        
        
        ## optimization
        writer.add_scalar("training_loss/Generator", lossG, global_step)
        lossG.backward()        
        optimizerGen.step()

        global_step += 1
  
    
    
    training_loss, training_accuracy = evaluator(discrminator, cross_entropy, device, training_batch, False)
    writer.add_scalar("accuracy/training",  training_accuracy, epoch)
    
    validation_loss, validation_accuracy = evaluator(discrminator, cross_entropy, device, validation_batch, False)
    writer.add_scalar("accuracy/validation", validation_accuracy, epoch)
    writer.flush()

writer.close()

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [08:40<00:00, 52.00s/it]


In [10]:
# torch.save(discrminator.state_dict(), './model_save/%s.pt' % prefix)

### 

## the performance of testing data

In [9]:
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.metrics import f1_score 
import math

def myEval(model, device, test_loader, display = False):
    model.eval()
    
    target_list = []
    output_list = []
    
    with torch.no_grad():
        for testdata in test_loader:
            data, target = testdata
            data, target = data.to(device), target.to(device)
            _, output = model(data)
            softmax2_score = [ math.exp(i[1]) / ( math.exp(i[0]) + math.exp(i[1]) ) for i in output.cpu().numpy() ]
            target_list += target.cpu().tolist()
            output_list += softmax2_score

    return target_list, output_list


def evaluation_df(pred_score, labeled_y):
    def TP_table(pred_score, labeled_y, threshold):
        y_pred = [0 if i < threshold else 1 for i in pred_score]
        y_true = list(labeled_y)


        fpr, tpr, _ = roc_curve(y_true, pred_score)
        auc_val = auc(fpr, tpr)


        tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
        # TP TN FP FN sensitivity specificity Accuracy
        sensitivity = tp/(tp+fn)

        specificity = tn/(tn+fp)
        accuracy    = (tp+tn)/(tp+tn+fp+fn)
        
        F1 = f1_score(y_true, y_pred)
        
        try:
            MCC = ((tp*tn)-(fp*fn)) / ((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))**0.5
        except:
            MCC = np.nan

        return [threshold, tp, fp, tn, fn, sensitivity, specificity, accuracy, auc_val, MCC, F1]

    res = []
    for i in range(1,20):
        threshold = i / 20
        res.append(TP_table(pred_score, labeled_y, threshold))

    res = pd.DataFrame(res, columns=['threshold', 'TP', 'FP', 'TN', 'FN', 'sen', 'spe', 'Acc', 'AUC', 'MCC', 'F1'])
    return res


In [11]:
testing_labeled = pd.read_csv("example/training_data/labeled_testing.dat").values
labeled_features, labeled_targets = testing_labeled[:, :-1], testing_labeled[:, -1]


tensor_dat = TensorDataset(torch.Tensor(labeled_features[:, np.newaxis, :]), torch.Tensor(labeled_targets).long())
testing_batch = DataLoader(dataset=tensor_dat, batch_size = batchsize, shuffle=False)

label, pred = myEval(discrminator, device, testing_batch)
evaluation_df(pred, label)

Unnamed: 0,threshold,TP,FP,TN,FN,sen,spe,Acc,AUC,MCC,F1
0,0.05,1284,3721,1170,51,0.961798,0.239215,0.394154,0.800737,0.20778,0.405047
1,0.1,1278,3518,1373,57,0.957303,0.28072,0.425795,0.800737,0.232247,0.416898
2,0.15,1265,3375,1516,70,0.947566,0.309957,0.446675,0.800737,0.242574,0.423431
3,0.2,1255,3266,1625,80,0.940075,0.332243,0.462576,0.800737,0.250632,0.42862
4,0.25,1250,3181,1710,85,0.93633,0.349622,0.475426,0.800737,0.259089,0.433576
5,0.3,1244,3115,1776,91,0.931835,0.363116,0.485063,0.800737,0.264195,0.436951
6,0.35,1239,3053,1838,96,0.92809,0.375792,0.494218,0.800737,0.269517,0.440377
7,0.4,1234,3001,1890,101,0.924345,0.386424,0.501767,0.800737,0.273473,0.443088
8,0.45,1226,2955,1936,109,0.918352,0.395829,0.50787,0.800737,0.274557,0.444525
9,0.5,1219,2904,1987,116,0.913109,0.406256,0.514937,0.800737,0.277141,0.446684
