# Before you begin make sure you have the appropriate folder structure and correct paths for you when saving results duting training

# Load dependencies

In [1]:
import pandas as pd
import numpy as np
from sklearn.metrics import silhouette_score,confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import scipy
from scipy.signal import savgol_filter
from scipy.spatial.distance import pdist,squareform
from scipy.stats import mannwhitneyu
import torch
import torch.nn.functional as F
import math
from matplotlib import pyplot as plt
import umap

from models import SimpleEncoder,Decoder,PriorDiscriminator,LocalDiscriminator,Classifier,SpeciesCovariate
from evaluationUtils import r_square,get_cindex,pearson_r,pseudoAccuracy

from IPython.display import clear_output
import seaborn as sns
sns.set()

In [2]:
device = torch.device('cuda')

In [3]:
# Create a train generators
def getSamples(N, batchSize):
    order = np.random.permutation(N)
    outList = []
    while len(order)>0:
        outList.append(order[0:batchSize])
        order = order[batchSize:]
    return outList

In [4]:
def compute_gradients(output, input):
    grads = torch.autograd.grad(output, input, create_graph=True)
    grads = grads[0].pow(2).mean()
    return grads

# Load Data

In [5]:
# Gex data 
cmap = pd.read_csv('../preprocessing/preprocessed_data/all_cmap_landmarks.csv',index_col=0)
gene_size = len(cmap.columns)
X = cmap.values
display(cmap)

Unnamed: 0,16,23,25,30,39,47,102,128,142,154,...,94239,116832,124583,147179,148022,200081,200734,256364,375346,388650
PCL001_HT29_24H:BRD-K42991516:10,0.266452,-0.250874,-0.854204,-0.041545,0.204450,0.709800,-0.328601,-0.498116,-1.454481,0.506321,...,0.536235,0.024452,0.928558,-0.453246,-0.140290,0.205065,1.148706,-1.933820,1.966937,-0.159919
PCL001_HT29_24H:BRD-K50817946:10,6.074023,-0.524075,-0.635742,2.014629,-3.747274,2.109600,0.847576,-2.732549,-5.729352,2.164091,...,0.447939,1.543649,-3.775020,1.827991,-0.088051,0.382848,1.400255,-3.087269,1.392148,1.027263
HOG002_A549_24H:BRD-K28296557-005-14-6:3.33,3.092555,1.760324,0.045857,-0.267738,-5.237659,-1.254134,-1.197927,-2.120804,-2.096229,0.799317,...,0.253642,-0.461737,-2.344703,1.581582,4.007076,-0.203330,0.715596,1.502107,1.281574,0.450898
DOSBIO001_MCF7_24H:BRD-K77888550:9.5278,-1.680236,1.174203,0.295703,0.555778,0.136969,-1.507160,-0.068983,-0.468983,-1.894113,-0.035792,...,1.204646,-0.688365,-1.042315,2.571737,-0.085614,-3.472259,1.436653,-1.054814,1.873788,1.680525
DOSBIO001_NPC_24H:BRD-K09069264:10.2084,-1.401400,0.308703,1.178614,-2.114849,-0.020324,-0.393869,-2.599080,-0.983008,-0.063675,-0.549799,...,0.349096,0.017305,0.356195,0.638253,0.862676,-0.106953,1.115011,2.205899,-0.306434,1.101611
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
DOSVAL001_MCF7_24H:BRD-K30867024:10,-0.228277,-0.574911,0.545074,1.320753,-0.102422,1.058099,0.366014,-1.520461,-0.004073,-0.637326,...,-0.057771,-0.660938,-1.952505,0.395937,1.790094,1.072335,1.657834,1.445156,2.031223,0.026611
DOSVAL002_A375_24H:BRD-K10749593:20,4.415511,0.608378,1.604217,-0.911175,-2.611416,-1.742975,-2.500287,-2.503129,-3.472708,3.008501,...,0.916562,-1.403280,0.479218,4.528471,1.701896,0.141621,1.953133,-1.480089,1.549125,1.414482
DOSVAL002_MCF7_24H:BRD-K14002526:20,1.311693,1.834785,1.277888,-0.224320,-0.365258,0.209443,0.166746,-2.112468,-0.870127,-0.083894,...,0.894919,-0.707055,0.519019,0.916627,0.710227,0.126153,2.277303,-1.870382,1.021850,1.199542
DOSVAL001_HT29_24H:BRD-K11624501:9.99164,1.540175,-0.196926,-0.094410,-1.951286,-2.848082,-2.478519,-1.257487,-1.247405,-4.006328,-0.362494,...,-0.371798,3.735393,2.011243,1.693114,2.924200,2.535851,1.861230,-3.021530,0.127304,0.980487


