# Load dependencies

In [1]:
import pandas as pd
import numpy as np
from sklearn.metrics import silhouette_score,confusion_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import SelectFromModel
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.model_selection import KFold
import scipy
from scipy.signal import savgol_filter
from scipy.spatial.distance import pdist,squareform
from scipy.stats import mannwhitneyu
import torch
import torch.nn.functional as F
import math
from matplotlib import pyplot as plt
import umap

from models import SimpleEncoder,Decoder,PriorDiscriminator,LocalDiscriminator
from evaluationUtils import r_square,get_cindex,pearson_r,pseudoAccuracy

from IPython.display import clear_output
import seaborn as sns
sns.set()

In [2]:
device = torch.device('cuda')

In [3]:
# Create a train generators
def getSamples(N, batchSize):
    order = np.random.permutation(N)
    outList = []
    while len(order)>0:
        outList.append(order[0:batchSize])
        order = order[batchSize:]
    return outList

# Load Data

In [4]:
# Gex data 
cmap = pd.read_csv('../preprocessing/preprocessed_data/all_cmap_landmarks.csv',index_col=0)
gene_size = len(cmap.columns)
X = cmap.values
display(cmap)

Unnamed: 0,16,23,25,30,39,47,102,128,142,154,...,94239,116832,124583,147179,148022,200081,200734,256364,375346,388650
PCL001_HT29_24H:BRD-K42991516:10,0.266452,-0.250874,-0.854204,-0.041545,0.204450,0.709800,-0.328601,-0.498116,-1.454481,0.506321,...,0.536235,0.024452,0.928558,-0.453246,-0.140290,0.205065,1.148706,-1.933820,1.966937,-0.159919
PCL001_HT29_24H:BRD-K50817946:10,6.074023,-0.524075,-0.635742,2.014629,-3.747274,2.109600,0.847576,-2.732549,-5.729352,2.164091,...,0.447939,1.543649,-3.775020,1.827991,-0.088051,0.382848,1.400255,-3.087269,1.392148,1.027263
HOG002_A549_24H:BRD-K28296557-005-14-6:3.33,3.092555,1.760324,0.045857,-0.267738,-5.237659,-1.254134,-1.197927,-2.120804,-2.096229,0.799317,...,0.253642,-0.461737,-2.344703,1.581582,4.007076,-0.203330,0.715596,1.502107,1.281574,0.450898
DOSBIO001_MCF7_24H:BRD-K77888550:9.5278,-1.680236,1.174203,0.295703,0.555778,0.136969,-1.507160,-0.068983,-0.468983,-1.894113,-0.035792,...,1.204646,-0.688365,-1.042315,2.571737,-0.085614,-3.472259,1.436653,-1.054814,1.873788,1.680525
DOSBIO001_NPC_24H:BRD-K09069264:10.2084,-1.401400,0.308703,1.178614,-2.114849,-0.020324,-0.393869,-2.599080,-0.983008,-0.063675,-0.549799,...,0.349096,0.017305,0.356195,0.638253,0.862676,-0.106953,1.115011,2.205899,-0.306434,1.101611
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
DOSVAL001_MCF7_24H:BRD-K30867024:10,-0.228277,-0.574911,0.545074,1.320753,-0.102422,1.058099,0.366014,-1.520461,-0.004073,-0.637326,...,-0.057771,-0.660938,-1.952505,0.395937,1.790094,1.072335,1.657834,1.445156,2.031223,0.026611
DOSVAL002_A375_24H:BRD-K10749593:20,4.415511,0.608378,1.604217,-0.911175,-2.611416,-1.742975,-2.500287,-2.503129,-3.472708,3.008501,...,0.916562,-1.403280,0.479218,4.528471,1.701896,0.141621,1.953133,-1.480089,1.549125,1.414482
DOSVAL002_MCF7_24H:BRD-K14002526:20,1.311693,1.834785,1.277888,-0.224320,-0.365258,0.209443,0.166746,-2.112468,-0.870127,-0.083894,...,0.894919,-0.707055,0.519019,0.916627,0.710227,0.126153,2.277303,-1.870382,1.021850,1.199542
DOSVAL001_HT29_24H:BRD-K11624501:9.99164,1.540175,-0.196926,-0.094410,-1.951286,-2.848082,-2.478519,-1.257487,-1.247405,-4.006328,-0.362494,...,-0.371798,3.735393,2.011243,1.693114,2.924200,2.535851,1.861230,-3.021530,0.127304,0.980487


In [5]:
# Gex data for controls
cmap_controls = pd.read_csv('../preprocessing/preprocessed_data/baselineCell/cmap_all_baselines_q1.csv',index_col=0)
display(cmap_controls)

Unnamed: 0,16,23,25,30,39,47,102,128,142,154,...,94239,116832,124583,147179,148022,200081,200734,256364,375346,388650
OFL001_A549_96H:G15,1.854175,1.868439,-0.140405,-0.278911,0.396597,0.334116,0.473704,-0.565553,1.372410,1.181299,...,1.252141,-0.291923,1.193942,0.978987,2.381282,-1.065447,1.174847,-0.885704,0.879203,0.216700
OFL001_MCF7_96H:J10,0.081511,0.651525,-0.205014,0.054704,0.726742,-0.126017,0.200712,0.915557,0.780285,0.007211,...,0.341261,0.405606,-0.054713,0.264261,-0.096964,0.752965,-0.249324,-1.176310,0.282062,-0.212717
ABY001_NCIH1975_XH:CMAP-000:-666:3,0.543459,1.647965,-1.731661,0.319534,1.078192,0.602553,0.323291,0.787790,0.888264,1.532468,...,0.704732,-1.326966,1.433667,-0.037051,1.016276,-0.481035,1.061352,1.616178,1.540468,-0.958139
ZTO.XPR001_THP1_408H:CMAP-000:-666,-0.054865,-0.085794,-0.319447,0.180520,0.124284,-0.117936,-0.267994,0.429114,-0.144781,0.190815,...,-0.114969,0.308555,0.055869,-0.450732,-0.394338,0.029793,0.046924,-0.231632,-0.186150,-0.309360
MOA001_A549_24H:N01,0.401776,1.197786,0.946556,0.794930,0.662958,0.473484,1.335021,0.338371,0.300303,0.690938,...,0.020668,0.171860,0.862337,0.525409,-0.029795,-0.263026,0.271724,0.934595,0.552001,-0.711617
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
DOSVAL001_HT29_24H:CMAP-000:-666,0.038320,-0.426547,0.183131,0.450992,-0.414180,-0.619587,-0.318295,0.066966,-0.618409,0.539847,...,0.085557,0.018541,0.174058,-0.101450,-0.279539,-0.303862,0.019368,1.111968,0.387193,-0.770082
DOSVAL001_HA1E_24H:CMAP-000:-666,0.319681,-0.182241,0.689418,0.542491,-0.124395,0.252069,-0.348502,0.145006,-0.018389,0.190280,...,-0.048443,0.188158,0.422073,0.123565,0.097611,-0.003442,0.924158,-0.212382,0.166562,0.142994
DOSVAL001_A375_24H:CMAP-000:-666,0.091151,-0.007194,0.360459,0.430177,-0.443078,-0.370296,-0.450974,0.616529,0.258591,0.111886,...,-0.571711,0.084373,0.240619,-0.372428,-0.168089,-0.137313,0.157594,0.256362,0.080780,-0.065995
DOSVAL001_HEPG2_24H:CMAP-000:-666,-0.276361,-0.321295,0.412983,0.040179,-0.144093,-0.374313,-0.488024,0.273988,-0.278131,-0.075510,...,-0.440074,0.220422,-0.144075,-0.162023,-0.328652,-0.300582,0.469960,-0.533808,0.158130,-0.492051


# Train autoencoder with all q1 controls

In [6]:
# model_params = {'encoder_hiddens':[640,384],
#                 'latent_dim': 292,
#                 'decoder_hiddens':[384,640],
#                 'dropout_decoder':0.2,
#                 'dropout_encoder':0.1,
#                 'encoder_activation':torch.nn.ELU(),
#                 'decoder_activation':torch.nn.ELU(),
#                 'lr':0.001,
#                 'schedule_step':200,
#                 'gamma':0.5,
#                 'batch_size':128,
#                 'epochs':500,
#                 'no_folds':10,
#                 'enc_l2_reg':0.01,
#                 'dec_l2_reg':0.01,
#                 'autoencoder_wd': 0.}

In [7]:
# bs=model_params['batch_size']
# k_folds=model_params['no_folds']
# NUM_EPOCHS=model_params['epochs']
# kfold=KFold(n_splits=k_folds,shuffle=True)

In [8]:
# dataset = torch.utils.data.TensorDataset(torch.tensor(X).float())

In [9]:
# valPear = []
# trainPear = []
# for fold,(train_idx,val_idx) in enumerate(kfold.split(dataset)):
    
#     X_train = torch.tensor(X[train_idx,:]).float().to(device)
#     X_val = torch.tensor(X[val_idx,:]).float().to(device)
#     N = X_train.shape[0]
    
