In [1]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import math
import sklearn.preprocessing as sk
import seaborn as sns
from sklearn import metrics
from sklearn.feature_selection import VarianceThreshold
from sklearn.model_selection import train_test_split
from utils import AllTripletSelector,HardestNegativeTripletSelector, RandomNegativeTripletSelector, SemihardNegativeTripletSelector # Strategies for selecting triplets within a minibatch
from metrics import AverageNonzeroTripletsMetric
from torch.utils.data.sampler import WeightedRandomSampler
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
import random
from random import randint
from sklearn.model_selection import StratifiedKFold

In [2]:
save_results_to = '/common/statsgeneral/gayara/MOLI/Pan_drug_2/results/'
torch.manual_seed(42)
random.seed(42)

GDSCE = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/GDSC_exprs.z.EGFRi.tsv", 
                    sep = "\t", index_col=0, decimal = ",")
GDSCE = pd.DataFrame.transpose(GDSCE)

GDSCM = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/GDSC_mutations.EGFRi.tsv", 
                    sep = "\t", index_col=0, decimal = ".")
GDSCM = pd.DataFrame.transpose(GDSCM)
GDSCM = GDSCM.loc[:,~GDSCM.columns.duplicated()]

GDSCC = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/GDSC_CNA.EGFRi.tsv", 
                    sep = "\t", index_col=0, decimal = ".")
GDSCC.drop_duplicates(keep='last')
GDSCC = pd.DataFrame.transpose(GDSCC)
GDSCC = GDSCC.loc[:,~GDSCC.columns.duplicated()]

PDXEerlo = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/PDX_exprs.Erlotinib.eb_with.GDSC_exprs.Erlotinib.tsv", 
                   sep = "\t", index_col=0, decimal = ",")
PDXEerlo = pd.DataFrame.transpose(PDXEerlo)
PDXMerlo = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/PDX_mutations.Erlotinib.tsv", 
                   sep = "\t", index_col=0, decimal = ",")
PDXMerlo = pd.DataFrame.transpose(PDXMerlo)
PDXCerlo = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/PDX_CNA.Erlotinib.tsv", 
                   sep = "\t", index_col=0, decimal = ",")
PDXCerlo.drop_duplicates(keep='last')
PDXCerlo = pd.DataFrame.transpose(PDXCerlo)
PDXCerlo = PDXCerlo.loc[:,~PDXCerlo.columns.duplicated()]

PDXEcet = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/PDX_exprs.Cetuximab.eb_with.GDSC_exprs.Cetuximab.tsv", 
                   sep = "\t", index_col=0, decimal = ",")
PDXEcet = pd.DataFrame.transpose(PDXEcet)
PDXMcet = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/PDX_mutations.Cetuximab.tsv", 
                   sep = "\t", index_col=0, decimal = ",")
PDXMcet = pd.DataFrame.transpose(PDXMcet)
PDXCcet = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/PDX_CNA.Cetuximab.tsv", 
                   sep = "\t", index_col=0, decimal = ",")
PDXCcet.drop_duplicates(keep='last')
PDXCcet = pd.DataFrame.transpose(PDXCcet)

In [3]:

PDXCcet = PDXCcet.loc[:,~PDXCcet.columns.duplicated()]

selector = VarianceThreshold(0.05)
selector.fit_transform(GDSCE)
GDSCE = GDSCE[GDSCE.columns[selector.get_support(indices=True)]]

GDSCM = GDSCM.fillna(0)
GDSCM[GDSCM != 0.0] = 1
GDSCC = GDSCC.fillna(0)
GDSCC[GDSCC != 0.0] = 1

ls = GDSCE.columns.intersection(GDSCM.columns)
ls = ls.intersection(GDSCC.columns)
ls = ls.intersection(PDXEerlo.columns)
ls = ls.intersection(PDXMerlo.columns)
ls = ls.intersection(PDXCerlo.columns)
ls = ls.intersection(PDXEcet.columns)
ls = ls.intersection(PDXMcet.columns)
ls = ls.intersection(PDXCcet.columns)
ls2 = GDSCE.index.intersection(GDSCM.index)
ls2 = ls2.intersection(GDSCC.index)
ls3 = PDXEerlo.index.intersection(PDXMerlo.index)
ls3 = ls3.intersection(PDXCerlo.index)
ls4 = PDXEcet.index.intersection(PDXMcet.index)
ls4 = ls4.intersection(PDXCcet.index)
ls = pd.unique(ls)