In [6]:
# Gex data for controls
cmap_controls = pd.read_csv('../preprocessing/preprocessed_data/baselineCell/cmap_all_baselines_q1.csv',index_col=0)
display(cmap_controls)

Unnamed: 0,16,23,25,30,39,47,102,128,142,154,...,94239,116832,124583,147179,148022,200081,200734,256364,375346,388650
OFL001_A549_96H:G15,1.854175,1.868439,-0.140405,-0.278911,0.396597,0.334116,0.473704,-0.565553,1.372410,1.181299,...,1.252141,-0.291923,1.193942,0.978987,2.381282,-1.065447,1.174847,-0.885704,0.879203,0.216700
OFL001_MCF7_96H:J10,0.081511,0.651525,-0.205014,0.054704,0.726742,-0.126017,0.200712,0.915557,0.780285,0.007211,...,0.341261,0.405606,-0.054713,0.264261,-0.096964,0.752965,-0.249324,-1.176310,0.282062,-0.212717
ABY001_NCIH1975_XH:CMAP-000:-666:3,0.543459,1.647965,-1.731661,0.319534,1.078192,0.602553,0.323291,0.787790,0.888264,1.532468,...,0.704732,-1.326966,1.433667,-0.037051,1.016276,-0.481035,1.061352,1.616178,1.540468,-0.958139
ZTO.XPR001_THP1_408H:CMAP-000:-666,-0.054865,-0.085794,-0.319447,0.180520,0.124284,-0.117936,-0.267994,0.429114,-0.144781,0.190815,...,-0.114969,0.308555,0.055869,-0.450732,-0.394338,0.029793,0.046924,-0.231632,-0.186150,-0.309360
MOA001_A549_24H:N01,0.401776,1.197786,0.946556,0.794930,0.662958,0.473484,1.335021,0.338371,0.300303,0.690938,...,0.020668,0.171860,0.862337,0.525409,-0.029795,-0.263026,0.271724,0.934595,0.552001,-0.711617
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
DOSVAL001_HT29_24H:CMAP-000:-666,0.038320,-0.426547,0.183131,0.450992,-0.414180,-0.619587,-0.318295,0.066966,-0.618409,0.539847,...,0.085557,0.018541,0.174058,-0.101450,-0.279539,-0.303862,0.019368,1.111968,0.387193,-0.770082
DOSVAL001_HA1E_24H:CMAP-000:-666,0.319681,-0.182241,0.689418,0.542491,-0.124395,0.252069,-0.348502,0.145006,-0.018389,0.190280,...,-0.048443,0.188158,0.422073,0.123565,0.097611,-0.003442,0.924158,-0.212382,0.166562,0.142994
DOSVAL001_A375_24H:CMAP-000:-666,0.091151,-0.007194,0.360459,0.430177,-0.443078,-0.370296,-0.450974,0.616529,0.258591,0.111886,...,-0.571711,0.084373,0.240619,-0.372428,-0.168089,-0.137313,0.157594,0.256362,0.080780,-0.065995
DOSVAL001_HEPG2_24H:CMAP-000:-666,-0.276361,-0.321295,0.412983,0.040179,-0.144093,-0.374313,-0.488024,0.273988,-0.278131,-0.075510,...,-0.440074,0.220422,-0.144075,-0.162023,-0.328652,-0.300582,0.469960,-0.533808,0.158130,-0.492051


# Train CPA model but without pre-training adverse
Just for a few initial epochs

