In [1]:
import torch
from torch import nn
import pandas as pd
import numpy as np
from scipy import sparse
import time

import matplotlib.pyplot as plt

from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from torch_lr_finder import LRFinder

from utils import *



In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print('Using GPU')
else:
    device = torch.device("cpu")

Using GPU


## Files

In [3]:
DATASET = 'cite'

In [4]:
data_dir = '/home/artemy/multimodal_proj/data/competition/'

log_dir = '/home/artemy/multimodal_proj/cur_model_tb'

o_dir = '/home/artemy/multimodal_proj/data/AE_predictions/'
pred_file = o_dir + "multi_gex_pred.npy"

In [5]:
if DATASET == 'multi':
    inputs_train_fn = data_dir + "atac_train.sparse.npz"
    targets_train_fn = data_dir + "gex_train.sparse.npz"
    inputs_test_fn = data_dir + "atac_test.sparse.npz"
elif DATASET == 'cite':
    inputs_train_fn = data_dir + "cite_gex_train.sparse.npz"
    targets_train_fn = data_dir + "cite_adt_train.sparse.npz"
    inputs_test_fn = data_dir + "cite_gex_test.sparse.npz"


## Model:

In [6]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01) 

class Encoder(nn.Module, HyperParameters):
    def __init__(self,
                input_dim: int,
                out_dims: list = [2000, 2000]
    ):
        super().__init__()
        self.save_hyperparameters()
        self.n_layers = len(out_dims)
        
        modules = []
        for i in range(self.n_layers):
            input_dim = self.input_dim if i == 0 else self.out_dims[i-1]
            modules.append(nn.Linear(input_dim, self.out_dims[i]))
            modules.append(nn.BatchNorm1d(num_features=self.out_dims[i]))            
            modules.append(nn.LeakyReLU(0.2))
            
            
            
        self.Encoder = nn.Sequential(*modules) 
        self.Encoder.apply(init_weights)
    
    def forward(self, x):
        out = self.Encoder(x)
        return out
    
    
class Decoder(nn.Module, HyperParameters):
    def __init__(self,
                out_dim: int,
                input_dims: list = [2000, 2000],
                
    ):
        super().__init__()
        self.save_hyperparameters()
        
        self.n_layers = len(input_dims)
        
        modules = []
        for i in range(self.n_layers):
            out_dim = self.out_dim if i == self.n_layers-1 else self.input_dims[i+1]
            modules.append(nn.Linear(self.input_dims[i], out_dim))
            if i < self.n_layers-1:
                modules.append(nn.BatchNorm1d(num_features=out_dim))
                modules.append(nn.LeakyReLU(0.2))
            modules.append(nn.Softplus())            
            
        self.Decoder = nn.Sequential(*modules) 
        self.Decoder.apply(init_weights)
        
    def forward(self, x):
        out = self.Decoder(x)
        return out

In [7]:
class AE(nn.Module, HyperParameters):
    def __init__(self,
                n_atac_features: int,
                n_rna_features: int = None,
                encoder_dims: list = [1000, 1000],
                decoder_dims: list = [1000, 1000]
                ):
        super().__init__()
        self.save_hyperparameters()
        self.encoder = Encoder(n_atac_features, out_dims=encoder_dims)
        #self.decoder_atac = Decoder(n_atac_features, input_dims=decoder_dims)
        self.decoder_rna = Decoder(n_rna_features, input_dims=decoder_dims)
        
    def forward(self, x):
        out = self.encoder(x)
        #atac_recon = self.decoder_atac(out)
        atac_recon = 0
        rna_recon = self.decoder_rna(out)
        return atac_recon, rna_recon
    
    def loss(self, y, y_hat):
        loss_fn = nn.MSELoss()
        l = loss_fn(y_hat, y)
        return l
    
    def weighted_loss(self, y, y_hat):
        y = y.flatten()
        y_hat = y_hat.flatten()
        non_zero_y = y[torch.ne(y_hat, 0)]
        non_zero_y_hat = y_hat[torch.ne(y_hat, 0)]
        zero_y = y[torch.eq(y_hat, 0)]
        zero_y_hat = y_hat[torch.eq(y_hat, 0)]
        l = torch.mean((zero_y_hat-zero_y) ** 2 / 2  * 1/10) 
        l += torch.mean((non_zero_y_hat-non_zero_y) ** 2 / 2  * 9/10 )
        return l
    
    def correl_loss(self, y, y_hat):
        l = -spearman_cor(y_hat, y)
        return l
    
    def training_step(self, inputs, targets, calculate_cor=True):
        atac_recon, rna_recon = self.forward(inputs)
        #loss_atac = self.loss(atac_recon, inputs)
        loss_atac = 1
        loss_rna = self.loss(rna_recon, targets)
        if calculate_cor:
            cor = spearman_cor(rna_recon, targets)
        else:
            cor=0
        return loss_atac, loss_rna, cor
    
    def validation_step(self, inputs, targets, calculate_cor=True):
        atac_recon, rna_recon = self.forward(inputs)
        #loss_atac = self.loss(atac_recon, inputs)
        loss_atac = 1
        loss_rna = self.loss(rna_recon, targets)
        if calculate_cor:
            cor = spearman_cor(rna_recon, targets)
        else:
            cor=0
        return loss_atac, loss_rna, cor
    
    def predict(self, inputs):
        atac_recon, rna_recon = self.forward(inputs)
        return atac_recon, rna_recon