PDXEerlo = PDXEerlo.loc[ls3,ls]
PDXMerlo = PDXMerlo.loc[ls3,ls]
PDXCerlo = PDXCerlo.loc[ls3,ls]
PDXEcet = PDXEcet.loc[ls4,ls]
PDXMcet = PDXMcet.loc[ls4,ls]
PDXCcet = PDXCcet.loc[ls4,ls]
GDSCE = GDSCE.loc[:,ls]
GDSCM = GDSCM.loc[:,ls]
GDSCC = GDSCC.loc[:,ls]

GDSCR = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/GDSC_response.EGFRi.tsv", 
                    sep = "\t", index_col=0, decimal = ",")

GDSCR.rename(mapper = str, axis = 'index', inplace = True)

d = {"R":0,"S":1}
GDSCR["response"] = GDSCR.loc[:,"response"].apply(lambda x: d[x])

responses = GDSCR
drugs = set(responses["drug"].values)
exprs_z = GDSCE
cna = GDSCC
mut = GDSCM
expression_zscores = []
CNA=[]
mutations = []
for drug in drugs:
    samples = responses.loc[responses["drug"]==drug,:].index.values
    e_z = exprs_z.loc[samples,:]
    c = cna.loc[samples,:]
    m = mut.loc[samples,:]
    m = mut.loc[samples,:]
    # next 3 rows if you want non-unique sample names
    e_z.rename(lambda x : str(x)+"_"+drug, axis = "index", inplace=True)
    c.rename(lambda x : str(x)+"_"+drug, axis = "index", inplace=True)
    m.rename(lambda x : str(x)+"_"+drug, axis = "index", inplace=True)
    expression_zscores.append(e_z)
    CNA.append(c)
    mutations.append(m)
responses.index = responses.index.values +"_"+responses["drug"].values
GDSCEv2 = pd.concat(expression_zscores, axis =0 )
GDSCCv2 = pd.concat(CNA, axis =0 )
GDSCMv2 = pd.concat(mutations, axis =0 )
GDSCRv2 = responses

ls2 = GDSCEv2.index.intersection(GDSCMv2.index)
ls2 = ls2.intersection(GDSCCv2.index)
GDSCEv2 = GDSCEv2.loc[ls2,:]
GDSCMv2 = GDSCMv2.loc[ls2,:]
GDSCCv2 = GDSCCv2.loc[ls2,:]
GDSCRv2 = GDSCRv2.loc[ls2,:]

Y = GDSCRv2['response'].values

PDXRcet = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/PDX_response.Cetuximab.tsv", 
                       sep = "\t", index_col=0, decimal = ",")
PDXRcet.loc[PDXRcet.iloc[:,1] == 'R'] = 0
PDXRcet.loc[PDXRcet.iloc[:,1] == 'S'] = 1
PDXRcet = PDXRcet.loc[ls4,:]
Ytscet = PDXRcet['response'].values    

PDXRerlo = pd.read_csv("/common/statsgeneral/gayara/MOLI/Pan_drug_2/all_data/PDX_response.Erlotinib.tsv", 
                       sep = "\t", index_col=0, decimal = ",")
PDXRerlo.loc[PDXRerlo.iloc[:,1] == 'R'] = 0
PDXRerlo.loc[PDXRerlo.iloc[:,1] == 'S'] = 1
PDXRerlo = PDXRerlo.loc[ls3,:]
Ytserlo = PDXRerlo['response'].values  

hdm1 = 32
hdm2 = 16
hdm3 = 256
rate1 = 0.5
rate2 = 0.8
rate3 = 0.5
rate4 = 0.3
mbs = 16
mrg = 1.5
lre = 0.001
lrm = 0.0001
lrc = 5e-05
epch = 20
lam = 0.5
lrCL = 0.005
wd = 0.0001

In [4]:
scalerGDSC = sk.StandardScaler()
scalerGDSC.fit(GDSCEv2.values)
X_trainE = scalerGDSC.transform(GDSCEv2.values)
X_testEerlo = scalerGDSC.transform(PDXEerlo.values)    
X_testEcet = scalerGDSC.transform(PDXEcet.values) 
y_trainE = Y

X_trainM = np.nan_to_num(GDSCMv2.values)
X_trainC = np.nan_to_num(GDSCCv2.values)
X_testMerlo = np.nan_to_num(PDXMerlo.values)
X_testCerlo = np.nan_to_num(PDXCerlo.values)
X_testMcet = np.nan_to_num(PDXMcet.values)
X_testCcet = np.nan_to_num(PDXCcet.values)

TX_testEerlo = torch.FloatTensor(X_testEerlo)
TX_testMerlo = torch.FloatTensor(X_testMerlo)
TX_testCerlo = torch.FloatTensor(X_testCerlo)
ty_testEerlo = torch.FloatTensor(Ytserlo.astype(int))