#     decoder = Decoder(model_params['latent_dim'],model_params['decoder_hiddens'],gene_size,
#                       dropRate=model_params['dropout_decoder'], 
#                       activation=model_params['decoder_activation']).to(device)
#     encoder = SimpleEncoder(gene_size,model_params['encoder_hiddens'],model_params['latent_dim'],
#                             dropRate=model_params['dropout_encoder'], 
#                             activation=model_params['encoder_activation']).to(device)
    
#     allParams = list(decoder.parameters()) + list(encoder.parameters())
#     optimizer = torch.optim.Adam(allParams, lr=model_params['lr'])
#     scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
#                                                 step_size=model_params['schedule_step'],
#                                                 gamma=model_params['gamma'])
#     trainLoss = []
#     trainLossSTD = []
#     for e in range(NUM_EPOCHS):
#         encoder.train()
#         decoder.train()
        
#         trainloader = getSamples(N, bs)
#         trainLoss_ALL = []
#         for dataIndex in trainloader:
            
#             dataIn = X_train[dataIndex,:]
            
#             optimizer.zero_grad()
#             Xhat = decoder(encoder(dataIn))
#             L2Loss = encoder.L2Regularization(model_params['enc_l2_reg']) + decoder.L2Regularization(model_params['dec_l2_reg'])
#             loss = torch.mean(torch.sum((Xhat - dataIn)**2,dim=1)) + L2Loss
            
#             loss.backward()
#             optimizer.step()
            
#             pear = pearson_r(Xhat.detach(), dataIn.detach())
#             trainLoss_ALL.append(loss.item())
#         if e%100==0:
#             outString = 'Fold={:.0f}'.format(fold)
#             outString += ', Epoch={:.0f}/{:.0f}'.format(e+1,NUM_EPOCHS)
#             outString += ', loss={:.4f}'.format(loss.item())
#             outString += ', Pearson`s r={:.4f}'.format(pear.item())
#             print(outString)
#         scheduler.step()
#         trainLoss.append(np.mean(trainLoss_ALL))
#         trainLossSTD.append(np.std(trainLoss_ALL))
#     outString = 'Fold={:.0f}'.format(fold)
#     outString += ', Epoch={:.0f}/{:.0f}'.format(e+1,NUM_EPOCHS)
#     outString += ', loss={:.4f}'.format(loss.item())
#     outString += ', Pearson`s r={:.4f}'.format(pear.item())
#     print(outString)
    
#     encoder.eval()
#     decoder.eval()
    
#     Xhat = decoder(encoder(X_val))
#     pear = pearson_r(Xhat.detach(), X_val.detach())
    
#     outString = 'Validation performance: Fold={:.0f}'.format(fold)
#     outString += ', Pearson`s r={:.4f}'.format(pear.item())
#     print(outString)
#     valPear.append(pear.item())
    
#     Xhat = decoder(encoder(X_train))
#     pear = pearson_r(Xhat.detach(), X_train.detach())
#     trainPear.append(pear.item())
    
#     plt.plot(range(1,model_params['epochs']+1),np.array(trainLoss))
#     curColor = plt.gca().lines[-1].get_color()
#     plt.fill_between(range(1,model_params['epochs']+1), 
#                     np.array(trainLoss) - np.array(trainLossSTD), 
#                     np.array(trainLoss) + np.array(trainLossSTD),
#                     color=curColor, alpha=0.2)
#     plt.ylabel('Loss')
#     plt.xlabel('Epoch')

In [10]:
# results = pd.DataFrame({'r':trainPear,
#                         'set':'train'})
# results = results.append(pd.DataFrame({'r':valPear,
#                         'set':'validation'}))

In [11]:
# results
# results.to_csv('../results/BaselineCellsAnalysis/ae_results.csv')

In [12]:
# stats_results = mannwhitneyu(np.array(trainPear),np.array(valPear))
# print(stats_results)
# sns.set_theme(style="whitegrid")
# sns.boxplot(data=results,x="set",y = "r")
# plt.axhline(y=0.7, color='black', linestyle='--',linewidth=3)
# plt.show()

# Train trasnlation model

In [13]:
# model_params = {'encoder_1_hiddens':[384,256],
#                 'encoder_2_hiddens':[384,256],
#                 'latent_dim': 128,
#                 'decoder_1_hiddens':[256,384],
#                 'decoder_2_hiddens':[256,384],
#                 'dropout_decoder':0.2,
#                 'dropout_encoder':0.1,
#                 'encoder_activation':torch.nn.ELU(),
#                 'decoder_activation':torch.nn.ELU(),
#                 'state_class_hidden':[256,128,64],
#                 'state_class_drop_in':0.5,
#                 'state_class_drop':0.25,
#                 'no_states':2,
#                 'encoding_lr':0.001,
#                 'schedule_step_enc':200,
#                 'gamma_enc':0.8,
#                 'batch_size':512,
#                 'epochs':1000,
#                 'prior_beta':1.0,
#                 'no_folds':5,
#                 'state_class_reg':1e-02,
#                 'enc_l2_reg':0.001,
#                 'dec_l2_reg':0.001,
#                 'lambda_mi_loss':100,
#                 'cosine_loss': 10,
#                 'reg_classifier': 10,
#                 'similarity_reg' : 10.,
#                 'autoencoder_wd': 0}

In [14]:
model_params = {'encoder_1_hiddens':[384,256],
                'encoder_2_hiddens':[384,256],
                'latent_dim': 128,
                'decoder_1_hiddens':[256,384],
                'decoder_2_hiddens':[256,384],
                'dropout_decoder':0.2,
                'dropout_encoder':0.1,
                'encoder_activation':torch.nn.ELU(),
                'decoder_activation':torch.nn.ELU(),
                'V_dropout':0.25,
                'state_class_hidden':[128,64,32],
                'state_class_drop_in':0.5,
                'state_class_drop':0.25,
                'no_states':2,
                'adv_class_hidden':[128,64,32],
                'adv_class_drop_in':0.3,
                'adv_class_drop':0.1,
                'no_adv_class':2,
                'encoding_lr':0.001,
                'adv_lr':0.001,
                'schedule_step_adv':200,
                'gamma_adv':0.5,
                'schedule_step_enc':200,
                'gamma_enc':0.8,
                'batch_size':512,
                'epochs':1000,
                'prior_beta':1.0,
                'no_folds':5,
                'v_reg':1e-04,
                'state_class_reg':1e-02,
                'enc_l2_reg':0.01,
                'dec_l2_reg':0.01,
                'lambda_mi_loss':100,
                'effsize_reg': 100,
                'cosine_loss': 10,
                'adv_penalnty':100,
                'reg_adv':1000,
                'reg_classifier': 1000,
                'similarity_reg' : 10,
                'adversary_steps':4,
                'autoencoder_wd': 0.,
                'adversary_wd': 0.}

In [15]:
bs =  model_params['batch_size']
# k_folds=model_params['no_folds']
NUM_EPOCHS=model_params['epochs']
# kfold=KFold(n_splits=k_folds,shuffle=True)

In [16]:
num_genes = int(0.5*gene_size)
random_iterations = 5

In [17]:
# lens = []
# for j in range(10000):
#     genes_1_all =[]
#     for i in range(5):
#         genes_1 = np.random.choice(cmap.columns.values, size=num_genes, replace=False)
#         genes_1_all = genes_1_all + list(genes_1)
#     genes_1_all = list(set(genes_1_all))
#     #print(len(genes_1_all))
#     lens.append(len(genes_1_all))
# lens = 100 * np.array(lens) / gene_size
# plt.hist(lens,20)
# plt.xlabel('% of landmarks seen in one part')