In [8]:
if DATASET == 'multi':
    model_params = {'encoder_dims': [2000, 2000],
                    'decoder_dims': [2000, 2000]}
elif DATASET == 'cite':
    model_params = {'encoder_dims': [5000, 3000],
                    'decoder_dims': [3000, 2000]}

## Train model

In [9]:
trainer_params = {'batch_size': 2048,
                  'use_schedule': True,
                  'inputs_fn': inputs_train_fn,
                  'targets_fn': targets_train_fn,
                  'device': device, 
                  'wd': 2e-2
                 }

if DATASET == 'multi':
    trainer_params['max_epochs'] = 20
    trainer_params['lr'] = 1e-3
elif DATASET == 'cite':
    trainer_params['max_epochs'] = 12
    trainer_params['lr'] = 5e-5

In [10]:
train_model = True
if train_model:
    writer = SummaryWriter(log_dir=log_dir)

    t = time.time()
    trainer = Trainer(**trainer_params, writer = writer)
    trainer.fit(AE, model_params, do_validation=True, subset_train=-1)

    elapsed_time = time.time()-t
    print('Hours: %s' % (elapsed_time // (60 ** 2)),
         'Minutes: %s' % (elapsed_time % 60 ** 2 // 60), sep='\n')

Using TensorBoard for output


100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [01:25<00:00,  7.09s/it]

Training loss: 1.96
Training cor: 0.91
Validation loss: 2.68
Validation cor: 0.88
Hours: 0.0
Minutes: 1.0





In [11]:
save_model = False
if save_model:
    trainer.save_model(f'{o_dir}model_{DATASET}_full.pt')

In [12]:
plot_progress_True = False
if plot_progress_True:
    plot_progress(trainer)

In [13]:
analyze_pred = False
if analyze_pred:
    trainer = Trainer(**trainer_params)
    trainer.train_loader, trainer.val_loader = trainer.load_data('train', subset_train=2048)
    trainer.load_model(AE, f'{o_dir}model_{DATASET}_full.pt', model_params)
    atac_pred, rna_pred, atac_orig, rna_orig = trainer.analyze_model(AE)
    plot_model_analysis(atac_pred, rna_pred, atac_orig, rna_orig)
    
    


## RUN LRRT

In [14]:
def run_lrrt(batch_size=2048, min_lr=1e-7, wd=2e-2, validation=False, device='cuda'):
    def criterion(predicted, orig):
        loss_fn = nn.MSELoss()
        l = loss_fn(orig, predicted[1])
        return l

    trainer = Trainer(batch_size=batch_size,
                      inputs_fn=inputs_train_fn,
                      targets_fn=targets_train_fn)
    train_loader, val_loader = trainer.load_data('train', subset_train=-1)
    model = AE(train_loader.n_input_features, train_loader.n_target_features)
    optimizer = torch.optim.AdamW(model.parameters(), lr=min_lr, weight_decay=wd)
    lr_finder = LRFinder(model, optimizer, criterion, device=device)
    if validation:
        lr_finder.range_test(train_loader, end_lr=10, num_iter=100, val_loader=val_loader)
    else:
        lr_finder.range_test(train_loader, end_lr=10, num_iter=100)
        
    lr_finder.plot(log_lr=True)

In [15]:
lrrt_on_train = False
if lrrt_on_train:
    run_lrrt()

In [16]:
lrrt_on_val = False
if lrrt_on_val:
    run_lrrt(validation=True)


## Make predictions

In [17]:
make_predictions = False
if make_predictions:
    outputs = trainer.test_model(inputs_test_fn)

    with open(pred_file, 'wb') as f:
        np.save(f, outputs)