In [7]:
model_params = {'encoder_1_hiddens':[384,256],
                'encoder_2_hiddens':[384,256],
                'latent_dim': 128,
                'decoder_1_hiddens':[256,384],
                'decoder_2_hiddens':[256,384],
                'dropout_decoder':0.2,
                'dropout_encoder':0.1,
                'encoder_activation':torch.nn.ELU(),
                'decoder_activation':torch.nn.ELU(),
                'V_dropout':0.25,
                'state_class_hidden':[64,32,16],#[128,64,32],
                'state_class_drop_in':0.5,
                'state_class_drop':0.25,
                'no_states':2,
                'adv_class_hidden':[128,64,32],
                'adv_class_drop_in':0.3,
                'adv_class_drop':0.1,
                'no_adv_class':2,
                'encoding_lr':0.001,
                'adv_lr':0.001,
                'schedule_step_adv':200,
                'gamma_adv':0.5,
                'schedule_step_enc':200,
                'gamma_enc':0.8,
                'batch_size':512,
                'epochs':1000,
                'prior_beta':1.0,
                'no_folds':5,
                'v_reg':1e-04,
                'state_class_reg':1e-02,
                'enc_l2_reg':0.01,
                'dec_l2_reg':0.01,
                'lambda_mi_loss':100,
                'effsize_reg': 100,
                'cosine_loss': 40,
                'adv_penalnty':100,
                'reg_adv':1000,
                'reg_classifier': 100,
                'similarity_reg' : 10,
                'adversary_steps':4,
                'autoencoder_wd': 0.,
                'adversary_wd': 0.}


In [8]:
bs =  model_params['batch_size']
NUM_EPOCHS=model_params['epochs']
class_criterion = torch.nn.CrossEntropyLoss()

In [9]:
num_genes = int(0.5*gene_size)
random_iterations = 5

### Uncomment the last line of the cell to save df_result_all at every iteration