In [18]:
df_result_all = pd.DataFrame({})
for cell in ["PC3","HT29","MCF7","A549","NPC","HEPG2","A375","YAPC","U2OS","MCF10A","HA1E","HCC515","ASC","VCAP","HUVEC","HELA"]:
    df_result_all = pd.DataFrame({})
    for j in range(random_iterations):
        #genes_1 = np.random.choice(cmap.columns.values, size=num_genes, replace=False)
        genes_1 = np.load('../results/SameCellimputationModel/genes_subsets/genes_1'+cell+'_iter'+str(j)+'.npy',allow_pickle=True)
        genes_2 = np.setdiff1d(cmap.columns.values,genes_1)
        valPear = []
        valPear_1 = []
        valPear_2 = []
        valPear_controls = []
        valPear_controls_1 = []
        valPear_controls_2 = []
        
        valPear_shuffled = []
        valPear_1_shuffled = []
        valPear_2_shuffled = []
        valPear_controls_shuffled = []
        valPear_controls_1_shuffled = []
        valPear_controls_2_shuffled = []
        
        for i in range(model_params['no_folds']):
            trainInfo = pd.read_csv('../preprocessing/preprocessed_data/SameCellimputationModel/'+cell+'/train_'+str(i)+'.csv',index_col=0)
            valInfo = pd.read_csv('../preprocessing/preprocessed_data/SameCellimputationModel/'+cell+'/val_'+str(i)+'.csv',index_col=0)
            
            if len(trainInfo)<950:
                bs = 256
            else:
                bs = model_params['batch_size']
            
            cmap_train = cmap.loc[trainInfo.sig_id,:]
            cols = cmap_train.columns.values
            cmap_train_shuffled = cmap_train.sample(frac=1, axis=1)
            cmap_train_shuffled.columns = cols
            cmap_val = cmap.loc[valInfo.sig_id,:]
            N = len(cmap_train)

            # Network
            decoder_1 = Decoder(model_params['latent_dim'],model_params['decoder_1_hiddens'],num_genes,
                                dropRate=model_params['dropout_decoder'], 
                                activation=model_params['decoder_activation']).to(device)
            decoder_2 = Decoder(model_params['latent_dim'],model_params['decoder_2_hiddens'],num_genes,
                                dropRate=model_params['dropout_decoder'], 
                                activation=model_params['decoder_activation']).to(device)
            encoder_1 = SimpleEncoder(num_genes,model_params['encoder_1_hiddens'],model_params['latent_dim'],
                                      dropRate=model_params['dropout_encoder'], 
                                      activation=model_params['encoder_activation'],
                                     normalizeOutput=False).to(device)
            encoder_2 = SimpleEncoder(num_genes,model_params['encoder_2_hiddens'],model_params['latent_dim'],
                                          dropRate=model_params['dropout_encoder'], 
                                          activation=model_params['encoder_activation'],
                                     normalizeOutput=False).to(device)
            prior_d = PriorDiscriminator(model_params['latent_dim']).to(device)
            local_d = LocalDiscriminator(model_params['latent_dim'],model_params['latent_dim']).to(device)

            allParams = list(decoder_1.parameters()) + list(encoder_1.parameters())
            allParams = allParams + list(decoder_2.parameters()) + list(encoder_2.parameters())
            allParams = allParams  + list(local_d.parameters())
            allParams = allParams + list(prior_d.parameters())
            optimizer = torch.optim.Adam(allParams, lr=model_params['encoding_lr'])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                        step_size=model_params['schedule_step_enc'],
                                                        gamma=model_params['gamma_enc'])
            trainLoss = []
            trainLossSTD = []
            for e in range(NUM_EPOCHS):
                encoder_1.train()
                decoder_1.train()
                encoder_2.train()
                decoder_2.train()
                prior_d.train()
                local_d.train()

                trainloader = getSamples(N, bs)
                trainLoss_ALL = []
                for dataIndex in trainloader:

                    data_1 = torch.tensor(cmap_train.loc[:,genes_1].values).float()
                    data_1 = data_1[dataIndex,:].to(device)
                    data_2 = torch.tensor(cmap_train.loc[:,genes_2].values).float()
                    data_2 = data_2[dataIndex,:].to(device)
                    
                    conditions = trainInfo.conditionId.values[dataIndex]
                    conditions = np.concatenate((conditions,conditions))
                    size = conditions.size
                    conditions = conditions.reshape(size,1)
                    conditions = conditions == conditions.transpose()
                    conditions = conditions*1
                    mask = torch.tensor(conditions).to(device).detach()
                    pos_mask = mask
                    neg_mask = 1 - mask
                    log_2 = math.log(2.)
                
                    optimizer.zero_grad()
                    
                    z_1 = encoder_1(data_1)
                    z_2 = encoder_2(data_2)
                    
                    #print('Epoch %s'%e)
                    #print(z_1)
                    
                    latent_vectors = torch.cat((z_1, z_2), 0)
                    z_un = local_d(latent_vectors)
                    #z_un_1 = local_d(z_1)
                    #z_un_2 = local_d(z_2)
                    #res_un = torch.matmul(z_un_1, z_un_2.t())
                    res_un = torch.matmul(z_un, z_un.t())
                    
                    #print(res_un)
                    
                    Xhat_1 = decoder_1(z_1)
                    Xhat_2 = decoder_2(z_2)
                    loss_1 = torch.mean(torch.sum((Xhat_1 - data_1)**2,dim=1)) + encoder_1.L2Regularization(model_params['enc_l2_reg']) + decoder_1.L2Regularization(model_params['dec_l2_reg'])
                    loss_2 = torch.mean(torch.sum((Xhat_2 - data_2)**2,dim=1)) +encoder_2.L2Regularization(model_params['enc_l2_reg']) + decoder_2.L2Regularization(model_params['dec_l2_reg'])
                    
                    silimalityLoss = torch.sum(torch.cdist(latent_vectors, latent_vectors) * pos_mask.float()) / pos_mask.float().sum()
                    #silimalityLoss = torch.mean(torch.cdist(z_1, z_2))
                    w1 = latent_vectors.norm(p=2, dim=1, keepdim=True)
                    w2 = latent_vectors.norm(p=2, dim=1, keepdim=True)
                    cosineLoss = torch.mm(latent_vectors, latent_vectors.t()) / (w1 * w2.t()).clamp(min=1e-6)
                    cosineLoss = torch.sum(cosineLoss * pos_mask.float()) / pos_mask.float().sum()
                    #cosineLoss = torch.mean(cosineLoss)

                    p_samples = res_un * pos_mask.float()
                    q_samples = res_un * neg_mask.float()

                    Ep = log_2 - F.softplus(- p_samples)
                    Eq = F.softplus(-q_samples) + q_samples - log_2

                    Ep = (Ep * pos_mask.float()).sum() / pos_mask.float().sum()
                    Eq = (Eq * neg_mask.float()).sum() / neg_mask.float().sum()
                    mi_loss = Eq - Ep
                    #mi_loss = torch.nan_to_num(mi_loss, nan=1e-03)

                    prior = torch.rand_like(torch.cat((z_1, z_2), 0))
                    term_a = torch.log(prior_d(prior)).mean()
                    term_b = torch.log(1.0 - prior_d(torch.cat((z_1, z_2), 0))).mean()
                    prior_loss = -(term_a + term_b) * model_params['prior_beta']
                    
                    loss = loss_1 + loss_2 + model_params[
                        'similarity_reg']*silimalityLoss - model_params[
                        'cosine_loss'] * cosineLoss + prior_loss + model_params['lambda_mi_loss'] * mi_loss
                    #loss = loss_1 + loss_2 + model_params[
                    #    'similarity_reg'] * silimalityLoss + model_params[
                    #    'lambda_mi_loss'] * mi_loss + prior_loss 
                    loss.backward()
                    optimizer.step()
                    
                    #w1 = latent_vectors.norm(p=2, dim=1, keepdim=True)
                    #w2 = latent_vectors.norm(p=2, dim=1, keepdim=True)
                    #cosineLoss = torch.mm(latent_vectors, latent_vectors.t()) / (w1 * w2.t()).clamp(min=1e-6)
                    #cosineLoss = torch.sum(cosineLoss * pos_mask.float()) / pos_mask.float().sum()

                    pear_1 = pearson_r(Xhat_1.detach(), data_1.detach())
                    mse_1 = torch.mean(torch.mean((Xhat_1.detach() - data_1.detach()) ** 2, dim=1))
                    pear_2 = pearson_r(Xhat_2.detach(), data_2.detach())
                    mse_2 = torch.mean(torch.mean((Xhat_2.detach() - data_2.detach()) ** 2, dim=1))
                    trainLoss_ALL.append(loss.item())
                    
                if e%250==0:
                    outString = 'Cell-line : '+cell+', rand_iter {:.0f}/{:.0f}'.format(j + 1, random_iterations)
                    outString += ', Split {:.0f}: Epoch={:.0f}/{:.0f}'.format(i + 1, e + 1, NUM_EPOCHS)
                    outString += ', pearson_1={:.4f}'.format(pear_1.item())
                    outString += ', MSE_1={:.4f}'.format(mse_1.item())
                    outString += ', pearson_2={:.4f}'.format(pear_2.item())
                    outString += ', MSE_2={:.4f}'.format(mse_2.item())
                    outString += ', MI Loss={:.4f}'.format(mi_loss.item())
                    outString += ', Prior Loss={:.4f}'.format(prior_loss.item())
                    outString += ', Cosine ={:.4f}'.format(cosineLoss.item())
                    outString += ', silimalityLoss ={:.4f}'.format(silimalityLoss.item())
                    outString += ', loss={:.4f}'.format(loss.item())
                    print(outString)
                scheduler.step()
                trainLoss.append(np.mean(trainLoss_ALL))
                trainLossSTD.append(np.std(trainLoss_ALL))
            outString = 'Cell-line : '+cell+', rand_iter {:.0f}/{:.0f}'.format(j + 1, random_iterations)
            outString += ', Split {:.0f}: Epoch={:.0f}/{:.0f}'.format(i + 1, e + 1, NUM_EPOCHS)
            outString += ', pearson_1={:.4f}'.format(pear_1.item())
            outString += ', MSE_1={:.4f}'.format(mse_1.item())
            outString += ', pearson_2={:.4f}'.format(pear_2.item())
            outString += ', MSE_2={:.4f}'.format(mse_2.item())
            outString += ', MI Loss={:.4f}'.format(mi_loss.item())
            outString += ', Prior Loss={:.4f}'.format(prior_loss.item())
            outString += ', Cosine ={:.4f}'.format(cosineLoss.item())
            outString += ', silimalityLoss ={:.4f}'.format(silimalityLoss.item())
            outString += ', loss={:.4f}'.format(loss.item())
            print(outString)

            encoder_1.eval()
            decoder_1.eval()
            encoder_2.eval()
            decoder_2.eval()
            prior_d.eval()
            local_d.eval()
            
            print('Validation performance for cell %s for try %s for split %s'%(cell,j+1,i+1))


            X_1 = torch.tensor(cmap_val.loc[:,genes_1].values).float().to(device)
            X_2 = torch.tensor(cmap_val.loc[:,genes_2].values).float().to(device)
                    
            z_1 = encoder_1(X_1)
            z_2 = encoder_2(X_2)
            Xhat_1 = decoder_1(z_1)
            Xhat_2 = decoder_2(z_2)
            
            pear_1 = pearson_r(Xhat_1.detach(), X_1.detach())
            pear_2 = pearson_r(Xhat_2.detach(), X_2.detach())
            valPear_1.append(pear_1.item())
            valPear_2.append(pear_2.item())

            print('Pearson correlation 1: %s'%pear_1.item())
            print('Pearson correlation 2: %s'%pear_2.item())
    
    
            x_hat_2_equivalent = decoder_2(z_1).detach()
            pearson_2 = pearson_r(x_hat_2_equivalent.detach(), X_2.detach())
            print('Pearson correlation 1 to 2: %s'%pearson_2.item())
            x_hat_1_equivalent = decoder_1(z_2).detach()
            pearson_1 = pearson_r(x_hat_1_equivalent.detach(), X_1.detach())
            print('Pearson correlation 2 to 1: %s'%pearson_1.item())
            
            valPear.append([pearson_2.item(),pearson_1.item()])
            
            ### Controls analysis
            #if cell not in ["HUVEC","HELA","ASC","YAPC"]:
            #    controls_info = pd.read_csv('../preprocessing/preprocessed_data/SameCellimputationModel/'+cell+'/controls.csv',index_col=0)
            #    print('Controls Performance')
            #    X_1 = torch.tensor(cmap_controls.loc[controls_info.sig_id,genes_1].values).float().to(device)
            #    X_2 = torch.tensor(cmap_controls.loc[controls_info.sig_id,genes_2].values).float().to(device)
            #    z_1 = encoder_1(X_1)
            #    z_2 = encoder_2(X_2)
            #    Xhat_1 = decoder_1(z_1)
            #    Xhat_2 = decoder_2(z_2)
            #    pear_1 = pearson_r(Xhat_1.detach(), X_1.detach())
            #    pear_2 = pearson_r(Xhat_2.detach(), X_2.detach())
            #    valPear_controls_1.append(pear_1.item())
            #    valPear_controls_2.append(pear_2.item())
            #    print('Pearson correlation 1: %s'%pear_1.item())
            #    print('Pearson correlation 2: %s'%pear_2.item())
            #    x_hat_2_equivalent = decoder_2(z_1).detach()
            #    pearson_2 = pearson_r(x_hat_2_equivalent.detach(), X_2.detach())
            #   print('Pearson correlation 1 to 2: %s'%pearson_2.item())
            #   x_hat_1_equivalent = decoder_1(z_2).detach()
            #    pearson_1 = pearson_r(x_hat_1_equivalent.detach(), X_1.detach())
            #    print('Pearson correlation 2 to 1: %s'%pearson_1.item())
            #    valPear_controls.append([pearson_2.item(),pearson_1.item()])
            
            torch.save(decoder_1,'../results/test_stuff/models_AutoTransOp/'+cell+'/decoder_1_fold%s_iter%s.pt'%(i,j))
            torch.save(decoder_2,'../results/test_stuff/models_AutoTransOp/'+cell+'/decoder_2_fold%s_iter%s.pt'%(i,j))
            torch.save(prior_d,'../results/test_stuff/models_AutoTransOp/'+cell+'/priorDiscr_fold%s_iter%s.pt'%(i,j))
            torch.save(local_d,'../results/test_stuff/models_AutoTransOp/'+cell+'/localDiscr_fold%s_iter%s.pt'%(i,j))
            torch.save(encoder_1,'../results/test_stuff/models_AutoTransOp/'+cell+'/encoder_1_fold%s_iter%s.pt'%(i,j))
            torch.save(encoder_2,'../results/test_stuff/models_AutoTransOp/'+cell+'/encoder_2_fold%s_iter%s.pt'%(i,j))
            
            ###################### Re-do for shuffling ###################
            print('Begin training random shuffled model')
            decoder_1 = Decoder(model_params['latent_dim'],model_params['decoder_1_hiddens'],num_genes,
                                dropRate=model_params['dropout_decoder'], 
                                activation=model_params['decoder_activation']).to(device)
            decoder_2 = Decoder(model_params['latent_dim'],model_params['decoder_2_hiddens'],num_genes,
                                dropRate=model_params['dropout_decoder'], 
                                activation=model_params['decoder_activation']).to(device)
            encoder_1 = SimpleEncoder(num_genes,model_params['encoder_1_hiddens'],model_params['latent_dim'],
                                        dropRate=model_params['dropout_encoder'], 
                                        activation=model_params['encoder_activation'],
                                        normalizeOutput=False).to(device)
            encoder_2 = SimpleEncoder(num_genes,model_params['encoder_2_hiddens'],model_params['latent_dim'],
                                            dropRate=model_params['dropout_encoder'], 
                                            activation=model_params['encoder_activation'],
                                        normalizeOutput=False).to(device)
            prior_d = PriorDiscriminator(model_params['latent_dim']).to(device)
            local_d = LocalDiscriminator(model_params['latent_dim'],model_params['latent_dim']).to(device)
            allParams = list(decoder_1.parameters()) + list(encoder_1.parameters())
            allParams = allParams + list(decoder_2.parameters()) + list(encoder_2.parameters())
            allParams = allParams  + list(local_d.parameters())
            allParams = allParams + list(prior_d.parameters())
            optimizer = torch.optim.Adam(allParams, lr=model_params['encoding_lr'])
            scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                            step_size=model_params['schedule_step_enc'],
                                                            gamma=model_params['gamma_enc'])
            for e in range(NUM_EPOCHS):
                encoder_1.train()
                decoder_1.train()
                encoder_2.train()
                decoder_2.train()
                prior_d.train()
                local_d.train()
                trainloader = getSamples(N, bs)
                trainLoss_ALL = []
                for dataIndex in trainloader:
                    data_1 = torch.tensor(cmap_train_shuffled.loc[:,genes_1].values).float()
                    data_1 = data_1[dataIndex,:].to(device)
                    data_2 = torch.tensor(cmap_train_shuffled.loc[:,genes_2].values).float()
                    data_2 = data_2[dataIndex,:].to(device)
                    conditions = trainInfo.conditionId.values[dataIndex]
                    conditions = np.concatenate((conditions,conditions))
                    size = conditions.size
                    conditions = conditions.reshape(size,1)
                    conditions = conditions == conditions.transpose()
                    conditions = conditions*1
                    mask = torch.tensor(conditions).to(device).detach()
                    pos_mask = mask
                    neg_mask = 1 - mask
                    log_2 = math.log(2.)
                    optimizer.zero_grad()
                    z_1 = encoder_1(data_1)
                    z_2 = encoder_2(data_2)
                    latent_vectors = torch.cat((z_1, z_2), 0)
                    z_un = local_d(latent_vectors)
                    res_un = torch.matmul(z_un, z_un.t())
                    Xhat_1 = decoder_1(z_1)
                    Xhat_2 = decoder_2(z_2)
                    loss_1 = torch.mean(torch.sum((Xhat_1 - data_1)**2,dim=1)) + encoder_1.L2Regularization(model_params['enc_l2_reg']) + decoder_1.L2Regularization(model_params['dec_l2_reg'])
                    loss_2 = torch.mean(torch.sum((Xhat_2 - data_2)**2,dim=1)) +encoder_2.L2Regularization(model_params['enc_l2_reg']) + decoder_2.L2Regularization(model_params['dec_l2_reg'])
                    silimalityLoss = torch.sum(torch.cdist(latent_vectors, latent_vectors) * pos_mask.float()) / pos_mask.float().sum()
                    w1 = latent_vectors.norm(p=2, dim=1, keepdim=True)
                    w2 = latent_vectors.norm(p=2, dim=1, keepdim=True)
                    cosineLoss = torch.mm(latent_vectors, latent_vectors.t()) / (w1 * w2.t()).clamp(min=1e-6)
                    cosineLoss = torch.sum(cosineLoss * pos_mask.float()) / pos_mask.float().sum()
                    p_samples = res_un * pos_mask.float()
                    q_samples = res_un * neg_mask.float()
                    Ep = log_2 - F.softplus(- p_samples)
                    Eq = F.softplus(-q_samples) + q_samples - log_2
                    Ep = (Ep * pos_mask.float()).sum() / pos_mask.float().sum()
                    Eq = (Eq * neg_mask.float()).sum() / neg_mask.float().sum()
                    mi_loss = Eq - Ep
                    prior = torch.rand_like(torch.cat((z_1, z_2), 0))
                    term_a = torch.log(prior_d(prior)).mean()
                    term_b = torch.log(1.0 - prior_d(torch.cat((z_1, z_2), 0))).mean()
                    prior_loss = -(term_a + term_b) * model_params['prior_beta']
                    loss = loss_1 + loss_2 + model_params[
                        'similarity_reg']*silimalityLoss - model_params[
                        'cosine_loss'] * cosineLoss + prior_loss + model_params['lambda_mi_loss'] * mi_loss
                    loss.backward()
                    optimizer.step()
                    pear_1 = pearson_r(Xhat_1.detach(), data_1.detach())
                    mse_1 = torch.mean(torch.mean((Xhat_1.detach() - data_1.detach()) ** 2, dim=1))
                    pear_2 = pearson_r(Xhat_2.detach(), data_2.detach())
                    mse_2 = torch.mean(torch.mean((Xhat_2.detach() - data_2.detach()) ** 2, dim=1))
                    trainLoss_ALL.append(loss.item())
                if e%250==0:
                    outString = 'Cell-line : '+cell+', rand_iter {:.0f}/{:.0f}'.format(j + 1, random_iterations)
                    outString += ', Split {:.0f}: Epoch={:.0f}/{:.0f}'.format(i + 1, e + 1, NUM_EPOCHS)
                    outString += ', pearson_1={:.4f}'.format(pear_1.item())
                    outString += ', MSE_1={:.4f}'.format(mse_1.item())
                    outString += ', pearson_2={:.4f}'.format(pear_2.item())
                    outString += ', MSE_2={:.4f}'.format(mse_2.item())
                    outString += ', MI Loss={:.4f}'.format(mi_loss.item())
                    outString += ', Prior Loss={:.4f}'.format(prior_loss.item())
                    outString += ', Cosine ={:.4f}'.format(cosineLoss.item())
                    outString += ', silimalityLoss ={:.4f}'.format(silimalityLoss.item())
                    outString += ', loss={:.4f}'.format(loss.item())
                    print(outString)
                scheduler.step()
                trainLoss.append(np.mean(trainLoss_ALL))
                trainLossSTD.append(np.std(trainLoss_ALL))
            outString = 'Cell-line : '+cell+', rand_iter {:.0f}/{:.0f}'.format(j + 1, random_iterations)
            outString += ', Split {:.0f}: Epoch={:.0f}/{:.0f}'.format(i + 1, e + 1, NUM_EPOCHS)
            outString += ', pearson_1={:.4f}'.format(pear_1.item())
            outString += ', MSE_1={:.4f}'.format(mse_1.item())
            outString += ', pearson_2={:.4f}'.format(pear_2.item())
            outString += ', MSE_2={:.4f}'.format(mse_2.item())
            outString += ', MI Loss={:.4f}'.format(mi_loss.item())
            outString += ', Prior Loss={:.4f}'.format(prior_loss.item())
            outString += ', Cosine ={:.4f}'.format(cosineLoss.item())
            outString += ', silimalityLoss ={:.4f}'.format(silimalityLoss.item())
            outString += ', loss={:.4f}'.format(loss.item())
            print(outString)
            encoder_1.eval()
            decoder_1.eval()
            encoder_2.eval()
            decoder_2.eval()
            prior_d.eval()
            local_d.eval()
            print('Shuffled performance for cell %s for try %s for split %s'%(cell,j+1,i+1))
            X_1 = torch.tensor(cmap_val.loc[:,genes_1].values).float().to(device)
            X_2 = torch.tensor(cmap_val.loc[:,genes_2].values).float().to(device)
            z_1 = encoder_1(X_1)
            z_2 = encoder_2(X_2)
            Xhat_1 = decoder_1(z_1)
            Xhat_2 = decoder_2(z_2)
            pear_1 = pearson_r(Xhat_1.detach(), X_1.detach())
            pear_2 = pearson_r(Xhat_2.detach(), X_2.detach())
            valPear_1_shuffled.append(pear_1.item())
            valPear_2_shuffled.append(pear_2.item())
            print('Pearson correlation 1: %s'%pear_1.item())
            print('Pearson correlation 2: %s'%pear_2.item())
            x_hat_2_equivalent = decoder_2(z_1).detach()
            pearson_2 = pearson_r(x_hat_2_equivalent.detach(), X_2.detach())
            print('Pearson correlation 1 to 2: %s'%pearson_2.item())
            x_hat_1_equivalent = decoder_1(z_2).detach()
            pearson_1 = pearson_r(x_hat_1_equivalent.detach(), X_1.detach())
            print('Pearson correlation 2 to 1: %s'%pearson_1.item())
            valPear_shuffled.append([pearson_2.item(),pearson_1.item()])
            ### Controls analysis
            #if cell not in ["HUVEC","HELA","ASC","YAPC"]:
            #    controls_info = pd.read_csv('../preprocessing/preprocessed_data/SameCellimputationModel/'+cell+'/controls.csv',index_col=0)
            #    print('Controls Performance shuffled')
            #    X_1 = torch.tensor(cmap_controls.loc[controls_info.sig_id,genes_1].values).float().to(device)
            #    X_2 = torch.tensor(cmap_controls.loc[controls_info.sig_id,genes_2].values).float().to(device)
            #   z_1 = encoder_1(X_1)
            #    z_2 = encoder_2(X_2)
            #    Xhat_1 = decoder_1(z_1)
            #    Xhat_2 = decoder_2(z_2)
            #    pear_1 = pearson_r(Xhat_1.detach(), X_1.detach())
            #    pear_2 = pearson_r(Xhat_2.detach(), X_2.detach())
            #    valPear_controls_1_shuffled.append(pear_1.item())
            #   valPear_controls_2_shuffled.append(pear_2.item())
            #    print('Pearson correlation 1: %s'%pear_1.item())
            #    print('Pearson correlation 2: %s'%pear_2.item())
            #    x_hat_2_equivalent = decoder_2(z_1).detach()
            #    pearson_2 = pearson_r(x_hat_2_equivalent.detach(), X_2.detach())
            #   print('Pearson correlation 1 to 2: %s'%pearson_2.item())
            #    x_hat_1_equivalent = decoder_1(z_2).detach()
            #    pearson_1 = pearson_r(x_hat_1_equivalent.detach(), X_1.detach())
            #    print('Pearson correlation 2 to 1: %s'%pearson_1.item())
            #    valPear_controls_shuffled.append([pearson_2.item(),pearson_1.item()])
            torch.save(decoder_1,'../results/test_stuff/models_AutoTransOp/'+cell+'/shuffled/decoder_1_fold%s_iter%s.pt'%(i,j))
            torch.save(decoder_2,'../results/test_stuff/models_AutoTransOp/'+cell+'/shuffled/decoder_2_fold%s_iter%s.pt'%(i,j))
            torch.save(prior_d,'../results/test_stuff/models_AutoTransOp/'+cell+'/shuffled/priorDiscr_fold%s_iter%s.pt'%(i,j))
            torch.save(local_d,'../results/test_stuff/models_AutoTransOp/'+cell+'/shuffled/localDiscr_fold%s_iter%s.pt'%(i,j))
            torch.save(encoder_1,'../results/test_stuff/models_AutoTransOp/'+cell+'/shuffled/encoder_1_fold%s_iter%s.pt'%(i,j))
            torch.save(encoder_2,'../results/test_stuff/models_AutoTransOp/'+cell+'/shuffled/encoder_2_fold%s_iter%s.pt'%(i,j))
            
        valPear = np.array(valPear)
        valPear_shuffled = np.array(valPear_shuffled)
        df_result = pd.DataFrame({'model_pearson2to1':valPear[:,0],'model_pearson1to2':valPear[:,1],
                                  'recon_pear_2':valPear_2 ,'recon_pear_1':valPear_1})
        df_result_shuffled = pd.DataFrame({'model_pearson2to1':valPear_shuffled[:,0],'model_pearson1to2':valPear_shuffled[:,1],
                                           'recon_pear_2':valPear_2_shuffled ,'recon_pear_1':valPear_1_shuffled})
        df_result['model'] = 'model'
        df_result_shuffled['model'] = 'shuffled'
        df_result['set'] = 'validation'
        df_result_shuffled['set'] = 'validation'
        df_result = df_result.append(df_result_shuffled)
        
        #df_result_controls = pd.DataFrame({'model_pearson2to1':[],'model_pearson1to2':[],
        #                                   'recon_pear_2':[] ,'recon_pear_1':[]})
        #df_result_controls['model'] = []
        #df_result_controls['set'] = []
        #if cell not in ["HUVEC","HELA","ASC","YAPC"]:
        #    valPear_controls_shuffled = np.array(valPear_controls_shuffled)
        #    valPear_controls = np.array(valPear_controls)
        #    df_result_controls = pd.DataFrame({'model_pearson2to1':valPear_controls[:,0],'model_pearson1to2':valPear_controls[:,1],
        #                                   'recon_pear_2':valPear_controls_2 ,'recon_pear_1':valPear_controls_1})
        #    df_result_controls_shuffled = pd.DataFrame({'model_pearson2to1':valPear_controls_shuffled[:,0],'model_pearson1to2':valPear_controls_shuffled[:,1],
        #                                   'recon_pear_2':valPear_controls_2_shuffled ,'recon_pear_1':valPear_controls_1_shuffled})
        #    df_result_controls['model'] = 'model'
        #    df_result_controls_shuffled['model'] = 'shuffled'
        #    df_result_controls['set'] = 'controls'
        #    df_result_controls_shuffled['set'] = 'controls'
        #    df_result_controls = df_result_controls.append(df_result_controls_shuffled)
        #df_result = df_result.append(df_result_controls)
        df_result['cell'] = cell
        df_result['iteration'] = j
        df_result_all = df_result_all.append(df_result)
        df_result_all.to_csv('../results/SameCellimputationModel/translation_results_'+cell+'.csv')#SameCellimputationModel

