In [1]:
from __future__ import absolute_import, division
from tqdm.auto import tqdm
import torch
import torch.nn.functional as F
from trainingUtils import MultipleOptimizer, MultipleScheduler, compute_kernel, compute_mmd
from models import Decoder,Encoder, SimpleEncoder,LocalDiscriminator,PriorDiscriminator,Classifier,SpeciesCovariate
from buildFrame import GlobalAEFramework
import math
import numpy as np
import pandas as pd
import sys
import random
import os
from sklearn.metrics import silhouette_score,confusion_matrix,f1_score
from scipy.stats import spearmanr
from evaluationUtils import r_square,get_cindex,pearson_r,pseudoAccuracy

from matplotlib import pyplot as plt
import seaborn as sns
sns.set()

# Load all data

In [2]:
device = torch.device('cuda')
# Initialize environment and seeds for reproducability
torch.backends.cudnn.benchmark = True
def seed_everything(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.benchmark = False

In [3]:
cmap1 = pd.read_csv('../preprocessing/preprocessed_data/cmap_landmarks_HT29_A375.csv',index_col = 0)
cmap2 = pd.read_csv('../preprocessing/preprocessed_data/cmap_landmarks_HA1E_PC3.csv',index_col = 0)
cmap = pd.concat((cmap1,cmap2),axis=0)
cmap

Unnamed: 0,16,23,25,30,39,47,102,128,142,154,...,94239,116832,124583,147179,148022,200081,200734,256364,375346,388650
PCL001_HT29_24H:BRD-K42991516:10,0.266452,-0.250874,-0.854204,-0.041545,0.204450,0.709800,-0.328601,-0.498116,-1.454481,0.506321,...,0.536235,0.024452,0.928558,-0.453246,-0.140290,0.205065,1.148706,-1.933820,1.966937,-0.159919
PCL001_HT29_24H:BRD-K50817946:10,6.074023,-0.524075,-0.635742,2.014629,-3.747274,2.109600,0.847576,-2.732549,-5.729352,2.164091,...,0.447939,1.543649,-3.775020,1.827991,-0.088051,0.382848,1.400255,-3.087269,1.392148,1.027263
PCL001_HT29_24H:BRD-K58479490:10,4.145089,-0.881727,-1.720977,1.636901,1.614980,0.092948,0.711952,-0.088671,-1.531390,-0.591393,...,-0.312573,0.095138,2.229333,0.250220,1.523056,-0.394704,-0.167089,0.833252,0.325481,-0.652675
DOSBIO001_A375_24H:BRD-K72343629:10.1316,1.545521,1.061800,1.165320,-1.052685,-3.449826,0.503872,1.850187,-0.426328,0.004190,1.948446,...,-1.748134,-0.636907,1.142301,1.178548,5.263110,-0.141872,-1.490323,0.526244,1.637315,-0.829246
PCL001_HT29_24H:BRD-A89859721:0.12,-0.063263,0.358551,-0.024186,0.695202,-2.394504,0.329883,-0.117662,-0.779904,0.439334,3.228897,...,-0.474038,-0.143447,1.764741,1.436673,0.602154,1.120865,-0.255665,0.316766,0.717193,-0.772269
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
DOSVAL003_HA1E_24H:BRD-K99481965:20,1.228643,2.201454,1.268494,-2.925642,-5.702180,-1.048154,-0.511689,-1.073713,-0.987551,0.088006,...,-1.521899,-1.388441,0.738075,-0.532254,-0.637348,-0.950856,1.951221,1.633684,0.036308,1.814479
DOSVAL001_HA1E_24H:BRD-K25712227:10.0073,1.240146,-0.854687,-0.066887,-0.824974,-3.184588,-1.167342,-1.893655,-0.301650,-2.429506,2.668036,...,-1.902034,-0.527167,1.294609,-0.512041,1.322372,1.399906,1.182881,0.106317,-0.632841,-1.876167
DOSVAL001_PC3_24H:BRD-K60763357:10.036,5.411014,-0.472251,-0.512716,1.804495,0.990145,0.694528,0.157606,-1.507224,-6.045609,-3.202184,...,-0.563177,-2.870153,1.764212,-0.991561,3.929906,1.254488,-0.011530,-1.532616,-0.706827,0.888209
DOSVAL002_PC3_24H:BRD-K58415436:20,4.105516,0.305956,2.209438,-0.029295,-1.663379,-0.174490,2.631207,-1.057758,-1.062944,1.510586,...,-0.394032,-0.496681,1.682302,-0.722316,0.301925,0.757042,1.770458,-0.691024,-0.490237,0.569134


In [4]:
gene_size = len(cmap.columns)
samples = cmap.index.values

In [5]:
# Create a train generators
def getSamples(N, batchSize):
    order = np.random.permutation(N)
    outList = []
    while len(order)>0:
        outList.append(order[0:batchSize])
        order = order[batchSize:]
    return outList

In [6]:
def compute_gradients(output, input):
    grads = torch.autograd.grad(output, input, create_graph=True)
    grads = grads[0].pow(2).mean()
    return grads

# Initialize parameters and models

In [7]:
model_params = {'encoder_hiddens':[640,384],
                'latent_dim': 292,
                'decoder_hiddens':[384,640],
                'dropout_decoder':0.2,
                'dropout_encoder':0.1,
                'encoder_activation':torch.nn.ELU(),
                'decoder_activation':torch.nn.ELU(),
                'V_dropout':0.25,
                'state_class_hidden':[256,128,64],
                'state_class_drop_in':0.5,
                'state_class_drop':0.25,
                'no_states':4,
                'adv_class_hidden':[256,128,64],
                'adv_class_drop_in':0.3,
                'adv_class_drop':0.1,
                'no_adv_class':4,
                'encoding_lr':0.01,
                'adv_lr':0.001,
                'schedule_step_adv':200,
                'gamma_adv':0.5,
                'schedule_step_enc':400,
                'gamma_enc':0.8,
                'batch_size_1':178,
                'batch_size_2':154,
                'batch_size_paired':90,
                'epochs':2000,
                'prior_beta':1.0,
                'no_folds':10,
                'v_reg':1e-04,
                'state_class_reg':1e-04,
                'enc_l2_reg':0.00001,
                'dec_l2_reg':0.01,
                'lambda_mi_loss':100,
                'effsize_reg': 100,
                'cosine_loss': 10,
                'adv_penalnty':100,
                'reg_adv':1000,
                'reg_classifier': 1000,
                'similarity_reg' : 10,
                'adversary_steps':4,
                'autoencoder_wd': 0.,
                'adversary_wd': 0.}

In [8]:
class_criterion = torch.nn.CrossEntropyLoss()
NUM_EPOCHS= model_params['epochs']
bs_1 = model_params['batch_size_1']
bs_2 =  model_params['batch_size_2']
bs_paired =  model_params['batch_size_paired']

## Pre-train adverse classifier

In [9]:
# globalFrame = GlobalAEFramework(model_params['no_states'],[[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]],
#                                 [gene_size,gene_size,gene_size,gene_size],
#                                 [model_params['encoder_hiddens'],model_params['encoder_hiddens'],model_params['encoder_hiddens'],model_params['encoder_hiddens']],
#                                 model_params['latent_dim'], 
#                                 [model_params['decoder_hiddens'],model_params['decoder_hiddens'],model_params['decoder_hiddens'],model_params['decoder_hiddens']] ,
#                                 dropRateEnc=model_params['dropout_encoder'], activationEnc=model_params['encoder_activation'],
#                                 dropRateDec=model_params['dropout_decoder'], activationDec=model_params['decoder_activation'],
#                                 covariateDrop = model_params['V_dropout']).to(device)

prior_d = PriorDiscriminator(model_params['latent_dim']).to(device)
local_d = LocalDiscriminator(model_params['latent_dim'],model_params['latent_dim']).to(device)

adverse_classifier = Classifier(in_channel=model_params['latent_dim'],
                                hidden_layers=model_params['adv_class_hidden'],
                                num_classes=model_params['no_adv_class'],
                                drop_in=0.1,
                                drop=0.1).to(device)

In [10]:
data = pd.read_csv('../preprocessing/preprocessed_data/AllfilteredDrugs.csv',index_col=0)
#data = data[(data['cell_iname'] == 'PC3')| (data['cell_iname'] == 'HA1E') | (data['cell_iname'] == 'A375') | (data['cell_iname'] == 'HT29')]
data = [data[data['cell_iname'] == 'PC3'],data[data['cell_iname'] == 'HA1E'],
        data[data['cell_iname'] == 'A375'],data[data['cell_iname'] == 'HT29']]

In [11]:
# num_of_modalities = globalFrame.num_of_modalities
# paired_modalities = globalFrame.paired_modalities

In [12]:
num_of_modalities = model_params['no_states']
paired_modalities = [[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]]

In [13]:
for i in range(len(paired_modalities)):
    sampleInfo1 = data[paired_modalities[i][0]].loc[:,['conditionId','sig_id','cell_iname']]
    sampleInfo2 = data[paired_modalities[i][1]].loc[:,['conditionId','sig_id','cell_iname']]
    paired = sampleInfo1.merge(sampleInfo2,how='inner',left_on='conditionId',right_on='conditionId')
    
    if i==0:
        pairedSamples = paired.reset_index(drop=True).copy()
    else:
        pairedSamples = pd.concat((pairedSamples,paired.reset_index(drop=True)),axis=0).reset_index(drop=True)

paired_sigs = np.union1d(pairedSamples.sig_id_x.values,pairedSamples.sig_id_y.values)
pairedSamples

Unnamed: 0,conditionId,sig_id_x,cell_iname_x,sig_id_y,cell_iname_y
0,SA-1921085_10 uM_24 h,DOSBIO001_A375_24H:BRD-K58292285:9.838,A375,DOSBIO001_HT29_24H:BRD-K58292285:9.838,HT29
1,BRD-K88822846_10 uM_24 h,DOSBIO001_A375_24H:BRD-K88822846:10.1074,A375,DOSBIO001_HT29_24H:BRD-K88822846:10.1074,HT29
2,BRD-K11169037_10 uM_24 h,DOSBIO001_A375_24H:BRD-K11169037:10.0519,A375,DOSBIO001_HT29_24H:BRD-K11169037:10.0519,HT29
3,UMB-32_10 uM_24 h,PCL001_A375_24H:BRD-K21532219:10,A375,PCL001_HT29_24H:BRD-K21532219:10,HT29
4,BRD-K58479490_10 uM_24 h,PCL001_A375_24H:BRD-K58479490:10,A375,PCL001_HT29_24H:BRD-K58479490:10,HT29
...,...,...,...,...,...
401,BRD-K48853221_10 uM_24 h,DOSVAL001_A375_24H:BRD-K48853221:10,A375,DOSVAL001_HT29_24H:BRD-K48853221:10,HT29
402,BRD-K66214645_10 uM_24 h,DOSVAL001_A375_24H:BRD-K66214645:10,A375,DOSVAL001_HT29_24H:BRD-K66214645:10,HT29
403,BRD-K63423329_20 uM_24 h,DOSVAL006_A375_24H:BRD-K63423329:20,A375,DOSVAL006_HT29_24H:BRD-K63423329:20,HT29
404,BRD-K05885357_10 uM_24 h,DOSVAL001_A375_24H:BRD-K05885357:10.0043,A375,DOSVAL001_HT29_24H:BRD-K05885357:10.0043,HT29


In [14]:
dataUnpaired = []
for i,df in enumerate(data):
    dataUnpaired.append(df[df.sig_id.isin(paired_sigs)].reset_index(drop=True))

In [15]:
N_paired = len(pairedSamples)
sample_sizes = [len(d) for d in data]

In [16]:
sample_sizes

[1189, 1087]

In [17]:
batchSizes = [256,256,256,256]
#bs_paired = 250

In [18]:
class_criterion = torch.nn.CrossEntropyLoss()
NUM_EPOCHS= model_params['epochs']

In [19]:
class MultiEncoder(torch.nn.Module):
    def __init__(self,num_of_modalities,paired_modalities,
                 in_channels, enc_hidden_layers, latent_dim ,
                 dropRateEnc=0.1, activationEnc=None, biasEnc=True,normalizeLatent=False,
                 variational=False):

        super(MultiEncoder, self).__init__()

        self.num_of_modalities = num_of_modalities
        self.paired_modalities = paired_modalities
        self.encoders = torch.nn.ModuleList()

        for i in range(self.num_of_modalities):
            if variational == True:
                self.encoders.append(Encoder(in_channels[i], enc_hidden_layers[i], latent_dim,dropRateEnc, activationEnc, biasEnc))
            else:
                self.encoders.append(SimpleEncoder(in_channels[i], enc_hidden_layers[i], latent_dim, dropRateEnc, activationEnc,normalizeLatent, biasEnc))

    def forward(self, x):
        z_latents = []
        for i in range(self.num_of_modalities):
            z = self.encoders[i](x[i])
            z_latents.append(z)
        return torch.cat(z_latents,0)

In [20]:
multi_enc = MultiEncoder(model_params['no_states'],[[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]],#[0,2],[0,3],[1,2],[1,3],[2,3]
                         [gene_size,gene_size,gene_size,gene_size],
                         [model_params['encoder_hiddens'],model_params['encoder_hiddens'],model_params['encoder_hiddens'],model_params['encoder_hiddens']],
                         model_params['latent_dim'], 
                         [model_params['decoder_hiddens'],model_params['decoder_hiddens'],model_params['decoder_hiddens'],model_params['decoder_hiddens']] ,
                         dropRateEnc=model_params['dropout_encoder'], activationEnc=model_params['encoder_activation']).to(device)

In [21]:
allParams = list(multi_enc.parameters())
allParams = allParams + list(prior_d.parameters()) + list(local_d.parameters())
allParams = allParams + list(adverse_classifier.parameters())
optimizer = torch.optim.Adam(allParams, lr=model_params['encoding_lr'], weight_decay=0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=model_params['schedule_step_enc'],
                                            gamma=model_params['gamma_enc'])

In [22]:
for e in range(0, NUM_EPOCHS):
    multi_enc.train()
    prior_d.train()
    local_d.train()
    adverse_classifier.train()
    
    
    trainLoadersList = []
    lenghts = []
    for load_id,bs in enumerate(batchSizes):
        trainLoad = getSamples(sample_sizes[load_id],bs)
        trainLoadersList.append(trainLoad)
        lenghts.append(len(trainLoad))
    #trainLoad = getSamples(N_paired,bs_paired)
    #trainLoadersList.append(trainLoad)
    #lenghts.append(len(trainLoad))
    
    maxLen = np.max(lenghts)
    
    for load_id,trainLoad in enumerate(trainLoadersList):
        if maxLen>lenghts[load_id]:
            trainloader_suppl = getSamples(sample_sizes[load_id],bs)
            for jj in range(maxLen-lenghts[load_id]):
                trainLoad.insert(jj,trainloader_suppl[jj])
        trainLoadersList[load_id] = trainLoad
            
    for j in range(maxLen):
        dataIndices = [trainLoad[j] for trainLoad in trainLoadersList]
        #dataIndex_paired = dataIndices[-1]
            
        #df_pairs = pairedSamples.iloc[dataIndex_paired,:]
        #paired_inds = len(df_pairs)
        dfs = [data[x].iloc[dataIndices[x],:] for x in range(len(dataIndices))]
        
        ### Create paired samples DURING TRAINING BELLOW
        #for pair_id in range(len(paired_modalities)):
        #    sampleInfo1 = dfs[paired_modalities[pair_id][0]].loc[:,['conditionId','sig_id','cell_iname']]
        #    sampleInfo2 = dfs[paired_modalities[pair_id][1]].loc[:,['conditionId','sig_id','cell_iname']]
        #    paired = sampleInfo1.merge(sampleInfo2,how='inner',left_on='conditionId',right_on='conditionId')
        #    if pair_id==0:
        #        conditions = np.concatenate((paired.conditionId.values,
        #                                     sampleInfo1.conditionId.values,
        #                                     paired.conditionId.values,
        #                                     sampleInfo2.conditionId.values))
        #    else:
        #        conditions = np.concatenate((conditions,
        #                                     paired.conditionId.values,
        #                                     sampleInfo1.conditionId.values,
        #                                     paired.conditionId.values,
        #                                     sampleInfo2.conditionId.values)) 
        X_s = []
        for df_id,df in enumerate(dfs):
            sampleInfo = df.loc[:,['conditionId','sig_id','cell_iname']]
            if df_id==0:
                conditions = df.conditionId.values
                true_labels = torch.ones(len(df)) * df_id
            else:
                conditions = np.concatenate((conditions,
                                             df.conditionId.values))
                true_labels = torch.cat((true_labels,torch.ones(len(df)) * df_id),0)
            X_s.append(torch.tensor(cmap.loc[df.sig_id.values].values).float().to(device))
        true_labels = true_labels.long().to(device)
        size = conditions.size
        conditions = conditions.reshape(size,1)
        conditions = conditions == conditions.transpose()
        conditions = conditions*1
        mask = torch.tensor(conditions).to(device).detach()
        pos_mask = mask
        neg_mask = 1 - mask
        log_2 = math.log(2.)
        #X_s = [torch.tensor(cmap.loc[df.sig_id.values].values).float().to(device) for df in dfs]     
        optimizer.zero_grad()
        
        #z_latents = []
        #for enc_id,enc in enumerate(encoders):
        #    z_latents.append(enc(X_s[enc_id]))
        #latent_vectors = torch.cat(z_latents, 0)
        latent_vectors = multi_enc(X_s)
        
        labels_adv = adverse_classifier(latent_vectors)
        _, predicted = torch.max(labels_adv, 1)
        predicted = predicted.cpu().numpy()
        f1 = f1_score(true_labels.cpu(), predicted,average='micro')
        adv_entropy = class_criterion(labels_adv,true_labels)
        
        z_un = local_d(latent_vectors)
        res_un = torch.matmul(z_un, z_un.t())
            
        p_samples = res_un * pos_mask.float()
        q_samples = res_un * neg_mask.float()

        Ep = log_2 - F.softplus(- p_samples)
        Eq = F.softplus(-q_samples) + q_samples - log_2

        Ep = (Ep * pos_mask.float()).sum() / pos_mask.float().sum()
        Eq = (Eq * neg_mask.float()).sum() / neg_mask.float().sum()
        mi_loss = Eq - Ep

        #prior = torch.rand_like(torch.cat((z_1, z_2), 0))
        prior = torch.rand_like(latent_vectors)

        term_a = torch.log(prior_d(prior)).mean()
        term_b = torch.log(1.0 - prior_d(latent_vectors)).mean()
        prior_loss = -(term_a + term_b) * model_params['prior_beta']
        
        #L2Loss = 0.
        #for enc in encoders:
        #    L2Loss = L2Loss + enc.L2Regularization(model_params['enc_l2_reg'])
        
        loss = mi_loss+prior_loss+adv_entropy+adverse_classifier.L2Regularization(model_params['state_class_reg'])#+L2Loss
        
        optimizer.step()
        
    scheduler.step()
    outString = 'Epoch={:.0f}/{:.0f}'.format(e+1,NUM_EPOCHS)
    outString += ', MI Loss={:.4f}'.format(mi_loss.item())
    outString += ', Prior Loss={:.4f}'.format(prior_loss.item())
    outString += ', Entropy Loss={:.4f}'.format(adv_entropy.item())
    outString += ', loss={:.4f}'.format(loss.item())
    outString += ', F1 score={:.4f}'.format(f1)
    if (e==0 or (e%250==0 and e>0)):
        print(outString)
print(outString)

Epoch=1/2000, MI Loss=3.7015, Prior Loss=1.3804, Entropy Loss=0.8781, loss=6.0130, F1 score=0.2553
Epoch=251/2000, MI Loss=3.6973, Prior Loss=1.3775, Entropy Loss=0.8868, loss=6.0145, F1 score=0.2624
Epoch=501/2000, MI Loss=3.6594, Prior Loss=1.3823, Entropy Loss=0.8850, loss=5.9796, F1 score=0.2553
Epoch=751/2000, MI Loss=3.7604, Prior Loss=1.3819, Entropy Loss=0.8811, loss=6.0763, F1 score=0.2979


KeyboardInterrupt: 

# Train the whole framework

In [None]:
globalFrame = GlobalAEFramework(model_params['no_states'],[[0,1],[0,2],[0,3],[1,2],[1,3],[2,3]],
                                [gene_size,gene_size,gene_size,gene_size],
                                [model_params['encoder_hiddens'],model_params['encoder_hiddens'],model_params['encoder_hiddens'],model_params['encoder_hiddens']],
                                model_params['latent_dim'], 
                                [model_params['decoder_hiddens'],model_params['decoder_hiddens'],model_params['decoder_hiddens'],model_params['decoder_hiddens']] ,
                                dropRateEnc=model_params['dropout_encoder'], activationEnc=model_params['encoder_activation'],
                                dropRateDec=model_params['dropout_decoder'], activationDec=model_params['decoder_activation'],
                                covariateDrop = model_params['V_dropout']).to(device)

prior_d = PriorDiscriminator(model_params['latent_dim']).to(device)
local_d = LocalDiscriminator(model_params['latent_dim'],model_params['latent_dim']).to(device)
    
classifier = Classifier(in_channel=model_params['latent_dim'],
                        hidden_layers=model_params['state_class_hidden'],
                        num_classes=model_params['no_states'],
                        drop_in=model_params['state_class_drop_in'],
                        drop=model_params['state_class_drop']).to(device)

adverse_classifier = Classifier(in_channel=model_params['latent_dim'],
                                hidden_layers=model_params['adv_class_hidden'],
                                num_classes=model_params['no_adv_class'],
                                drop_in=model_params['adv_class_drop_in'],
                                drop=model_params['adv_class_drop']).to(device)

In [None]:
allParams = list(globalFrame.parameters())
allParams = allParams + list(prior_d.parameters()) + list(local_d.parameters())
allParams = allParams + list(classifier.parameters())
optimizer = torch.optim.Adam(allParams, lr=model_params['encoding_lr'], weight_decay=0)
optimizer_adv = torch.optim.Adam(adverse_classifier.parameters(), lr=model_params['adv_lr'], weight_decay=0)
if model_params['schedule_step_adv'] is not None:
    scheduler_adv = torch.optim.lr_scheduler.StepLR(optimizer_adv,
                                                    step_size=model_params['schedule_step_adv'],
                                                    gamma=model_params['gamma_adv'])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=model_params['schedule_step_enc'],
                                                gamma=model_params['gamma_enc'])