In [None]:
df_result_all = pd.DataFrame({})
for cell in ["PC3","HT29","MCF7","A549","NPC","HEPG2","A375","YAPC","U2OS","MCF10A","HA1E","HCC515","ASC","VCAP","HUVEC","HELA"]:
    for j in range(random_iterations):
        #genes_1 = np.random.choice(cmap.columns.values, size=num_genes, replace=False)
        genes_1 = np.load('../results/SameCellimputationModel/genes_subsets/genes_1'+cell+'_iter'+str(j)+'.npy',allow_pickle=True)
        genes_2 = np.setdiff1d(cmap.columns.values,genes_1)
        valPear = []
        valPear_1 = []
        valPear_2 = []
        valPear_controls = []
        valPear_controls_1 = []
        valPear_controls_2 = []
        
        valPear_shuffled = []
        valPear_1_shuffled = []
        valPear_2_shuffled = []
        valPear_controls_shuffled = []
        valPear_controls_1_shuffled = []
        valPear_controls_2_shuffled = []
        
        valF1 = []
        valClassAcc = []
            
        for i in range(model_params['no_folds']):
            trainInfo = pd.read_csv('../preprocessing/preprocessed_data/SameCellimputationModel/'+cell+'/train_'+str(i)+'.csv',index_col=0)
            valInfo = pd.read_csv('../preprocessing/preprocessed_data/SameCellimputationModel/'+cell+'/val_'+str(i)+'.csv',index_col=0)
            
            if len(trainInfo)<950:
                bs = 256
            else:
                bs = model_params['batch_size']
            
            cmap_train = cmap.loc[trainInfo.sig_id,:]
            cols = cmap_train.columns.values
            cmap_train_shuffled = cmap_train.sample(frac=1, axis=1)
            cmap_train_shuffled.columns = cols
            cmap_val = cmap.loc[valInfo.sig_id,:]
            N = len(cmap_train)

            # Network
            decoder_1 = Decoder(model_params['latent_dim'],model_params['decoder_1_hiddens'],num_genes,
                                dropRate=model_params['dropout_decoder'], 
                                activation=model_params['decoder_activation']).to(device)
            decoder_2 = Decoder(model_params['latent_dim'],model_params['decoder_2_hiddens'],num_genes,
                                dropRate=model_params['dropout_decoder'], 
                                activation=model_params['decoder_activation']).to(device)
            encoder_1 = SimpleEncoder(num_genes,model_params['encoder_1_hiddens'],model_params['latent_dim'],
                                      dropRate=model_params['dropout_encoder'], 
                                      activation=model_params['encoder_activation'],
                                     normalizeOutput=False).to(device)
            encoder_2 = SimpleEncoder(num_genes,model_params['encoder_2_hiddens'],model_params['latent_dim'],
                                          dropRate=model_params['dropout_encoder'], 
                                          activation=model_params['encoder_activation'],
                                     normalizeOutput=False).to(device)
            prior_d = PriorDiscriminator(model_params['latent_dim']).to(device)
            local_d = LocalDiscriminator(model_params['latent_dim'],model_params['latent_dim']).to(device)
            classifier = Classifier(in_channel=model_params['latent_dim'],
                                    hidden_layers=model_params['state_class_hidden'],
                                    num_classes=model_params['no_states'],
                                    drop_in=model_params['state_class_drop_in'],
                                    drop=model_params['state_class_drop']).to(device)
            adverse_classifier = Classifier(in_channel=model_params['latent_dim'],
                                            hidden_layers=model_params['adv_class_hidden'],
                                            num_classes=model_params['no_adv_class'],
                                            drop_in=model_params['adv_class_drop_in'],
                                            drop=model_params['adv_class_drop']).to(device)
            Vsp = SpeciesCovariate(2,model_params['latent_dim'],dropRate=model_params['V_dropout']).to(device)
            
            #encoder_interm_1 = SimpleEncoder(model_params['latent_dim'],
            #                                [int(model_params['latent_dim']/2)],
            #                                int(model_params['latent_dim']/2),
            #                                dropRate=0.01,
            #                                activation=model_params['encoder_activation'],
            #                                normalizeOutput=True).to(device)
            #encoder_interm_2 = SimpleEncoder(model_params['latent_dim'],
            #                                [int(model_params['latent_dim']/2)],
            #                                int(model_params['latent_dim']/2),
            #                                dropRate=0.01,
            #                                activation=model_params['encoder_activation'],
            #                                normalizeOutput=True).to(device)

            allParams = list(decoder_1.parameters()) + list(encoder_1.parameters())
            allParams = allParams + list(decoder_2.parameters()) + list(encoder_2.parameters())
            allParams = allParams  + list(local_d.parameters())
            allParams = allParams + list(prior_d.parameters())
            allParams = allParams + list(Vsp.parameters())
            #allParams = allParams + list(encoder_interm_1.parameters()) + list(encoder_interm_2.parameters())
            allParams = allParams + list(classifier.parameters())
            optimizer_adv = torch.optim.Adam(adverse_classifier.parameters(),
                                             lr= model_params['adv_lr'], 
                                             weight_decay=model_params['adversary_wd'])
            optimizer = torch.optim.Adam(allParams,
                                         lr=model_params['encoding_lr'],
                                         weight_decay=model_params['autoencoder_wd'])
            if model_params['schedule_step_adv'] is not None:
                scheduler_adv = torch.optim.lr_scheduler.StepLR(optimizer_adv,
                                                                step_size=model_params['schedule_step_adv'],
                                                                gamma=model_params['gamma_adv'])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                        step_size=model_params['schedule_step_enc'],
                                                        gamma=model_params['gamma_enc'])
            #trainLoss = []
            #trainLossSTD = []
            for e in range(NUM_EPOCHS):
                encoder_1.train()
                decoder_1.train()
                encoder_2.train()
                decoder_2.train()
                prior_d.train()
                local_d.train()
                classifier.train()
                adverse_classifier.train()
                Vsp.train()
                #encoder_interm_1.train()
                #encoder_interm_2.train()

                trainloader = getSamples(N, bs)
                trainLoss_ALL = []
                for dataIndex in trainloader:

                    data_1 = torch.tensor(cmap_train.loc[:,genes_1].values).float()
                    data_1 = data_1[dataIndex,:].to(device)
                    data_2 = torch.tensor(cmap_train.loc[:,genes_2].values).float()
                    data_2 = data_2[dataIndex,:].to(device)
                    
                    z_species_1 = torch.cat((torch.ones(data_1.shape[0],1),
                                     torch.zeros(data_1.shape[0],1)),1).to(device)
                    z_species_2 = torch.cat((torch.zeros(data_2.shape[0],1),
                                     torch.ones(data_2.shape[0],1)),1).to(device)
                    
                    conditions = trainInfo.conditionId.values[dataIndex]
                    conditions = np.concatenate((conditions,conditions))
                    size = conditions.size
                    conditions = conditions.reshape(size,1)
                    conditions = conditions == conditions.transpose()
                    conditions = conditions*1
                    mask = torch.tensor(conditions).to(device).detach()
                    pos_mask = mask
                    neg_mask = 1 - mask
                    log_2 = math.log(2.)
                    optimizer_adv.zero_grad()
                    optimizer.zero_grad()
                    
                    z_base_1 = encoder_1(data_1)
                    z_base_2 = encoder_2(data_2)
                    latent_base_vectors = torch.cat((z_base_1, z_base_2), 0)
                    labels_adv = adverse_classifier(latent_base_vectors)
                    true_labels = torch.cat((torch.ones(z_base_1.shape[0]),
                        torch.zeros(z_base_2.shape[0])),0).long().to(device)
                    _, predicted = torch.max(labels_adv, 1)
                    predicted = predicted.cpu().numpy()
                    cf_matrix = confusion_matrix(true_labels.cpu().numpy(),predicted)
                    tn, fp, fn, tp = cf_matrix.ravel()
                    f1_basal_trained = 2*tp/(2*tp+fp+fn)
                    adv_entropy = class_criterion(labels_adv,true_labels)
                    adversary_drugs_penalty = compute_gradients(labels_adv.sum(), latent_base_vectors)
                    loss_adv = adv_entropy + model_params['adv_penalnty'] * adversary_drugs_penalty
                    loss_adv.backward()
                    optimizer_adv.step()
                    
                    if e>25:
                        optimizer.zero_grad()
                        
                        z_base_1 = encoder_1(data_1)
                        z_base_2 = encoder_2(data_2)
                        latent_base_vectors = torch.cat((z_base_1, z_base_2), 0)
                        
                        z_un = local_d(latent_base_vectors)
                        res_un = torch.matmul(z_un, z_un.t())
                        
                        z_1 = Vsp(z_base_1,z_species_1)
                        z_2 = Vsp(z_base_2,z_species_2)
                        #z_1 = encoder_interm_1(z_base_1)
                        #z_2 = encoder_interm_2(z_base_2)
                        latent_vectors = torch.cat((z_1, z_2), 0)

                        Xhat_1 = decoder_1(z_1)
                        Xhat_2 = decoder_2(z_2)
                        loss_1 = torch.mean(torch.sum((Xhat_1 - data_1)**2,dim=1)) + encoder_1.L2Regularization(model_params['enc_l2_reg']) + decoder_1.L2Regularization(model_params['dec_l2_reg'])
                        loss_2 = torch.mean(torch.sum((Xhat_2 - data_2)**2,dim=1)) +encoder_2.L2Regularization(model_params['enc_l2_reg']) + decoder_2.L2Regularization(model_params['dec_l2_reg'])

                        silimalityLoss = torch.sum(torch.cdist(latent_vectors, latent_vectors) * pos_mask.float()) / pos_mask.float().sum()
                        w1 = latent_vectors.norm(p=2, dim=1, keepdim=True)
                        w2 = latent_vectors.norm(p=2, dim=1, keepdim=True)
                        cosineLoss = torch.mm(latent_vectors, latent_vectors.t()) / (w1 * w2.t()).clamp(min=1e-6)
                        cosineLoss = torch.sum(cosineLoss * pos_mask.float()) / pos_mask.float().sum()

                        p_samples = res_un * pos_mask.float()
                        q_samples = res_un * neg_mask.float()
                        Ep = log_2 - F.softplus(- p_samples)
                        Eq = F.softplus(-q_samples) + q_samples - log_2
                        Ep = (Ep * pos_mask.float()).sum() / pos_mask.float().sum()
                        Eq = (Eq * neg_mask.float()).sum() / neg_mask.float().sum()
                        mi_loss = Eq - Ep
                        prior = torch.rand_like(latent_base_vectors)
                        term_a = torch.log(prior_d(prior)).mean()
                        term_b = torch.log(1.0 - prior_d(latent_base_vectors).mean())
                        prior_loss = -(term_a + term_b) * model_params['prior_beta']
                        
                        # Classification loss
                        labels = classifier(latent_vectors)
                        true_labels = torch.cat((torch.ones(z_1.shape[0]),
                            torch.zeros(z_2.shape[0])),0).long().to(device)
                        entropy = class_criterion(labels,true_labels)
                        _, predicted = torch.max(labels, 1)
                        predicted = predicted.cpu().numpy()
                        cf_matrix = confusion_matrix(true_labels.cpu().numpy(),predicted)
                        tn, fp, fn, tp = cf_matrix.ravel()
                        f1_latent = 2*tp/(2*tp+fp+fn)

                        # Remove signal from z_basal
                        labels_adv = adverse_classifier(latent_base_vectors)
                        true_labels = torch.cat((torch.ones(z_base_1.shape[0]),
                            torch.zeros(z_base_2.shape[0])),0).long().to(device)
                        adv_entropy = class_criterion(labels_adv,true_labels)
                        _, predicted = torch.max(labels_adv, 1)
                        predicted = predicted.cpu().numpy()
                        cf_matrix = confusion_matrix(true_labels.cpu().numpy(),predicted)
                        tn, fp, fn, tp = cf_matrix.ravel()
                        f1_basal = 2*tp/(2*tp+fp+fn)
                        
                        loss = loss_1 + loss_2 + model_params[
                            'similarity_reg']*silimalityLoss - model_params[
                            'cosine_loss'] * cosineLoss + prior_loss + model_params[
                            'lambda_mi_loss'] * mi_loss + model_params[
                            'reg_classifier'] * entropy - model_params[
                            'reg_adv']*adv_entropy +classifier.L2Regularization(model_params[
                            'state_class_reg']) +Vsp.Regularization(model_params['v_reg'])
                        loss.backward()
                        optimizer.step()

                        pear_1 = pearson_r(Xhat_1.detach(), data_1.detach())
                        mse_1 = torch.mean(torch.mean((Xhat_1.detach() - data_1.detach()) ** 2, dim=1))
                        pear_2 = pearson_r(Xhat_2.detach(), data_2.detach())
                        mse_2 = torch.mean(torch.mean((Xhat_2.detach() - data_2.detach()) ** 2, dim=1))
                if model_params['schedule_step_adv'] is not None:
                    scheduler_adv.step()
                if e>25:
                    #print('10 happened')
                    scheduler.step()
                    if e%250==0 or e==26:
                        outString = 'Cell-line : '+cell+', rand_iter {:.0f}/{:.0f}'.format(j + 1, random_iterations)
                        outString += ', Split {:.0f}: Epoch={:.0f}/{:.0f}'.format(i + 1, e + 1, NUM_EPOCHS)
                        outString += ', pearson_1={:.4f}'.format(pear_1.item())
                        outString += ', MSE_1={:.4f}'.format(mse_1.item())
                        outString += ', pearson_2={:.4f}'.format(pear_2.item())
                        outString += ', MSE_2={:.4f}'.format(mse_2.item())
                        outString += ', MI Loss={:.4f}'.format(mi_loss.item())
                        outString += ', Prior Loss={:.4f}'.format(prior_loss.item())
                        outString += ', Entropy Loss={:.4f}'.format(entropy.item())
                        outString += ', Adverse Entropy={:.4f}'.format(adv_entropy.item())
                        outString += ', Cosine ={:.4f}'.format(cosineLoss.item())
                        outString += ', silimalityLoss ={:.4f}'.format(silimalityLoss.item())
                        outString += ', loss={:.4f}'.format(loss.item())
                        outString += ', F1 latent={:.4f}'.format(f1_latent)
                        outString += ', F1 basal={:.4f}'.format(f1_basal)
                        outString += ', F1 basal trained={:.4f}'.format(f1_basal_trained)
                        print(outString)
            outString = 'Cell-line : '+cell+', rand_iter {:.0f}/{:.0f}'.format(j + 1, random_iterations)
            outString += ', Split {:.0f}: Epoch={:.0f}/{:.0f}'.format(i + 1, e + 1, NUM_EPOCHS)
            outString += ', pearson_1={:.4f}'.format(pear_1.item())
            outString += ', MSE_1={:.4f}'.format(mse_1.item())
            outString += ', pearson_2={:.4f}'.format(pear_2.item())
            outString += ', MSE_2={:.4f}'.format(mse_2.item())
            outString += ', MI Loss={:.4f}'.format(mi_loss.item())
            outString += ', Prior Loss={:.4f}'.format(prior_loss.item())
            outString += ', Entropy Loss={:.4f}'.format(entropy.item())
            outString += ', Adverse Entropy={:.4f}'.format(adv_entropy.item())
            outString += ', Cosine ={:.4f}'.format(cosineLoss.item())
            outString += ', silimalityLoss ={:.4f}'.format(silimalityLoss.item())
            outString += ', loss={:.4f}'.format(loss.item())
            outString += ', F1 latent={:.4f}'.format(f1_latent)
            outString += ', F1 basal={:.4f}'.format(f1_basal)
            outString += ', F1 basal trained={:.4f}'.format(f1_basal_trained)
            print(outString)

            encoder_1.eval()
            decoder_1.eval()
            encoder_2.eval()
            decoder_2.eval()
            prior_d.eval()
            local_d.eval()
            classifier.eval()
            adverse_classifier.eval()
            Vsp.eval()
            #encoder_interm_1.eval()
            #encoder_interm_2.eval()
            
            print('Validation performance for cell %s for try %s for split %s'%(cell,j+1,i+1))


            X_1 = torch.tensor(cmap_val.loc[:,genes_1].values).float().to(device)
            X_2 = torch.tensor(cmap_val.loc[:,genes_2].values).float().to(device)
            z_species_1 = torch.cat((torch.ones(X_1.shape[0],1),
                             torch.zeros(X_1.shape[0],1)),1).to(device)
            z_species_2 = torch.cat((torch.zeros(X_2.shape[0],1),
                             torch.ones(X_2.shape[0],1)),1).to(device)
                    
            z_latent_base_1 = encoder_1(X_1)
            z_latent_base_2 = encoder_2(X_2)
            z_1 = Vsp(z_latent_base_1,z_species_1)
            z_2 = Vsp(z_latent_base_2,z_species_2)
            #z_1 = encoder_interm_1(z_latent_base_1)
            #z_2 = encoder_interm_2(z_latent_base_2)
            Xhat_1 = decoder_1(z_1)
            Xhat_2 = decoder_2(z_2)
            pear_1 = pearson_r(Xhat_1.detach(), X_1.detach())
            pear_2 = pearson_r(Xhat_2.detach(), X_2.detach())
            valPear_1.append(pear_1.item())
            valPear_2.append(pear_2.item())
            print('Pearson correlation 1: %s'%pear_1.item())
            print('Pearson correlation 2: %s'%pear_2.item())
            
            # Classification
            labels = classifier(torch.cat((z_1, z_2), 0))
            true_labels = torch.cat((torch.ones(z_1.shape[0]).view(z_1.shape[0],1),
                                torch.zeros(z_2.shape[0]).view(z_2.shape[0],1)),0).long()
            _, predicted = torch.max(labels, 1)
            predicted = predicted.cpu().numpy()
            cf_matrix = confusion_matrix(true_labels.numpy(),predicted)
            tn, fp, fn, tp = cf_matrix.ravel()
            class_acc = (tp+tn)/predicted.size
            f1 = 2*tp/(2*tp+fp+fn)
            valF1.append(f1)
            valClassAcc.append(class_acc)
            print('Classification accuracy: %s'%class_acc)
            print('Classification F1 score: %s'%f1)
            
            z_1 = Vsp(z_latent_base_1,1.-z_species_1)
            #z_1 = encoder_interm_2(z_latent_base_1)
            x_hat_2_equivalent = decoder_2(z_1).detach()
            pearson_2 = pearson_r(x_hat_2_equivalent.detach(), X_2.detach())
            print('Pearson correlation 1 to 2: %s'%pearson_2.item())
            z_2 = Vsp(z_latent_base_2,1.-z_species_2)
            #z_2 = encoder_interm_1(z_latent_base_2)
            x_hat_1_equivalent = decoder_1(z_2).detach()
            pearson_1 = pearson_r(x_hat_1_equivalent.detach(), X_1.detach())
            print('Pearson correlation 2 to 1: %s'%pearson_1.item())
            
            valPear.append([pearson_2.item(),pearson_1.item()])
            
            #torch.save(decoder_1,'../results/test_stuff/models_AutoTransOp/'+cell+'/decoder_1_fold%s_iter%s.pt'%(i,j))
            #torch.save(decoder_2,'../results/test_stuff/models_AutoTransOp/'+cell+'/decoder_2_fold%s_iter%s.pt'%(i,j))
            #torch.save(prior_d,'../results/test_stuff/models_AutoTransOp/'+cell+'/priorDiscr_fold%s_iter%s.pt'%(i,j))
            #torch.save(local_d,'../results/test_stuff/models_AutoTransOp/'+cell+'/localDiscr_fold%s_iter%s.pt'%(i,j))
            #torch.save(encoder_1,'../results/test_stuff/models_AutoTransOp/'+cell+'/encoder_1_fold%s_iter%s.pt'%(i,j))
            #torch.save(encoder_2,'../results/test_stuff/models_AutoTransOp/'+cell+'/encoder_2_fold%s_iter%s.pt'%(i,j))
            #torch.save(classifier,'../results/test_stuff/models_AutoTransOp/'+cell+'/classifier_fold%s_iter%s.pt'%(i,j))
            #torch.save(Vsp,'../results/test_stuff/models_AutoTransOp/'+cell+'/Vsp_fold%s_iter%s.pt'%(i,j))
            #torch.save(adverse_classifier,'../results/test_stuff/models_AutoTransOp/'+cell+'/adverse_classifier_fold%s_iter%s.pt'%(i,j))
            #torch.save(encoder_interm_1,'../results/test_stuff/models_AutoTransOp/'+cell+'/encoder_interm_1_fold%s_iter%s.pt'%(i,j))
            #torch.save(encoder_interm_2,'../results/test_stuff/models_AutoTransOp/'+cell+'/encoder_interm_2_fold%s_iter%s.pt'%(i,j))
            
        valPear = np.array(valPear)
        #valPear_shuffled = np.array(valPear_shuffled)
        df_result = pd.DataFrame({'model_pearson2to1':valPear[:,0],'model_pearson1to2':valPear[:,1],
                                  'recon_pear_2':valPear_2 ,'recon_pear_1':valPear_1,
                                  'F1':valF1,'Accuracy':valClassAcc})
        #df_result_shuffled = pd.DataFrame({'model_pearson2to1':valPear_shuffled[:,0],'model_pearson1to2':valPear_shuffled[:,1],
        #                                   'recon_pear_2':valPear_2_shuffled ,'recon_pear_1':valPear_1_shuffled})
        df_result['model'] = 'model'
        #df_result_shuffled['model'] = 'shuffled'
        df_result['set'] = 'validation'
        #df_result_shuffled['set'] = 'validation'
        #df_result = df_result.append(df_result_shuffled)
        
        #df_result_controls = pd.DataFrame({'model_pearson2to1':[],'model_pearson1to2':[],
        #                                   'recon_pear_2':[] ,'recon_pear_1':[]})
        #df_result_controls['model'] = []
        #df_result_controls['set'] = []
        #if cell not in ["HUVEC","HELA","ASC","YAPC"]:
        #    valPear_controls_shuffled = np.array(valPear_controls_shuffled)
        #    valPear_controls = np.array(valPear_controls)
        #    df_result_controls = pd.DataFrame({'model_pearson2to1':valPear_controls[:,0],'model_pearson1to2':valPear_controls[:,1],
        #                                   'recon_pear_2':valPear_controls_2 ,'recon_pear_1':valPear_controls_1})
        #    df_result_controls_shuffled = pd.DataFrame({'model_pearson2to1':valPear_controls_shuffled[:,0],'model_pearson1to2':valPear_controls_shuffled[:,1],
        #                                   'recon_pear_2':valPear_controls_2_shuffled ,'recon_pear_1':valPear_controls_1_shuffled})
        #    df_result_controls['model'] = 'model'
        #    df_result_controls_shuffled['model'] = 'shuffled'
        #    df_result_controls['set'] = 'controls'
        #    df_result_controls_shuffled['set'] = 'controls'
        #    df_result_controls = df_result_controls.append(df_result_controls_shuffled)
        #df_result = df_result.append(df_result_controls)
        df_result['cell'] = cell
        df_result['iteration'] = j
        df_result_all = df_result_all.append(df_result)
        #df_result_all.to_csv('../results/test_stuff/translation_results.csv') #SameCellimputationModel

# Remember to save df_result_all after all the trainings