TX_testEcet = torch.FloatTensor(X_testEcet)
TX_testMcet = torch.FloatTensor(X_testMcet)
TX_testCcet = torch.FloatTensor(X_testCcet)
ty_testEcet = torch.FloatTensor(Ytscet.astype(int))

#Train
class_sample_count = np.array([len(np.where(y_trainE==t)[0]) for t in np.unique(y_trainE)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[t] for t in y_trainE])

samples_weight = torch.from_numpy(samples_weight)
sampler = WeightedRandomSampler(samples_weight.type('torch.DoubleTensor'), len(samples_weight), replacement=True)

mb_size = mbs

trainDataset = torch.utils.data.TensorDataset(torch.FloatTensor(X_trainE), torch.FloatTensor(X_trainM), 
                                                      torch.FloatTensor(X_trainC), torch.FloatTensor(y_trainE.astype(int)))

trainLoader = torch.utils.data.DataLoader(dataset = trainDataset, batch_size=mb_size, shuffle=False, num_workers=1, sampler = sampler)

n_sampE, IE_dim = X_trainE.shape
n_sampM, IM_dim = X_trainM.shape
n_sampC, IC_dim = X_trainC.shape

h_dim1 = hdm1
h_dim2 = hdm2
h_dim3 = hdm3        
Z_in = h_dim1 + h_dim2 + h_dim3
marg = mrg
lrE = lre
lrM = lrm
lrC = lrc
epoch = epch

costtr = []
auctr = []
costts = []
aucts = []

triplet_selector = RandomNegativeTripletSelector(marg)
triplet_selector2 = AllTripletSelector()


class AEE(nn.Module):
    def __init__(self):
        super(AEE, self).__init__()
        self.EnE = torch.nn.Sequential(
            nn.Linear(IE_dim, h_dim1),
            nn.BatchNorm1d(h_dim1),
            nn.ReLU(),
            nn.Dropout(rate1))
    def forward(self, x):
        output = self.EnE(x)
        return output

class AEM(nn.Module):
    def __init__(self):
        super(AEM, self).__init__()
        self.EnM = torch.nn.Sequential(
            nn.Linear(IM_dim, h_dim2),
            nn.BatchNorm1d(h_dim2),
            nn.ReLU(),
            nn.Dropout(rate2))
    def forward(self, x):
        output = self.EnM(x)
        return output    


class AEC(nn.Module):
    def __init__(self):
        super(AEC, self).__init__()
        self.EnC = torch.nn.Sequential(
            nn.Linear(IM_dim, h_dim3),
            nn.BatchNorm1d(h_dim3),
            nn.ReLU(),
            nn.Dropout(rate3))
    def forward(self, x):
        output = self.EnC(x)
        return output       

class OnlineTriplet(nn.Module):
    def __init__(self, marg, triplet_selector):
        super(OnlineTriplet, self).__init__()
        self.marg = marg
        self.triplet_selector = triplet_selector
    def forward(self, embeddings, target):
        triplets = self.triplet_selector.get_triplets(embeddings, target)
        return triplets

class OnlineTestTriplet(nn.Module):
    def __init__(self, marg, triplet_selector):
        super(OnlineTestTriplet, self).__init__()
        self.marg = marg
        self.triplet_selector = triplet_selector
    def forward(self, embeddings, target):
        triplets = self.triplet_selector.get_triplets(embeddings, target)
        return triplets    

class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.FC = torch.nn.Sequential(
            nn.Linear(Z_in, 1),
            nn.Dropout(rate4),
            nn.Sigmoid())
    def forward(self, x):
        return self.FC(x)

torch.cuda.manual_seed_all(42)


AutoencoderE = AEE()
AutoencoderM = AEM()
AutoencoderC = AEC()

solverE = optim.Adagrad(AutoencoderE.parameters(), lr=lrE)
solverM = optim.Adagrad(AutoencoderM.parameters(), lr=lrM)
solverC = optim.Adagrad(AutoencoderC.parameters(), lr=lrC)

trip_criterion = torch.nn.TripletMarginLoss(margin=marg, p=2)
TripSel = OnlineTriplet(marg, triplet_selector)
TripSel2 = OnlineTestTriplet(marg, triplet_selector2)

Clas = Classifier()
SolverClass = optim.Adagrad(Clas.parameters(), lr=lrCL, weight_decay = wd)
C_loss = torch.nn.BCELoss()

