In [3]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from xml.dom import minidom

import gc
import scanpy as sc
import anndata
import umap
from sklearn.preprocessing import StandardScaler
from sklearn.feature_extraction.text import TfidfTransformer

import torch
import torch.nn.functional as F
from contrastive_loss_pytorch import ContrastiveLoss
import activations

from sklearn.metrics import mean_absolute_error
from scipy import stats

## Load Training and validation data

In [2]:
cancer_mrna_tr = np.load('./lc_data/lc_cancer_mrna_tr.npy')
hc_mrna_tr = np.load('./lc_data/lc_hc_mrna_tr.npy')

cancer_mrna_val = np.load('./lc_data/lc_cancer_mrna_val.npy')
hc_mrna_val = np.load('./lc_data/lc_hc_mrna_val.npy')

cancer_exrna_tr = np.load('./lc_data/lc_cancer_exrna_tr.npy')
hc_exrna_tr = np.load('./lc_data/lc_hc_exrna_tr.npy')

cancer_exrna_val = np.load('./lc_data/lc_cancer_exrna_val.npy')
hc_exrna_val = np.load('./lc_data/lc_hc_exrna_val.npy')

In [13]:
mrna_std=[]
exrna_std = []
for i in range(cancer_mrna_tr.shape[1]):
    mrna_std.append(np.std(np.concatenate([cancer_mrna_tr[:,i],hc_mrna_tr[:,i]])))
for i in range(cancer_exrna_tr.shape[1]):
    exrna_std.append(np.std(np.concatenate([cancer_exrna_tr[:,i],hc_exrna_tr[:,i]])))
    
print(len(mrna_std))
print(len(exrna_std))

1927
2537


## Random Pairing

In [5]:
import random


dat_mrna = []
dat_exrna = []
N = 10000
x = np.random.randint(0,2,size=N)

for i in range(N):
    if(x[i]==1):
        # select healthy
        dat_exrna.append(hc_exrna_tr[int(np.random.choice(len(hc_exrna_tr)))])
        dat_mrna.append(hc_mrna_tr[int(np.random.choice(len(hc_mrna_tr)))])
    else:
        # select cancer
        dat_exrna.append(cancer_exrna_tr[int(np.random.choice(len(cancer_exrna_tr)))])
        dat_mrna.append(cancer_mrna_tr[int(np.random.choice(len(cancer_mrna_tr)))])
        
dat_exrna = np.asarray(dat_exrna)
dat_mrna = np.asarray(dat_mrna)

In [6]:
def genDataRand(N=10000):
    x = np.random.randint(0,2,size=N)
    dat_mrna = []
    dat_exrna = []

    for i in range(N):
        if(x[i]==1):
            # select healthy
            # dat_exrna.append((hc_exrna[int(np.random.choice(len(hc_exrna)))] - mn_ex)/sd_ex)
            # dat_mrna.append((hc_mrna[int(np.random.choice(len(hc_mrna)))] - mn_m)/sd_m)
            dat_exrna.append(hc_exrna_tr[int(np.random.choice(len(hc_exrna_tr)))])
            dat_mrna.append(hc_mrna_tr[int(np.random.choice(len(hc_mrna_tr)))])
        else:
            # select cancer
            # dat_exrna.append((cancer_exrna[int(np.random.choice(len(cancer_exrna_tr)))] - mn_ex)/sd_ex)
            # dat_mrna.append((cancer_mrna[int(np.random.choice(len(cancer_mrna_tr)))] - mn_m)/sd_m)
            dat_exrna.append(cancer_exrna_tr[int(np.random.choice(len(cancer_exrna_tr)))])
            dat_mrna.append(cancer_mrna_tr[int(np.random.choice(len(cancer_mrna_tr)))])

    dat_exrna = np.asarray(dat_exrna)
    dat_mrna = np.asarray(dat_mrna)
    return [dat_mrna,dat_exrna]

In [7]:
dat_exrna.shape

(10000, 2537)

In [17]:
dat_mrna.shape

(10000, 1927)

## Encoder-Decoder (exRNA to mRNA)

In [13]:
# Creating a PyTorch class
class AE(torch.nn.Module):
	def __init__(self,input_dim=dat_exrna.shape[1], output_dim=dat_mrna.shape[1]):
		super().__init__()
		# Building a linear encoder with Linear
		# layer followed by Relu activation function
		self.encoder = torch.nn.Sequential(
			torch.nn.Linear(input_dim, 512),
			torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
			torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
			torch.nn.ReLU(),
			torch.nn.Linear(128, 64),
			torch.nn.ReLU(),
			torch.nn.Linear(64, 36),
			torch.nn.ReLU(),
			torch.nn.Linear(36, 18),
		)
		
		# Building an linear decoder with Linear
		self.decoder = torch.nn.Sequential(
			torch.nn.Linear(18, 36),
			torch.nn.ReLU(),
			torch.nn.Linear(36, 64),
			torch.nn.ReLU(),
			torch.nn.Linear(64, 128),
			torch.nn.ReLU(),
			torch.nn.Linear(128, 256),
            torch.nn.ReLU(),
			torch.nn.Linear(256, 512),
            torch.nn.ReLU(),
			torch.nn.Linear(512, output_dim),
			torch.nn.Sigmoid()
		)
	def forward(self, x):
		encoded = self.encoder(x)
		decoded = self.decoder(encoded)
		return decoded


In [14]:
# Model Initialization
model = AE()

# Validation using MSE Loss function
loss_function = torch.nn.MSELoss()

# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(),
							lr = 1e-1,
							weight_decay = 1e-8)