In [70]:
valR2 = []
valPear = []
valMSE =[]
valSpear = []
valAccuracy = []


valPearDirect = []
valSpearDirect = []
valAccDirect = []

valR2_1 = []
valPear_1 = []
valMSE_1 =[]
valSpear_1 = []
valAccuracy_1 = []

valR2_2 = []
valPear_2 = []
valMSE_2 =[]
valSpear_2 = []
valAccuracy_2 = []

crossCorrelation = []

valF1 = []
valClassAcc = []

In [None]:
for e in range(0, NUM_EPOCHS):
    decoder_1.train()
    decoder_2.train()
    encoder_1.train()
    encoder_2.train()
    prior_d.train()
    local_d.train()
    classifier.train()
    adverse_classifier.train()
    Vsp.train()
    
    trainloader_1 = getSamples(N_1, bs_1)
    len_1 = len(trainloader_1)
    trainloader_2 = getSamples(N_2, bs_2)
    len_2 = len(trainloader_2)
    trainloader_paired = getSamples(N_paired, bs_paired)
    len_paired = len(trainloader_paired)

    lens = [len_1,len_2,len_paired]
    maxLen = np.max(lens)
    
    if maxLen>lens[0]:
        trainloader_suppl = getSamples(N_1, bs_1)
        for jj in range(maxLen-lens[0]):
            trainloader_1.insert(jj,trainloader_suppl[jj])
        
    if maxLen>lens[1]:
        trainloader_suppl = getSamples(N_2, bs_2)
        for jj in range(maxLen-lens[1]):
            trainloader_2.insert(jj,trainloader_suppl[jj])
        
    if maxLen>lens[2]:
        trainloader_suppl = getSamples(N_paired, bs_paired)
        for jj in range(maxLen-lens[2]):
            trainloader_paired.insert(jj,trainloader_suppl[jj])
            
    for j in range(maxLen):
        dataIndex_1 = trainloader_1[j]
        dataIndex_2 = trainloader_2[j]
        dataIndex_paired = trainloader_paired[j]
            
        df_pairs = trainInfo_paired.iloc[dataIndex_paired,:]
        df_1 = trainInfo_1.iloc[dataIndex_1,:]
        df_2 = trainInfo_2.iloc[dataIndex_2,:]
        paired_inds = len(df_pairs)
            
            
        X_1 = torch.tensor(np.concatenate((cmap.loc[df_pairs['sig_id.x']].values,
                                                 cmap.loc[df_1.sig_id].values))).float().to(device)
        X_2 = torch.tensor(np.concatenate((cmap.loc[df_pairs['sig_id.y']].values,
                                                 cmap.loc[df_2.sig_id].values))).float().to(device)
            
        z_species_1 = torch.cat((torch.ones(X_1.shape[0],1),
                                     torch.zeros(X_1.shape[0],1)),1).to(device)
        z_species_2 = torch.cat((torch.zeros(X_2.shape[0],1),
                                     torch.ones(X_2.shape[0],1)),1).to(device)
            
            
        conditions = np.concatenate((df_pairs.conditionId.values,
                                            df_1.conditionId.values,
                                            df_pairs.conditionId.values,
                                            df_2.conditionId.values))
        size = conditions.size
        conditions = conditions.reshape(size,1)
        conditions = conditions == conditions.transpose()
        conditions = conditions*1
        mask = torch.tensor(conditions).to(device).detach()
        pos_mask = mask
        neg_mask = 1 - mask
        log_2 = math.log(2.)
        optimizer.zero_grad()
        optimizer_adv.zero_grad()
                        
        #if e % model_params['adversary_steps']==0:
        z_base_1 = encoder_1(X_1)
        z_base_2 = encoder_2(X_2)
        latent_base_vectors = torch.cat((z_base_1, z_base_2), 0)
        labels_adv = adverse_classifier(latent_base_vectors)
        true_labels = torch.cat((torch.ones(z_base_1.shape[0]),
                                 torch.zeros(z_base_2.shape[0])),0).long().to(device)
        _, predicted = torch.max(labels_adv, 1)
        predicted = predicted.cpu().numpy()
        cf_matrix = confusion_matrix(true_labels.cpu().numpy(),predicted)
        tn, fp, fn, tp = cf_matrix.ravel()
        f1_basal_trained = 2*tp/(2*tp+fp+fn)
        adv_entropy = class_criterion(labels_adv,true_labels)
        adversary_drugs_penalty = compute_gradients(labels_adv.sum(), latent_base_vectors)
        loss_adv = adv_entropy + model_params['adv_penalnty'] * adversary_drugs_penalty
        loss_adv.backward()
        optimizer_adv.step()
        #print(f1_basal_trained)
        #else:
        optimizer.zero_grad()
        #f1_basal_trained = None
        z_base_1 = encoder_1(X_1)
        z_base_2 = encoder_2(X_2)
        latent_base_vectors = torch.cat((z_base_1, z_base_2), 0)
            
        #z_un = local_d(torch.cat((z_1, z_2), 0))
        z_un = local_d(latent_base_vectors)
        res_un = torch.matmul(z_un, z_un.t())
            
        z_1 = Vsp(z_base_1,z_species_1)
        z_2 = Vsp(z_base_2,z_species_2)
        latent_vectors = torch.cat((z_1, z_2), 0)
            
        y_pred_1 = decoder_1(z_1)
        fitLoss_1 = torch.mean(torch.sum((y_pred_1 - X_1)**2,dim=1))
        L2Loss_1 = decoder_1.L2Regularization(model_params['dec_l2_reg']) + encoder_1.L2Regularization(model_params['enc_l2_reg'])
        loss_1 = fitLoss_1 + L2Loss_1
            
        y_pred_2 = decoder_2(z_2)
        fitLoss_2 = torch.mean(torch.sum((y_pred_2 - X_2)**2,dim=1))
        L2Loss_2 = decoder_2.L2Regularization(model_params['dec_l2_reg']) + encoder_2.L2Regularization(model_params['enc_l2_reg'])
        loss_2 = fitLoss_2 + L2Loss_2

        silimalityLoss = torch.mean(torch.sum((z_base_1[0:paired_inds,:] - z_base_2[0:paired_inds,:])**2,dim=-1))
        cosineLoss = torch.nn.functional.cosine_similarity(z_base_1[0:paired_inds,:],z_base_2[0:paired_inds,:],dim=-1).mean()
            
        p_samples = res_un * pos_mask.float()
        q_samples = res_un * neg_mask.float()
    
        Ep = log_2 - F.softplus(- p_samples)
        Eq = F.softplus(-q_samples) + q_samples - log_2

        Ep = (Ep * pos_mask.float()).sum() / pos_mask.float().sum()
        Eq = (Eq * neg_mask.float()).sum() / neg_mask.float().sum()
        mi_loss = Eq - Ep

        prior = torch.rand_like(latent_base_vectors)

        term_a = torch.log(prior_d(prior)).mean()
        term_b = torch.log(1.0 - prior_d(latent_base_vectors)).mean()
        prior_loss = -(term_a + term_b) * model_params['prior_beta']

        # Classification loss
        labels = classifier(latent_vectors)
        true_labels = torch.cat((torch.ones(z_1.shape[0]),
                                 torch.zeros(z_2.shape[0])),0).long().to(device)
        entropy = class_criterion(labels,true_labels)
        _, predicted = torch.max(labels, 1)
        predicted = predicted.cpu().numpy()
        cf_matrix = confusion_matrix(true_labels.cpu().numpy(),predicted)
        tn, fp, fn, tp = cf_matrix.ravel()
        f1_latent = 2*tp/(2*tp+fp+fn)
            
        # Remove signal from z_basal
        labels_adv = adverse_classifier(latent_base_vectors)
        true_labels = torch.cat((torch.ones(z_base_1.shape[0]),
                                 torch.zeros(z_base_2.shape[0])),0).long().to(device)
        adv_entropy = class_criterion(labels_adv,true_labels)
        _, predicted = torch.max(labels_adv, 1)
        predicted = predicted.cpu().numpy()
        cf_matrix = confusion_matrix(true_labels.cpu().numpy(),predicted)
        tn, fp, fn, tp = cf_matrix.ravel()
        f1_basal = 2*tp/(2*tp+fp+fn)
            
        loss = loss_1 + loss_2 + model_params['similarity_reg'] * silimalityLoss +model_params['lambda_mi_loss']*mi_loss + prior_loss  + model_params['reg_classifier'] * entropy - model_params['reg_adv']*adv_entropy +classifier.L2Regularization(model_params['state_class_reg']) +Vsp.Regularization(model_params['v_reg'])  - model_params['cosine_loss'] * cosineLoss

        loss.backward()
        optimizer.step()
            
        
        pearson_1 = pearson_r(y_pred_1.detach().flatten(), X_1.detach().flatten())
        r2_1 = r_square(y_pred_1.detach().flatten(), X_1.detach().flatten())
        mse_1 = torch.mean(torch.mean((y_pred_1.detach() - X_1.detach())**2,dim=1))
        
        pearson_2 = pearson_r(y_pred_2.detach().flatten(), X_2.detach().flatten())
        r2_2 = r_square(y_pred_2.detach().flatten(), X_2.detach().flatten())
        mse_2 = torch.mean(torch.mean((y_pred_2.detach() - X_2.detach())**2,dim=1))            
            
    if model_params['schedule_step_adv'] is not None:
        scheduler_adv.step()
    if (e>=0):
        scheduler.step()
        outString = 'Try {:.0f}: Split {:.0f}: Epoch={:.0f}/{:.0f}'.format(fold_id,i,e+1,NUM_EPOCHS)
        outString += ', r2_1={:.4f}'.format(r2_1.item())
        outString += ', pearson_1={:.4f}'.format(pearson_1.item())
        outString += ', MSE_1={:.4f}'.format(mse_1.item())
        outString += ', r2_2={:.4f}'.format(r2_2.item())
        outString += ', pearson_2={:.4f}'.format(pearson_2.item())
        outString += ', MSE_2={:.4f}'.format(mse_2.item())
        outString += ', MI Loss={:.4f}'.format(mi_loss.item())
        outString += ', Prior Loss={:.4f}'.format(prior_loss.item())
        outString += ', Entropy Loss={:.4f}'.format(entropy.item())
        outString += ', Adverse Entropy={:.4f}'.format(adv_entropy.item())
        outString += ', Cosine Loss={:.4f}'.format(cosineLoss.item())
        outString += ', loss={:.4f}'.format(loss.item())
        outString += ', F1 latent={:.4f}'.format(f1_latent)
        outString += ', F1 basal={:.4f}'.format(f1_basal)
        #if e % model_params["adversary_steps"] == 0 and e>0:
        outString += ', F1 basal trained={:.4f}'.format(f1_basal_trained)
        #else:
        #    outString += ', F1 basal trained= %s'%f1_basal_trained
    if (e==0 or (e%250==0 and e>0)):
        print2log(outString)
print2log(outString)