Cell-line : PC3, rand_iter 1/5, Split 1: Epoch=1/1000, pearson_1=0.0757, MSE_1=3.2391, pearson_2=0.0697, MSE_2=3.2477, MI Loss=1.1247, Prior Loss=1.2561, Cosine =0.5115, silimalityLoss =7.2794, loss=3381.1453
Cell-line : PC3, rand_iter 1/5, Split 1: Epoch=251/1000, pearson_1=0.7576, MSE_1=1.0358, pearson_2=0.7587, MSE_2=1.0355, MI Loss=-0.6234, Prior Loss=0.0002, Cosine =0.8188, silimalityLoss =2.7380, loss=995.0627
Cell-line : PC3, rand_iter 1/5, Split 1: Epoch=501/1000, pearson_1=0.7513, MSE_1=0.9129, pearson_2=0.7504, MSE_2=0.9065, MI Loss=-0.6427, Prior Loss=0.0000, Cosine =0.8159, silimalityLoss =2.1835, loss=862.1006
Cell-line : PC3, rand_iter 1/5, Split 1: Epoch=751/1000, pearson_1=0.7697, MSE_1=0.9017, pearson_2=0.7658, MSE_2=0.8981, MI Loss=-0.6516, Prior Loss=0.0000, Cosine =0.8282, silimalityLoss =1.8238, loss=846.0872
Cell-line : PC3, rand_iter 1/5, Split 1: Epoch=1000/1000, pearson_1=0.7750, MSE_1=0.8332, pearson_2=0.7652, MSE_2=0.8470, MI Loss=-0.6545, Prior Loss=0.0000, 