In [15]:
def AE_trainer(X_a, X_b, model, batch_size = 64, num_epoch=5, 
                        f_temp = 0.1, p_temp = 1.0):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     f_con = ContrastiveLoss(batch_size = batch_size,temperature = f_temp).to(device)
    opt = torch.optim.Adam(model.parameters(),lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
    
    s_m_max = 0
    
    for k in range(num_epoch):
        
        model.to(device)
        
        X_a,X_b = genDataRand()
        
        n = X_a.shape[0]
        X_tensor_A=torch.tensor(X_a).float()
        X_tensor_B=torch.tensor(X_b).float()
        
        losses = 0
        
        for j in range(n//batch_size):
            opt.zero_grad()
            inputs_a = X_tensor_A[j*batch_size:(j+1)*batch_size,:].to(device)
            
            inputs_b = X_tensor_B[j*batch_size:(j+1)*batch_size,:].to(device)
            
            output_a = model(inputs_b)
#             feas2,o2,_,_ = model(inputs_a2,inputs_b)
        
            fea_mi = loss_function(inputs_a,output_a)#+f_con(feas,feas2)
#             p_mi = p_con(o.T,no.T)+p_con(o.T,o2.T)
        
            
            loss = fea_mi # + p_mi 
#             opt.zero_grad()
            loss.backward()
            opt.step()
        
            del inputs_a
            del inputs_b
            del output_a
        
            losses += loss.data.tolist()
        
        del X_tensor_A
        del X_tensor_B
        
        # validation codes
        model.eval()
                
        ##Val data
        X_all_tensor_cm_val = torch.tensor(cancer_mrna_val).float()
        X_all_tensor_hm_val = torch.tensor(hc_mrna_val).float()

        X_all_tensor_cex_val = torch.tensor(cancer_exrna_val).float()
        X_all_tensor_hex_val = torch.tensor(hc_exrna_val).float()

        cex_encoder_val = model(X_all_tensor_cex_val)
        hex_encoder_val = model(X_all_tensor_hex_val)
        
        pred_m = torch.concatenate([cex_encoder_val, hex_encoder_val]).detach().numpy()
#         pred_m.requires_grad = False
        true_m = np.concatenate([cancer_mrna_val, hc_mrna_val])
        p_m = stats.pearsonr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0]
        score_m = p_m
                
        print("Epoch: {}, Total loss: {}".format(k,str(round(losses,4))))
        
        if score_m > s_m_max:
            torch.save({'model_state_dict': model.state_dict()}, './models/lc_compare_extom.pt')
            print('Saved!')
            print('\t',stats.pearsonr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0])
            
            s_m_max = score_m
            
        if score_m >= 0.9:
            torch.save({'model_state_dict': model.state_dict()}, './models/lc_compare_extom.pt')
            print('Final Saved!')
            print('\t',stats.pearsonr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0])
            break
        
        gc.collect()

In [None]:
model.train()
AE_trainer(X_a = dat_mrna, X_b = dat_exrna, model = model, num_epoch=50)

In [None]:
torch.save({'model_state_dict': model.state_dict()}, './models/lc_compare_extom.pt')

## Encoder-Decoder (mRNA to exRNA)

In [18]:
# Creating a PyTorch class
# 28*28 ==> 9 ==> 28*28
class AE_mtoex(torch.nn.Module):
	def __init__(self,input_dim=dat_mrna.shape[1], output_dim=dat_exrna.shape[1]):
		super().__init__()
		# Building a linear encoder with Linear
		# layer followed by Relu activation function
		# 784 ==> 9
		self.encoder = torch.nn.Sequential(
			torch.nn.Linear(input_dim, 512),
			torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
			torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
			torch.nn.ReLU(),
			torch.nn.Linear(128, 64),
			torch.nn.ReLU(),
			torch.nn.Linear(64, 36),
			torch.nn.ReLU(),
			torch.nn.Linear(36, 18),
# 			torch.nn.ReLU(),
# 			torch.nn.Linear(18, 9)
		)
		
		# Building an linear decoder with Linear
		# layer followed by Relu activation function
		# The Sigmoid activation function
		# outputs the value between 0 and 1
		# 9 ==> 784
		self.decoder = torch.nn.Sequential(
# 			torch.nn.Linear(9, 18),
# 			torch.nn.ReLU(),
			torch.nn.Linear(18, 36),
			torch.nn.ReLU(),
			torch.nn.Linear(36, 64),
			torch.nn.ReLU(),
			torch.nn.Linear(64, 128),
			torch.nn.ReLU(),
			torch.nn.Linear(128, 256),
            torch.nn.ReLU(),
			torch.nn.Linear(256, 512),
            torch.nn.ReLU(),
			torch.nn.Linear(512, output_dim),
			torch.nn.Sigmoid()
		)
	def forward(self, x):
		encoded = self.encoder(x)
		decoded = self.decoder(encoded)
		return decoded


In [19]:
# Model Initialization
model_mtoex = AE_mtoex()

# Validation using MSE Loss function
loss_function = torch.nn.MSELoss()

# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model_mtoex.parameters(),
							lr = 1e-1,
							weight_decay = 1e-8)