for it in range(epoch):

    epoch_cost4 = 0
    epoch_cost3 = []
    num_minibatches = int(n_sampE / mb_size) 

    for i, (dataE, dataM, dataC, target) in enumerate(trainLoader):
        flag = 0
        AutoencoderE.train()
        AutoencoderM.train()
        AutoencoderC.train()
        Clas.train()

        if torch.mean(target)!=0. and torch.mean(target)!=1.: 
            ZEX = AutoencoderE(dataE)
            ZMX = AutoencoderM(dataM)
            ZCX = AutoencoderC(dataC)

            ZT = torch.cat((ZEX, ZMX, ZCX), 1)
            ZT = F.normalize(ZT, p=2, dim=0)
            Pred = Clas(ZT)

            Triplets = TripSel2(ZT, target)
            loss = lam * trip_criterion(ZT[Triplets[:,0],:],ZT[Triplets[:,1],:],ZT[Triplets[:,2],:]) + C_loss(Pred,target.view(-1,1))     

            y_true = target.view(-1,1)
            y_pred = Pred
            AUC = roc_auc_score(y_true.detach().numpy(),y_pred.detach().numpy()) 

            solverE.zero_grad()
            solverM.zero_grad()
            solverC.zero_grad()
            SolverClass.zero_grad()

            loss.backward()

            solverE.step()
            solverM.step()
            solverC.step()
            SolverClass.step()

            epoch_cost4 = epoch_cost4 + (loss / num_minibatches)
            epoch_cost3.append(AUC)
            flag = 1

    if flag == 1:
        costtr.append(torch.mean(epoch_cost4))
        auctr.append(np.mean(epoch_cost3))
        print('Iter-{}; Total loss: {:.4}'.format(it, loss))

Iter-0; Total loss: 1.317
Iter-1; Total loss: 1.304
Iter-2; Total loss: 1.504
Iter-3; Total loss: 0.879
Iter-4; Total loss: 1.19
Iter-5; Total loss: 0.9562
Iter-6; Total loss: 1.114
Iter-7; Total loss: 0.9934
Iter-8; Total loss: 1.384
Iter-9; Total loss: 0.8135
Iter-10; Total loss: 1.603
Iter-11; Total loss: 1.105
Iter-12; Total loss: 1.181
Iter-13; Total loss: 1.341
Iter-14; Total loss: 1.111
Iter-15; Total loss: 1.034
Iter-16; Total loss: 1.149
Iter-17; Total loss: 0.9583
Iter-18; Total loss: 1.432
Iter-19; Total loss: 1.166


In [5]:
with torch.no_grad():

    AutoencoderE.eval()
    AutoencoderM.eval()
    AutoencoderC.eval()
    Clas.eval()
    
    ZEX = AutoencoderE(torch.FloatTensor(X_trainE))
    ZMX = AutoencoderM(torch.FloatTensor(X_trainM))
    ZCX = AutoencoderC(torch.FloatTensor(X_trainC))
    ZTX = torch.cat((ZEX, ZMX, ZCX), 1)
    ZTX = F.normalize(ZTX, p=2, dim=0)
    PredX = Clas(ZTX)
    AUCt = roc_auc_score(Y, PredX.detach().numpy())
    print(AUCt)

0.95274047404142


In [6]:
with torch.no_grad():

    AutoencoderE.eval()
    AutoencoderM.eval()
    AutoencoderC.eval()
    Clas.eval()

    ZETerlo = AutoencoderE(TX_testEerlo)
    ZMTerlo = AutoencoderM(TX_testMerlo)
    ZCTerlo = AutoencoderC(TX_testCerlo)
    ZTTerlo = torch.cat((ZETerlo, ZMTerlo, ZCTerlo), 1)
    ZTTerlo = F.normalize(ZTTerlo, p=2, dim=0)
    PredTerlo = Clas(ZTTerlo)
    Ytserlo = Ytserlo.astype('int64')
    AUCterlo = roc_auc_score(Ytserlo, PredTerlo.detach().numpy())
    print(AUCterlo)


0.7222222222222223


In [7]:
with torch.no_grad():

    AutoencoderE.eval()
    AutoencoderM.eval()
    AutoencoderC.eval()
    Clas.eval()

    ZETcet = AutoencoderE(TX_testEcet)
    ZMTcet = AutoencoderM(TX_testMcet)
    ZCTcet = AutoencoderC(TX_testCcet)
    ZTTcet = torch.cat((ZETcet, ZMTcet, ZCTcet), 1)
    ZTTcet = F.normalize(ZTTcet, p=2, dim=0)
    PredTcet = Clas(ZTTcet)
    Ytscet = Ytscet.astype('int64')
    AUCtcet = roc_auc_score(Ytscet, PredTcet.detach().numpy())
    print(AUCtcet)

0.7854545454545454