Cell-line : PC3, rand_iter 1/5, Split 4: Epoch=501/1000, pearson_1=0.7468, MSE_1=0.9120, pearson_2=0.7398, MSE_2=0.9249, MI Loss=-0.6386, Prior Loss=0.0000, Cosine =0.8252, silimalityLoss =2.3176, loss=872.8799
Cell-line : PC3, rand_iter 1/5, Split 4: Epoch=751/1000, pearson_1=0.8037, MSE_1=0.9594, pearson_2=0.7907, MSE_2=0.9780, MI Loss=-0.6498, Prior Loss=0.0000, Cosine =0.8422, silimalityLoss =1.8142, loss=913.9617
Cell-line : PC3, rand_iter 1/5, Split 4: Epoch=1000/1000, pearson_1=0.7907, MSE_1=0.8799, pearson_2=0.7844, MSE_2=0.9064, MI Loss=-0.6539, Prior Loss=0.0000, Cosine =0.8467, silimalityLoss =1.6274, loss=836.4551
Validation performance for cell PC3 for try 1 for split 4
Pearson correlation 1: 0.7560822367668152
Pearson correlation 2: 0.7473733425140381
Pearson correlation 1 to 2: 0.675345778465271
Pearson correlation 2 to 1: 0.6910034418106079
Begin training random shuffled model
Cell-line : PC3, rand_iter 1/5, Split 4: Epoch=1/1000, pearson_1=0.0649, MSE_1=2.7416, pearson