In [20]:
def AE_mtoex_trainer(X_a, X_b, model, batch_size = 64, num_epoch=5, 
                        f_temp = 0.1, p_temp = 1.0):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     f_con = ContrastiveLoss(batch_size = batch_size,temperature = f_temp).to(device)
    opt = torch.optim.Adam(model.parameters(),lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
    
    s_ex_max=0
    
    for k in range(num_epoch):
        
        model_mtoex.to(device)
        
        X_a,X_b = genDataRand()
        
        n = X_a.shape[0]
        X_tensor_A=torch.tensor(X_a).float()
        X_tensor_B=torch.tensor(X_b).float()
        
        losses = 0
        
        for j in range(n//batch_size):
            opt.zero_grad()
            inputs_a = X_tensor_A[j*batch_size:(j+1)*batch_size,:].to(device)
            
            inputs_b = X_tensor_B[j*batch_size:(j+1)*batch_size,:].to(device)
            
            output_a = model(inputs_a)
#             feas2,o2,_,_ = model(inputs_a2,inputs_b)
        
            fea_mi = loss_function(inputs_b,output_a)#+f_con(feas,feas2)
#             p_mi = p_con(o.T,no.T)+p_con(o.T,o2.T)
        
            
            loss = fea_mi # + p_mi 
#             opt.zero_grad()
            loss.backward()
            opt.step()
        
            del inputs_a
            del inputs_b
            del output_a
        
            losses += loss.data.tolist()
        
        del X_tensor_A
        del X_tensor_B
        
        # validation codes
        model.eval()
                
        ##Val data
        X_all_tensor_cm_val = torch.tensor(cancer_mrna_val).float()
        X_all_tensor_hm_val = torch.tensor(hc_mrna_val).float()

        cex_val = model(X_all_tensor_cm_val)
        hex_val = model(X_all_tensor_hm_val)
        
        pred_ex = torch.concatenate([cex_val, hex_val]).detach().numpy()
#         pred_m.requires_grad = False
        true_ex = np.concatenate([cancer_exrna_val, hc_exrna_val])
        p_ex = stats.pearsonr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0]
        score_ex = p_ex
                
        print("Epoch: {}, Total loss: {}".format(k,str(round(losses,4))))
        
        if score_ex > s_ex_max:
            torch.save({'model_state_dict': model.state_dict()}, './models/lc_compare_mtoex.pt')
            print('Saved!')
            print('\t',stats.spearmanr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0])
            print('\t',stats.pearsonr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0])

            s_ex_max = score_ex
        
        if score_ex >= 0.9:
            torch.save({'model_state_dict': model.state_dict()}, './models/lc_compare_mtoex.pt')
            print('Final Saved!')
            print('\t',stats.pearsonr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0])
            break
        
        gc.collect()

In [229]:
model_mtoex.train()
AE_mtoex_trainer(X_a = dat_mrna, X_b = dat_exrna, model = model_mtoex, num_epoch=50)

Epoch: 0, Total loss: 18727273.793
Saved!
	 0.12448049264874982
	 0.4668589308803987
Epoch: 1, Total loss: 18297565.6562
Saved!
	 0.5243088003702284
	 0.19209185727348066
Epoch: 2, Total loss: 18052712.6602
Epoch: 3, Total loss: 17526781.8828
Epoch: 4, Total loss: 18418107.0898
Epoch: 5, Total loss: 17654145.6973
Epoch: 6, Total loss: 18160411.7617
Epoch: 7, Total loss: 17662953.5625
Epoch: 8, Total loss: 18200707.0449
Epoch: 9, Total loss: 19167956.2402
Epoch: 10, Total loss: 17482621.3301
Epoch: 11, Total loss: 18113990.752
Epoch: 12, Total loss: 18225185.584
Epoch: 13, Total loss: 18537195.8711
Epoch: 14, Total loss: 17001341.2422
Epoch: 15, Total loss: 17457974.9531
Epoch: 16, Total loss: 18061046.5908
Epoch: 17, Total loss: 17962874.8887
Epoch: 18, Total loss: 17377728.2617
Epoch: 19, Total loss: 17928324.2852
Epoch: 20, Total loss: 18329716.8906
Epoch: 21, Total loss: 18387381.2148
Epoch: 22, Total loss: 18863593.8867
Epoch: 23, Total loss: 18085594.4316
Epoch: 24, Total loss: 17

KeyboardInterrupt: 

In [21]:
torch.save({'model_state_dict': model_mtoex.state_dict()}, './models/lc_compare_mtoex.pt')

AE_mtoex(
  (encoder): Sequential(
    (0): Linear(in_features=1927, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=256, bias=True)
    (3): ReLU()
    (4): Linear(in_features=256, out_features=128, bias=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=64, bias=True)
    (7): ReLU()
    (8): Linear(in_features=64, out_features=36, bias=True)
    (9): ReLU()
    (10): Linear(in_features=36, out_features=18, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=18, out_features=36, bias=True)
    (1): ReLU()
    (2): Linear(in_features=36, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=128, bias=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=256, bias=True)
    (7): ReLU()
    (8): Linear(in_features=256, out_features=512, bias=True)
    (9): ReLU()
    (10): Linear(in_features=512, out_features=2537, bias=True)
    (11): Sigmoid()
  )
)

## Trained using both exRNA

In [18]:
## random pairing
import random

dat_exrna = []
dat_exrna1 = []
N = 10000
x = np.random.randint(0,2,size=N)

for i in range(N):
    if(x[i]==1):
        # select healthy
        dat_exrna.append(hc_exrna[int(np.random.choice(len(hc_exrna_tr)))])
        dat_exrna1.append(hc_exrna[int(np.random.choice(len(hc_exrna_tr)))])
    else:
        # select cancer
        dat_exrna.append(cancer_exrna[int(np.random.choice(len(cancer_exrna_tr)))])
        dat_exrna1.append(cancer_exrna[int(np.random.choice(len(cancer_exrna_tr)))])
        
dat_exrna = np.asarray(dat_exrna)
dat_exrna1 = np.asarray(dat_exrna1)

In [19]:
## Random pairing function
def genexRand(N=10000):
    x = np.random.randint(0,2,size=N)
    dat_exrna = []
    dat_exrna1 = []

    for i in range(N):
        if(x[i]==1):
            # select healthy
            # dat_exrna.append((hc_exrna[int(np.random.choice(len(hc_exrna)))] - mn_ex)/sd_ex)
            # dat_mrna.append((hc_mrna[int(np.random.choice(len(hc_mrna)))] - mn_m)/sd_m)
            dat_exrna.append(hc_exrna[int(np.random.choice(len(hc_exrna_tr)))])
            dat_exrna1.append(hc_exrna[int(np.random.choice(len(hc_exrna_tr)))])
        else:
            # select cancer
            # dat_exrna.append((cancer_exrna[int(np.random.choice(len(cancer_exrna_tr)))] - mn_ex)/sd_ex)
            # dat_mrna.append((cancer_mrna[int(np.random.choice(len(cancer_mrna_tr)))] - mn_m)/sd_m)
            dat_exrna.append(cancer_exrna[int(np.random.choice(len(cancer_exrna_tr)))])
            dat_exrna1.append(cancer_exrna[int(np.random.choice(len(cancer_exrna_tr)))])

    dat_exrna = np.asarray(dat_exrna)
    dat_exrna1 = np.asarray(dat_exrna1)
    return [dat_exrna,dat_exrna1]

In [20]:
class CrossPred_exex(torch.nn.Module):
    def __init__(self,input_dim_a=2000,input_dim_b=2000,emb_dim=128):
        super(CrossPred_exex, self).__init__()
        self.input_dim_a = input_dim_a
        self.input_dim_b = input_dim_b
#         self.clf_out = clf_out
        self.encoder_a = torch.nn.Sequential(
            torch.nn.Linear(self.input_dim_a, 1000),
            torch.nn.BatchNorm1d(1000),
            torch.nn.ReLU(),
            torch.nn.Linear(1000, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, emb_dim))
        self.encoder_b = torch.nn.Sequential(
            torch.nn.Linear(self.input_dim_b, 1000),
            torch.nn.BatchNorm1d(1000),
            torch.nn.ReLU(),
            torch.nn.Linear(1000, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, emb_dim))
#         self.clf = torch.nn.Sequential(
#             torch.nn.Linear(128, self.clf_out),
#             torch.nn.Softmax(dim=1))
#         self.feature = torch.nn.Sequential(
#             torch.nn.Linear(128, 32))
        
    def forward(self, x_a,x_b):
        out_a = self.encoder_a(x_a)
#         f_a = self.feature(out_a)
#         y_a = self.clf(out_a)
        
        out_b = self.encoder_b(x_b)
#         f_b = self.feature(out_b)
#         y_b = self.clf(out_b)
        return out_a,out_b
    
def CrossPred_trainer_exex(X_a, X_b, model, batch_size = 64, num_epoch=5, 
                        f_temp = 0.1, p_temp = 1.0):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    f_con = ContrastiveLoss(batch_size = batch_size,temperature = f_temp).to(device)