Cell-line : PC3, rand_iter 2/5, Split 2: Epoch=1000/1000, pearson_1=0.7860, MSE_1=0.8868, pearson_2=0.7868, MSE_2=0.8380, MI Loss=-0.6569, Prior Loss=0.0000, Cosine =0.8411, silimalityLoss =1.5693, loss=804.5872
Validation performance for cell PC3 for try 2 for split 2
Pearson correlation 1: 0.7621927857398987
Pearson correlation 2: 0.7675774693489075
Pearson correlation 1 to 2: 0.7004943490028381
Pearson correlation 2 to 1: 0.6958286166191101
Begin training random shuffled model
Cell-line : PC3, rand_iter 2/5, Split 2: Epoch=1/1000, pearson_1=0.0621, MSE_1=3.0386, pearson_2=0.0624, MSE_2=3.1628, MI Loss=1.2741, Prior Loss=1.2155, Cosine =0.5124, silimalityLoss =7.4107, loss=3257.8462
Cell-line : PC3, rand_iter 2/5, Split 2: Epoch=251/1000, pearson_1=0.7676, MSE_1=1.0410, pearson_2=0.7652, MSE_2=1.0491, MI Loss=-0.6257, Prior Loss=0.0001, Cosine =0.8230, silimalityLoss =2.7014, loss=1003.5240
Cell-line : PC3, rand_iter 2/5, Split 2: Epoch=501/1000, pearson_1=0.7694, MSE_1=0.8887, pears

Cell-line : PC3, rand_iter 2/5, Split 5: Epoch=251/1000, pearson_1=0.7160, MSE_1=1.0050, pearson_2=0.7150, MSE_2=1.0021, MI Loss=-0.6275, Prior Loss=0.0001, Cosine =0.8188, silimalityLoss =2.9381, loss=965.2155
Cell-line : PC3, rand_iter 2/5, Split 5: Epoch=501/1000, pearson_1=0.7685, MSE_1=0.9219, pearson_2=0.7685, MSE_2=0.9293, MI Loss=-0.6462, Prior Loss=0.0000, Cosine =0.8272, silimalityLoss =2.1921, loss=877.1718
Cell-line : PC3, rand_iter 2/5, Split 5: Epoch=751/1000, pearson_1=0.7754, MSE_1=0.9434, pearson_2=0.7802, MSE_2=0.9274, MI Loss=-0.6428, Prior Loss=0.0000, Cosine =0.8339, silimalityLoss =1.7431, loss=880.7794
Cell-line : PC3, rand_iter 2/5, Split 5: Epoch=1000/1000, pearson_1=0.7515, MSE_1=0.8636, pearson_2=0.7561, MSE_2=0.8558, MI Loss=-0.6534, Prior Loss=0.0000, Cosine =0.8310, silimalityLoss =1.6489, loss=803.3461
Shuffled performance for cell PC3 for try 2 for split 5
Pearson correlation 1: 0.33406129479408264
Pearson correlation 2: 0.3286129832267761
Pearson correl

Cell-line : PC3, rand_iter 3/5, Split 3: Epoch=751/1000, pearson_1=0.7501, MSE_1=0.8612, pearson_2=0.7474, MSE_2=0.8961, MI Loss=-0.6496, Prior Loss=0.0000, Cosine =0.8113, silimalityLoss =1.9278, loss=826.6763
Cell-line : PC3, rand_iter 3/5, Split 3: Epoch=1000/1000, pearson_1=0.7833, MSE_1=0.8483, pearson_2=0.7774, MSE_2=0.8933, MI Loss=-0.6509, Prior Loss=0.0000, Cosine =0.8363, silimalityLoss =1.6135, loss=814.1008
Shuffled performance for cell PC3 for try 3 for split 3
Pearson correlation 1: 0.33691585063934326
Pearson correlation 2: 0.3254726231098175
Pearson correlation 1 to 2: -0.004306446760892868
Pearson correlation 2 to 1: 0.002012926619499922
Cell-line : PC3, rand_iter 3/5, Split 4: Epoch=1/1000, pearson_1=0.0537, MSE_1=2.9351, pearson_2=0.0734, MSE_2=2.9093, MI Loss=1.2421, Prior Loss=1.2517, Cosine =0.5071, silimalityLoss =7.4219, loss=3080.2786
Cell-line : PC3, rand_iter 3/5, Split 4: Epoch=251/1000, pearson_1=0.7200, MSE_1=0.9910, pearson_2=0.7171, MSE_2=1.0069, MI Loss

Cell-line : PC3, rand_iter 4/5, Split 2: Epoch=1/1000, pearson_1=0.0667, MSE_1=3.0130, pearson_2=0.0640, MSE_2=3.0419, MI Loss=1.4668, Prior Loss=1.2789, Cosine =0.5143, silimalityLoss =7.3372, loss=3204.7146
Cell-line : PC3, rand_iter 4/5, Split 2: Epoch=251/1000, pearson_1=0.7293, MSE_1=0.9899, pearson_2=0.7331, MSE_2=0.9943, MI Loss=-0.6224, Prior Loss=0.0001, Cosine =0.8190, silimalityLoss =2.9119, loss=954.1561
Cell-line : PC3, rand_iter 4/5, Split 2: Epoch=501/1000, pearson_1=0.7751, MSE_1=0.9326, pearson_2=0.7746, MSE_2=0.9406, MI Loss=-0.6416, Prior Loss=0.0000, Cosine =0.8274, silimalityLoss =2.0723, loss=887.1287
Cell-line : PC3, rand_iter 4/5, Split 2: Epoch=751/1000, pearson_1=0.7690, MSE_1=0.8891, pearson_2=0.7675, MSE_2=0.9202, MI Loss=-0.6537, Prior Loss=0.0000, Cosine =0.8301, silimalityLoss =1.8244, loss=850.3524
Cell-line : PC3, rand_iter 4/5, Split 2: Epoch=1000/1000, pearson_1=0.7663, MSE_1=0.8617, pearson_2=0.7698, MSE_2=0.8510, MI Loss=-0.6524, Prior Loss=0.0000, 