#     p_con = ContrastiveLoss(batch_size = model.clf_out,temperature = p_temp)
#     opt = torch.optim.SGD(model.parameters(),lr=0.01, momentum=0.9,weight_decay=5e-4)
    opt = torch.optim.Adam(model.parameters(),lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
    
    for k in range(num_epoch):
        
        model.to(device)
        
        X_a,X_b = genexRand()
        
        n = X_a.shape[0]
#         r = np.random.permutation(n)
#         X_train_a = X_a[r,:]
        X_tensor_A=torch.tensor(X_a).float()
#         X_train_b = X_b[r,:]
        X_tensor_B=torch.tensor(X_b).float()
        
        losses = 0
        
        for j in range(n//batch_size):
            opt.zero_grad()
            inputs_a = X_tensor_A[j*batch_size:(j+1)*batch_size,:].to(device)
#             inputs_a2 = inputs_a + torch.normal(0,1,inputs_a.shape).to(device)
#             inputs_a = inputs_a + torch.normal(0,1,inputs_a.shape).to(device)
            
            inputs_b = X_tensor_B[j*batch_size:(j+1)*batch_size,:].to(device)
#             inputs_b = inputs_b + torch.normal(0,1,inputs_b.shape).to(device)
            
            feas,nfeas = model(inputs_a,inputs_b)
#             feas2,o2,_,_ = model(inputs_a2,inputs_b)
        
            fea_mi = f_con(feas,nfeas)#+f_con(feas,feas2)
#             p_mi = p_con(o.T,no.T)+p_con(o.T,o2.T)
        
            
            loss = fea_mi # + p_mi 
#             opt.zero_grad()
            loss.backward()
            opt.step()
        
            del inputs_a
            del inputs_b
        
            losses += loss.data.tolist()
        print("Epoch: {}, Total loss: {}".format(k,str(round(losses,4))))
        
        del X_tensor_A
        del X_tensor_B
        
        if k % 20 == 0:
            torch.save({'model_state_dict': net_exex.state_dict()}, './models/lc_model_exex.pt')
            print('Saved!')
        
        gc.collect()

In [21]:
def weights_init(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform(m.weight.data)
        m.bias.data.zero_()

In [22]:
net_exex = CrossPred_exex(input_dim_a=dat_exrna.shape[1],
                         input_dim_b=dat_exrna1.shape[1])
net_exex.apply(weights_init)##initialize weights

  torch.nn.init.xavier_uniform(m.weight.data)


CrossPred_exex(
  (encoder_a): Sequential(
    (0): Linear(in_features=2537, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1000, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
  )
  (encoder_b): Sequential(
    (0): Linear(in_features=2537, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1000, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
  )
)

In [23]:
##3. Training process
net_exex.train()
CrossPred_trainer_exex(X_a = dat_exrna, X_b = dat_exrna1, model = net_exex, num_epoch=50)

Epoch: 0, Total loss: 1025.1583
Saved!
Epoch: 1, Total loss: 759.1899
Epoch: 2, Total loss: 695.9267
Epoch: 3, Total loss: 675.3307
Epoch: 4, Total loss: 665.1807
Epoch: 5, Total loss: 661.0111
Epoch: 6, Total loss: 657.5414
Epoch: 7, Total loss: 655.7489
Epoch: 8, Total loss: 654.8288
Epoch: 9, Total loss: 653.2891
Epoch: 10, Total loss: 652.7399
Epoch: 11, Total loss: 652.0308
Epoch: 12, Total loss: 650.9802
Epoch: 13, Total loss: 650.2155
Epoch: 14, Total loss: 650.9889
Epoch: 15, Total loss: 649.9143
Epoch: 16, Total loss: 649.8277
Epoch: 17, Total loss: 649.7807
Epoch: 18, Total loss: 649.0556
Epoch: 19, Total loss: 650.0605
Epoch: 20, Total loss: 649.4861
Saved!
Epoch: 21, Total loss: 649.1426
Epoch: 22, Total loss: 648.9372
Epoch: 23, Total loss: 648.8659
Epoch: 24, Total loss: 649.7182
Epoch: 25, Total loss: 649.6467
Epoch: 26, Total loss: 648.962
Epoch: 27, Total loss: 649.0123
Epoch: 28, Total loss: 648.8417
Epoch: 29, Total loss: 649.0931
Epoch: 30, Total loss: 648.8071
Epoc

In [24]:
torch.save({'model_state_dict': net_exex.state_dict()}, './models/lc_model_exex.pt')

## Trained using both mRNA

In [27]:
## Random pairing 
import random

dat_mrna = []
dat_mrna1 = []
N = 10000
x = np.random.randint(0,2,size=N)

for i in range(N):
    if(x[i]==1):
        # select healthy
        dat_mrna.append(hc_mrna[int(np.random.choice(len(hc_mrna_tr)))])
        dat_mrna1.append(hc_mrna[int(np.random.choice(len(hc_mrna_tr)))])
    else:
        # select cancer
        dat_mrna.append(cancer_mrna[int(np.random.choice(len(cancer_mrna_tr)))])
        dat_mrna1.append(cancer_mrna[int(np.random.choice(len(cancer_mrna_tr)))])
        
dat_mrna = np.asarray(dat_mrna)
dat_mrna1 = np.asarray(dat_mrna1)

In [28]:
## Random pairing function
def genmRand(N=10000):
    x = np.random.randint(0,2,size=N)
    dat_mrna = []
    dat_mrna1 = []

    for i in range(N):
        if(x[i]==1):
            # select healthy
            # dat_exrna.append((hc_exrna[int(np.random.choice(len(hc_exrna)))] - mn_ex)/sd_ex)
            # dat_mrna.append((hc_mrna[int(np.random.choice(len(hc_mrna)))] - mn_m)/sd_m)
            dat_mrna.append(hc_mrna[int(np.random.choice(len(hc_mrna_tr)))])
            dat_mrna1.append(hc_mrna[int(np.random.choice(len(hc_mrna_tr)))])
        else:
            # select cancer
            # dat_exrna.append((cancer_exrna[int(np.random.choice(len(cancer_exrna_tr)))] - mn_ex)/sd_ex)
            # dat_mrna.append((cancer_mrna[int(np.random.choice(len(cancer_mrna_tr)))] - mn_m)/sd_m)
            dat_mrna.append(cancer_mrna[int(np.random.choice(len(cancer_mrna_tr)))])
            dat_mrna1.append(cancer_mrna[int(np.random.choice(len(cancer_mrna_tr)))])

    dat_mrna = np.asarray(dat_mrna)
    dat_mrna1 = np.asarray(dat_mrna1)
    return [dat_mrna,dat_mrna1]

In [32]:
class CrossPred_mm(torch.nn.Module):
    def __init__(self,input_dim_a=2000,input_dim_b=2000,emb_dim=128):
        super(CrossPred_mm, self).__init__()
        self.input_dim_a = input_dim_a
        self.input_dim_b = input_dim_b
#         self.clf_out = clf_out
        self.encoder_a = torch.nn.Sequential(
            torch.nn.Linear(self.input_dim_a, 1000),
            torch.nn.BatchNorm1d(1000),
            torch.nn.ReLU(),
            torch.nn.Linear(1000, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, emb_dim))
        self.encoder_b = torch.nn.Sequential(
            torch.nn.Linear(self.input_dim_b, 1000),
            torch.nn.BatchNorm1d(1000),
            torch.nn.ReLU(),
            torch.nn.Linear(1000, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, emb_dim))
#         self.clf = torch.nn.Sequential(
#             torch.nn.Linear(128, self.clf_out),
#             torch.nn.Softmax(dim=1))
#         self.feature = torch.nn.Sequential(
#             torch.nn.Linear(128, 32))
        
    def forward(self, x_a,x_b):
        out_a = self.encoder_a(x_a)
#         f_a = self.feature(out_a)
#         y_a = self.clf(out_a)
        
        out_b = self.encoder_b(x_b)
#         f_b = self.feature(out_b)
#         y_b = self.clf(out_b)
        return out_a,out_b
    
def CrossPred_trainer_mm(X_a, X_b, model, batch_size = 64, num_epoch=5, 
                        f_temp = 0.1, p_temp = 1.0):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    f_con = ContrastiveLoss(batch_size = batch_size,temperature = f_temp).to(device)
#     p_con = ContrastiveLoss(batch_size = model.clf_out,temperature = p_temp)
#     opt = torch.optim.SGD(model.parameters(),lr=0.01, momentum=0.9,weight_decay=5e-4)
    opt = torch.optim.Adam(model.parameters(),lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
    
    for k in range(num_epoch):
        
        model.to(device)
        
        X_a,X_b = genmRand()
        
        n = X_a.shape[0]
#         r = np.random.permutation(n)
#         X_train_a = X_a[r,:]
        X_tensor_A=torch.tensor(X_a).float()
#         X_train_b = X_b[r,:]
        X_tensor_B=torch.tensor(X_b).float()
        
        losses = 0
        
        for j in range(n//batch_size):
            opt.zero_grad()
            inputs_a = X_tensor_A[j*batch_size:(j+1)*batch_size,:].to(device)
#             inputs_a2 = inputs_a + torch.normal(0,1,inputs_a.shape).to(device)
#             inputs_a = inputs_a + torch.normal(0,1,inputs_a.shape).to(device)
            
            inputs_b = X_tensor_B[j*batch_size:(j+1)*batch_size,:].to(device)
#             inputs_b = inputs_b + torch.normal(0,1,inputs_b.shape).to(device)
            
            feas,nfeas = model(inputs_a,inputs_b)
#             feas2,o2,_,_ = model(inputs_a2,inputs_b)
        
            fea_mi = f_con(feas,nfeas)#+f_con(feas,feas2)
#             p_mi = p_con(o.T,no.T)+p_con(o.T,o2.T)
        
            
            loss = fea_mi # + p_mi 
#             opt.zero_grad()
            loss.backward()
            opt.step()
        
            del inputs_a
            del inputs_b
        
            losses += loss.data.tolist()
        print("Epoch: {}, Total loss: {}".format(k,str(round(losses,4))))
        
        del X_tensor_A
        del X_tensor_B
        
        if k % 20 == 0:
            torch.save({'model_state_dict': net_mm.state_dict()}, './models/lc_model_mm.pt')
            print('Saved!')
        
        gc.collect()

In [33]:
def weights_init(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform(m.weight.data)
        m.bias.data.zero_()

In [35]:
net_mm = CrossPred_mm(input_dim_a=dat_mrna.shape[1],
                         input_dim_b=dat_mrna1.shape[1])
net_mm.apply(weights_init)##initialize weights

  torch.nn.init.xavier_uniform(m.weight.data)


CrossPred_mm(
  (encoder_a): Sequential(
    (0): Linear(in_features=1927, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1000, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
  )
  (encoder_b): Sequential(
    (0): Linear(in_features=1927, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1000, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
  )
)

In [36]:
##3. Training process
net_mm.train()
CrossPred_trainer_mm(X_a = dat_mrna, X_b = dat_mrna1, model = net_mm, num_epoch=50)

Epoch: 0, Total loss: 962.2707
Saved!
Epoch: 1, Total loss: 722.3637
Epoch: 2, Total loss: 681.2195
Epoch: 3, Total loss: 668.2164
Epoch: 4, Total loss: 661.3396
Epoch: 5, Total loss: 657.5871
Epoch: 6, Total loss: 655.5517
Epoch: 7, Total loss: 653.729
Epoch: 8, Total loss: 651.6925
Epoch: 9, Total loss: 651.5757
Epoch: 10, Total loss: 650.9219
Epoch: 11, Total loss: 649.9463
Epoch: 12, Total loss: 649.6664
Epoch: 13, Total loss: 649.4958
Epoch: 14, Total loss: 648.9606
Epoch: 15, Total loss: 648.586
Epoch: 16, Total loss: 648.2118
Epoch: 17, Total loss: 648.2322
Epoch: 18, Total loss: 648.2641
Epoch: 19, Total loss: 648.3507
Epoch: 20, Total loss: 648.0683
Saved!
Epoch: 21, Total loss: 648.237
Epoch: 22, Total loss: 648.2074
Epoch: 23, Total loss: 647.9865
Epoch: 24, Total loss: 647.684
Epoch: 25, Total loss: 647.8979
Epoch: 26, Total loss: 647.8786
Epoch: 27, Total loss: 648.367
Epoch: 28, Total loss: 648.2547
Epoch: 29, Total loss: 647.9615
Epoch: 30, Total loss: 648.322
Epoch: 31,

In [37]:
torch.save({'model_state_dict': net_mm.state_dict()}, './models/lc_model_mm.pt')

## Encoder Model-CrossPred

In [23]:
def weights_init(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform(m.weight.data)
        m.bias.data.zero_()

In [24]:
from scipy import stats
from sklearn.neighbors import NearestNeighbors

knn_m = NearestNeighbors(n_neighbors=10,metric='cosine')
knn_ex = NearestNeighbors(n_neighbors=10,metric='cosine')


In [25]:
# Cross prediction

eps = 1e-10


def mToEx_multiple_tr(embs,X_knn_ex):
    dists,idxs = knn_ex.kneighbors(embs)
    dists = dists + eps

    out_ex = np.zeros((len(embs),X_knn_ex.shape[1]))
    for vi in range(len(embs)):
        for i in range(len(dists[vi])):
            out_ex[vi] += (1/dists[vi][i])*X_knn_ex[idxs[vi][i]]
        
        out_ex[vi] = out_ex[vi]/np.sum(1/dists[vi])

    return out_ex

def exToM_multiple_tr(embs,X_knn_m):
    dists,idxs = knn_m.kneighbors(embs)
    dists = dists + eps

    out_ex = np.zeros((len(embs),X_knn_m.shape[1]))
    for vi in range(len(embs)):
        for i in range(len(dists[vi])):
            out_ex[vi] += (1/dists[vi][i])*X_knn_m[idxs[vi][i]]
        
        out_ex[vi] = out_ex[vi]/np.sum(1/dists[vi])

    return out_ex

In [26]:
class CrossPred(torch.nn.Module):
    def __init__(self,input_dim_a=2000,input_dim_b=2000,emb_dim=128):
        super(CrossPred, self).__init__()
        self.input_dim_a = input_dim_a
        self.input_dim_b = input_dim_b
#         self.clf_out = clf_out
        self.encoder_a = torch.nn.Sequential(
            torch.nn.Linear(self.input_dim_a, 1000),
            torch.nn.BatchNorm1d(1000),
            torch.nn.ReLU(),
            torch.nn.Linear(1000, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, emb_dim))
        self.encoder_b = torch.nn.Sequential(
            torch.nn.Linear(self.input_dim_b, 1000),
            torch.nn.BatchNorm1d(1000),
            torch.nn.ReLU(),
            torch.nn.Linear(1000, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, emb_dim))
#         self.clf = torch.nn.Sequential(
#             torch.nn.Linear(128, self.clf_out),
#             torch.nn.Softmax(dim=1))
#         self.feature = torch.nn.Sequential(
#             torch.nn.Linear(128, 32))
        
    def forward(self, x_a,x_b):
        out_a = self.encoder_a(x_a)
#         f_a = self.feature(out_a)
#         y_a = self.clf(out_a)
        
        out_b = self.encoder_b(x_b)
#         f_b = self.feature(out_b)
#         y_b = self.clf(out_b)
        return out_a,out_b
    
def CrossPred_trainer(X_a, X_b, model, batch_size = 64, num_epoch=5, 
                        f_temp = 0.1, p_temp = 1.0):
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    f_con = ContrastiveLoss(batch_size = batch_size,temperature = f_temp).to(device)
#     p_con = ContrastiveLoss(batch_size = model.clf_out,temperature = p_temp)
#     opt = torch.optim.SGD(model.parameters(),lr=0.01, momentum=0.9,weight_decay=5e-4)
    opt = torch.optim.Adam(model.parameters(),lr=5e-5,betas=(0.9,0.98),eps=1e-6,weight_decay=0.2)
    
    s_m_max = 0
    s_ex_max = 0
    
    for k in range(num_epoch):
        
        model.train()
        model.to(device)
        
        X_a,X_b = genDataRand()
        
        # Genearte noise with same size as that of the data.
        for i in range(X_a.shape[1]):
            noise = np.random.normal(0,mrna_std[i],X_a.shape[0])
            X_a[:,i] = X_a[:,i]+noise*0.01
        for i in range(X_b.shape[1]):
            noise = np.random.normal(0,exrna_std[i],X_b.shape[0])
            X_b[:,i] = X_b[:,i]+noise*0.01
        
        n = X_a.shape[0]
#         r = np.random.permutation(n)
#         X_train_a = X_a[r,:]
        X_tensor_A=torch.tensor(X_a).float()
#         X_train_b = X_b[r,:]
        X_tensor_B=torch.tensor(X_b).float()
        
        losses = 0
        
        for j in range(n//batch_size):
            opt.zero_grad()
            inputs_a = X_tensor_A[j*batch_size:(j+1)*batch_size,:].to(device)
#             inputs_a2 = inputs_a + torch.normal(0,1,inputs_a.shape).to(device)
#             inputs_a = inputs_a + torch.normal(0,1,inputs_a.shape).to(device)
            
            inputs_b = X_tensor_B[j*batch_size:(j+1)*batch_size,:].to(device)
#             inputs_b = inputs_b + torch.normal(0,1,inputs_b.shape).to(device)
            
            feas,nfeas = model(inputs_a,inputs_b)
#             feas2,o2,_,_ = model(inputs_a2,inputs_b)
        
            fea_mi = f_con(feas,nfeas)#+f_con(feas,feas2)
#             p_mi = p_con(o.T,no.T)+p_con(o.T,o2.T)
        
            
            loss = fea_mi # + p_mi 
#             opt.zero_grad()
            loss.backward()
            opt.step()
        
            del inputs_a
            del inputs_b
        
            losses += loss.data.tolist()
                
        del X_tensor_A
        del X_tensor_B
        
        # validation codes
        model.eval()
        
        X_all_tensor_cm = torch.tensor(cancer_mrna_tr).float()
        X_all_tensor_hm = torch.tensor(hc_mrna_tr).float()

        X_all_tensor_cex = torch.tensor(cancer_exrna_tr).float()
        X_all_tensor_hex = torch.tensor(hc_exrna_tr).float()

        y_pred_cm = model.encoder_a(X_all_tensor_cm)
        y_pred_cm = torch.Tensor.cpu(y_pred_cm).detach().numpy()

        y_pred_hm = model.encoder_a(X_all_tensor_hm)
        y_pred_hm = torch.Tensor.cpu(y_pred_hm).detach().numpy()

        y_pred_cex = model.encoder_b(X_all_tensor_cex)
        y_pred_cex = torch.Tensor.cpu(y_pred_cex).detach().numpy()

        y_pred_hex = model.encoder_b(X_all_tensor_hex)
        y_pred_hex = torch.Tensor.cpu(y_pred_hex).detach().numpy()
        
        knn_m.fit(np.concatenate([y_pred_cm,y_pred_hm]))
        knn_ex.fit(np.concatenate([y_pred_cex,y_pred_hex]))

        X_knn_m = np.concatenate([X_all_tensor_cm,X_all_tensor_hm])
        X_knn_ex = np.concatenate([X_all_tensor_cex,X_all_tensor_hex])
        
        ##Val data
        X_all_tensor_cm_val = torch.tensor(cancer_mrna_val).float()
        X_all_tensor_hm_val = torch.tensor(hc_mrna_val).float()

        X_all_tensor_cex_val = torch.tensor(cancer_exrna_val).float()
        X_all_tensor_hex_val = torch.tensor(hc_exrna_val).float()

        y_pred_cm_val = model.encoder_a(X_all_tensor_cm_val)
        y_pred_cm_val = torch.Tensor.cpu(y_pred_cm_val).detach().numpy()

        y_pred_hm_val = model.encoder_a(X_all_tensor_hm_val)
        y_pred_hm_val = torch.Tensor.cpu(y_pred_hm_val).detach().numpy()

        y_pred_cex_val = model.encoder_b(X_all_tensor_cex_val)
        y_pred_cex_val = torch.Tensor.cpu(y_pred_cex_val).detach().numpy()
        
        y_pred_hex_val = model.encoder_b(X_all_tensor_hex_val)
        y_pred_hex_val = torch.Tensor.cpu(y_pred_hex_val).detach().numpy()
        
        dec_cm = exToM_multiple_tr(y_pred_cex_val,X_knn_m)
        dec_hm = exToM_multiple_tr(y_pred_hex_val,X_knn_m)

        dec_cex = mToEx_multiple_tr(y_pred_cm_val,X_knn_ex)
        dec_hex = mToEx_multiple_tr(y_pred_hm_val,X_knn_ex)
        
        pred_m = np.concatenate([dec_cm,dec_hm])
        true_m = np.concatenate([cancer_mrna_val,hc_mrna_val])
        sp_m = stats.spearmanr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0]
        p_m = stats.pearsonr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0]
        score_m = sp_m + p_m
        print(score_m)
#         print("MAE mRNA: ",mean_absolute_error(np.mean(pred_m,axis=0),
#                                        np.mean(true_m,axis=0)))
        
        pred_ex = np.concatenate([dec_cex,dec_hex])
        true_ex = np.concatenate([cancer_exrna_val,hc_exrna_val])
        sp_ex = stats.spearmanr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0]
        p_ex = stats.pearsonr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0]
        score_ex = sp_ex + p_ex
        print(score_ex)
#         print("MAE exRNA: ",mean_absolute_error(np.mean(pred_ex,axis=0),
#                                        np.mean(true_ex,axis=0)))
        
        print("Epoch: {}, Total loss: {}".format(k,str(round(losses,4))))
        
        
        if score_m > s_m_max and score_ex > s_ex_max:
            torch.save({'model_state_dict': net.state_dict()}, './models/lc_crosspred.pt')
            print('Saved!')
            print('\t',stats.spearmanr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0])
            print('\t',stats.pearsonr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0])
            print('\t',stats.spearmanr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0])
            print('\t',stats.pearsonr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0])

            s_m_max = score_m
            s_ex_max = score_ex
        
        if score_m >= 1.8 and score_ex >= 1.8:
            torch.save({'model_state_dict': net.state_dict()}, './models/lc_crosspred.pt')
            print('Final Saved!')
            print('\t',stats.spearmanr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0])
            print('\t',stats.pearsonr(np.mean(pred_m,axis=0),np.mean(true_m,axis=0))[0])
            print('\t',stats.spearmanr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0])
            print('\t',stats.pearsonr(np.mean(pred_ex,axis=0),np.mean(true_ex,axis=0))[0])
            break
        
        gc.collect()

In [27]:
##2. Use CrossPred for modality prediction
# clf_out = 25
net = CrossPred(input_dim_a=dat_mrna.shape[1],
                         input_dim_b=dat_exrna.shape[1])
net.apply(weights_init)##initialize weights

  torch.nn.init.xavier_uniform(m.weight.data)


CrossPred(
  (encoder_a): Sequential(
    (0): Linear(in_features=1927, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1000, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
  )
  (encoder_b): Sequential(
    (0): Linear(in_features=2537, out_features=1000, bias=True)
    (1): BatchNorm1d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=1000, out_features=128, bias=True)
    (4): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=128, out_features=128, bias=True)
  )
)

In [34]:
##3. Training process
net.train()
CrossPred_trainer(X_a = dat_mrna, X_b = dat_exrna, model = net, num_epoch=500)

1.8227679468350957
1.2266627111537232
Epoch: 0, Total loss: 954.2531
Saved!
	 0.9901286400497061
	 0.8326393067853896
	 0.5584052982839088
	 0.6682574128698144
1.8894344898055233
1.5435772340066123
Epoch: 1, Total loss: 731.8544
Saved!
	 0.9953565125664862
	 0.8940779772390371
	 0.6500339096928764
	 0.8935433243137357
1.8800565605391863
1.5487384278685523
Epoch: 2, Total loss: 687.0644
1.9581191084550615
1.6182111396021555
Epoch: 3, Total loss: 670.7786
Saved!
	 0.9979695407227763
	 0.9601495677322852
	 0.6826079481901445
	 0.935603191412011
1.9735126187723342
1.5367172118778663
Epoch: 4, Total loss: 663.1051
1.9165588979522834
1.5877580657703645
Epoch: 5, Total loss: 658.4034
1.9587150807756524
1.386659106723978
Epoch: 6, Total loss: 655.9229
1.894581252916891
1.6647043034188294
Epoch: 7, Total loss: 654.4669
1.8204225551063224
1.5674444285043099
Epoch: 8, Total loss: 652.9937
1.9720054603945747
1.4823111820511818
Epoch: 9, Total loss: 652.2509
1.8083160460453098
1.4699716427855694
Ep

KeyboardInterrupt: 

In [102]:
torch.save({'model_state_dict': net.state_dict()}, './models/lc_crosspred.pt')