Cell-line : PC3, rand_iter 4/5, Split 5: Epoch=501/1000, pearson_1=0.7619, MSE_1=0.9354, pearson_2=0.7624, MSE_2=0.9477, MI Loss=-0.6430, Prior Loss=0.0000, Cosine =0.8171, silimalityLoss =2.1847, loss=893.1481
Cell-line : PC3, rand_iter 4/5, Split 5: Epoch=751/1000, pearson_1=0.7580, MSE_1=0.8859, pearson_2=0.7597, MSE_2=0.8832, MI Loss=-0.6524, Prior Loss=0.0000, Cosine =0.8316, silimalityLoss =1.7919, loss=830.6113
Cell-line : PC3, rand_iter 4/5, Split 5: Epoch=1000/1000, pearson_1=0.7805, MSE_1=0.8369, pearson_2=0.7785, MSE_2=0.8516, MI Loss=-0.6505, Prior Loss=0.0000, Cosine =0.8380, silimalityLoss =1.5696, loss=787.7471
Validation performance for cell PC3 for try 4 for split 5
Pearson correlation 1: 0.7721824645996094
Pearson correlation 2: 0.7745112180709839
Pearson correlation 1 to 2: 0.7069967985153198
Pearson correlation 2 to 1: 0.7052341103553772
Begin training random shuffled model
Cell-line : PC3, rand_iter 4/5, Split 5: Epoch=1/1000, pearson_1=0.0607, MSE_1=3.2222, pearso

Cell-line : PC3, rand_iter 5/5, Split 3: Epoch=1000/1000, pearson_1=0.7702, MSE_1=0.8595, pearson_2=0.7610, MSE_2=0.8515, MI Loss=-0.6513, Prior Loss=0.0000, Cosine =0.8358, silimalityLoss =1.6126, loss=799.0441
Validation performance for cell PC3 for try 5 for split 3
Pearson correlation 1: 0.7690044045448303
Pearson correlation 2: 0.7640540599822998
Pearson correlation 1 to 2: 0.702646791934967
Pearson correlation 2 to 1: 0.7098784446716309
Begin training random shuffled model
Cell-line : PC3, rand_iter 5/5, Split 3: Epoch=1/1000, pearson_1=0.0511, MSE_1=2.9986, pearson_2=0.0655, MSE_2=2.9515, MI Loss=1.1522, Prior Loss=1.2548, Cosine =0.5294, silimalityLoss =7.3105, loss=3121.5535
Cell-line : PC3, rand_iter 5/5, Split 3: Epoch=251/1000, pearson_1=0.7170, MSE_1=0.9337, pearson_2=0.7187, MSE_2=0.9403, MI Loss=-0.6280, Prior Loss=0.0000, Cosine =0.7992, silimalityLoss =3.1996, loss=902.8594
Cell-line : PC3, rand_iter 5/5, Split 3: Epoch=501/1000, pearson_1=0.7739, MSE_1=0.9643, pearson

Cell-line : HT29, rand_iter 1/5, Split 1: Epoch=1/1000, pearson_1=0.1060, MSE_1=3.1794, pearson_2=0.1126, MSE_2=3.3156, MI Loss=1.1357, Prior Loss=1.1827, Cosine =0.5321, silimalityLoss =7.2172, loss=3385.3511
Cell-line : HT29, rand_iter 1/5, Split 1: Epoch=251/1000, pearson_1=0.7728, MSE_1=1.0500, pearson_2=0.7813, MSE_2=1.0627, MI Loss=-0.6423, Prior Loss=0.0000, Cosine =0.8372, silimalityLoss =2.8579, loss=1015.2806
Cell-line : HT29, rand_iter 1/5, Split 1: Epoch=501/1000, pearson_1=0.7714, MSE_1=0.8862, pearson_2=0.7757, MSE_2=0.9111, MI Loss=-0.6501, Prior Loss=0.0000, Cosine =0.8439, silimalityLoss =2.2394, loss=852.2362
Cell-line : HT29, rand_iter 1/5, Split 1: Epoch=751/1000, pearson_1=0.7808, MSE_1=0.8902, pearson_2=0.7862, MSE_2=0.8902, MI Loss=-0.6645, Prior Loss=0.0000, Cosine =0.8464, silimalityLoss =1.9207, loss=837.6716
Cell-line : HT29, rand_iter 1/5, Split 1: Epoch=1000/1000, pearson_1=0.7860, MSE_1=0.8454, pearson_2=0.7921, MSE_2=0.8425, MI Loss=-0.6557, Prior Loss=0.

Cell-line : HT29, rand_iter 1/5, Split 4: Epoch=501/1000, pearson_1=0.7745, MSE_1=0.9111, pearson_2=0.7767, MSE_2=0.9407, MI Loss=-0.6535, Prior Loss=0.0000, Cosine =0.8422, silimalityLoss =2.2377, loss=878.8021
Cell-line : HT29, rand_iter 1/5, Split 4: Epoch=751/1000, pearson_1=0.8154, MSE_1=0.9004, pearson_2=0.8134, MSE_2=0.9182, MI Loss=-0.6620, Prior Loss=0.0000, Cosine =0.8553, silimalityLoss =1.8280, loss=855.8920
Cell-line : HT29, rand_iter 1/5, Split 4: Epoch=1000/1000, pearson_1=0.7858, MSE_1=0.8348, pearson_2=0.7805, MSE_2=0.8824, MI Loss=-0.6544, Prior Loss=0.0000, Cosine =0.8471, silimalityLoss =1.7702, loss=805.1987
Shuffled performance for cell HT29 for try 1 for split 4
Pearson correlation 1: 0.30730313062667847
Pearson correlation 2: 0.31091776490211487
Pearson correlation 1 to 2: 0.0028215497732162476
Pearson correlation 2 to 1: -0.0008073781500570476
Cell-line : HT29, rand_iter 1/5, Split 5: Epoch=1/1000, pearson_1=0.1021, MSE_1=3.3771, pearson_2=0.1231, MSE_2=3.2746,

Cell-line : HT29, rand_iter 2/5, Split 2: Epoch=1000/1000, pearson_1=0.8110, MSE_1=0.8332, pearson_2=0.8135, MSE_2=0.7919, MI Loss=-0.6591, Prior Loss=0.0000, Cosine =0.8606, silimalityLoss =1.6840, loss=758.2990
Shuffled performance for cell HT29 for try 2 for split 2
Pearson correlation 1: 0.3100980222225189
Pearson correlation 2: 0.31404125690460205
Pearson correlation 1 to 2: 0.004155758768320084
Pearson correlation 2 to 1: -0.00019505515228956938
Cell-line : HT29, rand_iter 2/5, Split 3: Epoch=1/1000, pearson_1=0.0925, MSE_1=3.1671, pearson_2=0.1137, MSE_2=3.2710, MI Loss=1.1349, Prior Loss=1.2195, Cosine =0.5333, silimalityLoss =7.1922, loss=3357.2734
Cell-line : HT29, rand_iter 2/5, Split 3: Epoch=251/1000, pearson_1=0.7414, MSE_1=1.0187, pearson_2=0.7485, MSE_2=1.0232, MI Loss=-0.6297, Prior Loss=0.0001, Cosine =0.8355, silimalityLoss =2.9121, loss=982.2238
Cell-line : HT29, rand_iter 2/5, Split 3: Epoch=501/1000, pearson_1=0.7809, MSE_1=0.8580, pearson_2=0.7865, MSE_2=0.8757, 

Cell-line : HT29, rand_iter 3/5, Split 1: Epoch=251/1000, pearson_1=0.7547, MSE_1=0.9666, pearson_2=0.7657, MSE_2=0.9390, MI Loss=-0.6336, Prior Loss=0.0001, Cosine =0.8392, silimalityLoss =2.8142, loss=914.5237
Cell-line : HT29, rand_iter 3/5, Split 1: Epoch=501/1000, pearson_1=0.7893, MSE_1=0.9331, pearson_2=0.7889, MSE_2=0.9086, MI Loss=-0.6394, Prior Loss=0.0000, Cosine =0.8409, silimalityLoss =2.1573, loss=874.1804
Cell-line : HT29, rand_iter 3/5, Split 1: Epoch=751/1000, pearson_1=0.7837, MSE_1=0.8616, pearson_2=0.7920, MSE_2=0.8310, MI Loss=-0.6551, Prior Loss=0.0000, Cosine =0.8493, silimalityLoss =1.8733, loss=795.1815
Cell-line : HT29, rand_iter 3/5, Split 1: Epoch=1000/1000, pearson_1=0.7965, MSE_1=0.8724, pearson_2=0.7984, MSE_2=0.8527, MI Loss=-0.6559, Prior Loss=0.0000, Cosine =0.8518, silimalityLoss =1.7476, loss=808.3113
Validation performance for cell HT29 for try 3 for split 1
Pearson correlation 1: 0.7733427286148071
Pearson correlation 2: 0.7729620933532715
Pearson 

Cell-line : HT29, rand_iter 3/5, Split 4: Epoch=751/1000, pearson_1=0.7727, MSE_1=0.9137, pearson_2=0.7733, MSE_2=0.9262, MI Loss=-0.6592, Prior Loss=0.0000, Cosine =0.8384, silimalityLoss =1.9826, loss=868.2889
Cell-line : HT29, rand_iter 3/5, Split 4: Epoch=1000/1000, pearson_1=0.8078, MSE_1=0.8498, pearson_2=0.8142, MSE_2=0.8020, MI Loss=-0.6552, Prior Loss=0.0000, Cosine =0.8468, silimalityLoss =1.7771, loss=773.1456
Validation performance for cell HT29 for try 3 for split 4
Pearson correlation 1: 0.7749460935592651
Pearson correlation 2: 0.7727134823799133
Pearson correlation 1 to 2: 0.7212855219841003
Pearson correlation 2 to 1: 0.7271867394447327
Begin training random shuffled model
Cell-line : HT29, rand_iter 3/5, Split 4: Epoch=1/1000, pearson_1=0.1192, MSE_1=3.7094, pearson_2=0.1100, MSE_2=3.7830, MI Loss=1.0154, Prior Loss=1.2217, Cosine =0.5325, silimalityLoss =7.1319, loss=3860.2598
Cell-line : HT29, rand_iter 3/5, Split 4: Epoch=251/1000, pearson_1=0.7697, MSE_1=1.0059, p

Cell-line : HT29, rand_iter 4/5, Split 2: Epoch=251/1000, pearson_1=0.7378, MSE_1=0.9718, pearson_2=0.7301, MSE_2=0.9938, MI Loss=-0.6397, Prior Loss=0.0001, Cosine =0.8285, silimalityLoss =2.9581, loss=944.7173
Cell-line : HT29, rand_iter 4/5, Split 2: Epoch=501/1000, pearson_1=0.7726, MSE_1=0.8866, pearson_2=0.7754, MSE_2=0.8860, MI Loss=-0.6619, Prior Loss=0.0000, Cosine =0.8400, silimalityLoss =2.1854, loss=838.4497
Cell-line : HT29, rand_iter 4/5, Split 2: Epoch=751/1000, pearson_1=0.7763, MSE_1=0.8347, pearson_2=0.7772, MSE_2=0.8382, MI Loss=-0.6559, Prior Loss=0.0000, Cosine =0.8453, silimalityLoss =1.8829, loss=785.5630
Cell-line : HT29, rand_iter 4/5, Split 2: Epoch=1000/1000, pearson_1=0.7988, MSE_1=0.8010, pearson_2=0.8016, MSE_2=0.7978, MI Loss=-0.6565, Prior Loss=0.0000, Cosine =0.8557, silimalityLoss =1.6672, loss=745.6425
Shuffled performance for cell HT29 for try 4 for split 2
Pearson correlation 1: 0.30871617794036865
Pearson correlation 2: 0.30458444356918335
Pearson 

Cell-line : HT29, rand_iter 4/5, Split 5: Epoch=751/1000, pearson_1=0.7954, MSE_1=0.8364, pearson_2=0.7901, MSE_2=0.8524, MI Loss=-0.6528, Prior Loss=0.0000, Cosine =0.8390, silimalityLoss =1.8252, loss=792.6458
Cell-line : HT29, rand_iter 4/5, Split 5: Epoch=1000/1000, pearson_1=0.7802, MSE_1=0.8192, pearson_2=0.7797, MSE_2=0.8257, MI Loss=-0.6570, Prior Loss=0.0000, Cosine =0.8550, silimalityLoss =1.6603, loss=767.5474
Shuffled performance for cell HT29 for try 4 for split 5
Pearson correlation 1: 0.31477090716362
Pearson correlation 2: 0.31160062551498413
Pearson correlation 1 to 2: -0.013905717059969902
Pearson correlation 2 to 1: -0.013395796529948711
Cell-line : HT29, rand_iter 5/5, Split 1: Epoch=1/1000, pearson_1=0.1113, MSE_1=3.0641, pearson_2=0.1066, MSE_2=3.2423, MI Loss=1.1146, Prior Loss=1.2043, Cosine =0.5287, silimalityLoss =7.2256, loss=3291.1956
Cell-line : HT29, rand_iter 5/5, Split 1: Epoch=251/1000, pearson_1=0.7715, MSE_1=0.9561, pearson_2=0.7759, MSE_2=0.9514, MI 

Cell-line : HT29, rand_iter 5/5, Split 4: Epoch=251/1000, pearson_1=0.7361, MSE_1=0.9980, pearson_2=0.7473, MSE_2=0.9948, MI Loss=-0.6396, Prior Loss=0.0001, Cosine =0.8326, silimalityLoss =2.9170, loss=957.8120
Cell-line : HT29, rand_iter 5/5, Split 4: Epoch=501/1000, pearson_1=0.7983, MSE_1=0.9120, pearson_2=0.7984, MSE_2=0.9323, MI Loss=-0.6505, Prior Loss=0.0000, Cosine =0.8451, silimalityLoss =2.1664, loss=874.7108
Cell-line : HT29, rand_iter 5/5, Split 4: Epoch=751/1000, pearson_1=0.7800, MSE_1=0.9262, pearson_2=0.7814, MSE_2=0.9148, MI Loss=-0.6525, Prior Loss=0.0000, Cosine =0.8411, silimalityLoss =1.9665, loss=869.3046
Cell-line : HT29, rand_iter 5/5, Split 4: Epoch=1000/1000, pearson_1=0.8024, MSE_1=0.8143, pearson_2=0.8046, MSE_2=0.8392, MI Loss=-0.6618, Prior Loss=0.0000, Cosine =0.8581, silimalityLoss =1.6836, loss=772.4140
Validation performance for cell HT29 for try 5 for split 4
Pearson correlation 1: 0.7694718241691589
Pearson correlation 2: 0.7800646424293518
Pearson 

Cell-line : MCF7, rand_iter 1/5, Split 2: Epoch=751/1000, pearson_1=0.7374, MSE_1=0.8209, pearson_2=0.7319, MSE_2=0.8330, MI Loss=-0.6489, Prior Loss=0.0000, Cosine =0.8284, silimalityLoss =2.0853, loss=779.4378
Cell-line : MCF7, rand_iter 1/5, Split 2: Epoch=1000/1000, pearson_1=0.7714, MSE_1=0.8598, pearson_2=0.7734, MSE_2=0.8852, MI Loss=-0.6612, Prior Loss=0.0000, Cosine =0.8360, silimalityLoss =1.7946, loss=818.5253
Validation performance for cell MCF7 for try 1 for split 2
Pearson correlation 1: 0.7500357031822205
Pearson correlation 2: 0.7560182213783264
Pearson correlation 1 to 2: 0.6864781379699707
Pearson correlation 2 to 1: 0.6831030249595642
Begin training random shuffled model
Cell-line : MCF7, rand_iter 1/5, Split 2: Epoch=1/1000, pearson_1=0.0599, MSE_1=2.8215, pearson_2=0.0755, MSE_2=2.6153, MI Loss=1.3272, Prior Loss=1.2450, Cosine =0.5108, silimalityLoss =7.4118, loss=2889.2610
Cell-line : MCF7, rand_iter 1/5, Split 2: Epoch=251/1000, pearson_1=0.7348, MSE_1=1.1038, p

Cell-line : MCF7, rand_iter 1/5, Split 5: Epoch=1/1000, pearson_1=0.0720, MSE_1=3.0570, pearson_2=0.0614, MSE_2=2.8836, MI Loss=1.1846, Prior Loss=1.2422, Cosine =0.5202, silimalityLoss =7.2178, loss=3119.2761
Cell-line : MCF7, rand_iter 1/5, Split 5: Epoch=251/1000, pearson_1=0.7339, MSE_1=0.9491, pearson_2=0.7437, MSE_2=0.8946, MI Loss=-0.6272, Prior Loss=0.0001, Cosine =0.8260, silimalityLoss =2.7543, loss=883.4979
Cell-line : MCF7, rand_iter 1/5, Split 5: Epoch=501/1000, pearson_1=0.7617, MSE_1=0.9372, pearson_2=0.7598, MSE_2=0.9121, MI Loss=-0.6426, Prior Loss=0.0000, Cosine =0.8268, silimalityLoss =2.1271, loss=876.2111
Cell-line : MCF7, rand_iter 1/5, Split 5: Epoch=751/1000, pearson_1=0.7595, MSE_1=0.8308, pearson_2=0.7618, MSE_2=0.8070, MI Loss=-0.6406, Prior Loss=0.0000, Cosine =0.8189, silimalityLoss =1.9346, loss=769.3917
Cell-line : MCF7, rand_iter 1/5, Split 5: Epoch=1000/1000, pearson_1=0.7362, MSE_1=0.8397, pearson_2=0.7463, MSE_2=0.8025, MI Loss=-0.6487, Prior Loss=0.0

KeyboardInterrupt: 