In [None]:
import os
import time
import datetime
import numpy as np
from collections import OrderedDict
import tqdm
import argparse
import pickle
import warnings

import scanpy as sc

import torch
import torch.nn as nn
import torch.nn.functional as F

import sys
sys.path.append(prj_path)
from drscax.scripts import models as tcgamodels
from drscax.scripts import eval_ as tcgaeval
from drscax.scripts import data as tcgadata
from drscax.scripts import train as tcgatrain

In [None]:
job_name='v062_woimmune_bst8layer50k_221012_165814'
data_version='cancer_only'
block='bst8layer50k'
approach='2'
model='cAE'
beta=0.0
nocondition=True
binarizeinput=False
initial_lr=0.001
batch_size=1024
nolrscheduler=False
layer_norm=True
inject_c1_eachlayer=False

**run**

In [None]:
# filepaths
scratch_path = ''
prj_path = ''

mfp = os.path.join(scratch_path, '/model_zoo/dev/')

In [None]:

bst_mdl_chkpt = os.path.join(mfp, '{}.pt'.format(job_name)) # ~/model_zoo/dev/v062_woimmune_bst8layer50k_221012_165814.pt

In [1]:


################ loading the training data ################

mfp = os.path.join(scratch_path, '/model_zoo/dev/')
lfp = os.path.join(prj_path, '/experiments/run_logs') # log filepath

if block == 'default':
    block = [50000, 4096, 2048, 1024, 512]
elif block == 'bst8layer50k':
    block = [50000, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024]
else:
    warnings.warn('invalid block architecture selected. using default')
    block = [50000, 4096, 2048, 1024, 512]


adata = sc.read(os.path.join(scratch_path, 'PublicationPage/tcga_canceronly_top50klogtfidf_221011.h5ad'))
criterion = tcgatrain.VAELoss(beta=beta, reconstruction='cont') if model=='cVAE' else nn.MSELoss(reduction='sum')


topn = block[0]


############ Initializing the model ############


model == 'cAE':

net = tcgamodels.cAE_v05(
        layer_io=block,
        layer_norm=layer_norm,
        n_c1_class=0 if nocondition else len(adata.obs['batch'].unique()),
        c1_embed_dim=0 if nocondition else 8,
        inject_c1_eachlayer=inject_c1_eachlayer,
        sigmoid_out=binarizeinput,
        return_latent=False,
    )


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    torch.cuda.empty_cache()

net = net.to(device)
print(net)

save_name = '{}v05_{}_noc{}_allL{}_lr{}_topn{}_LN{}_n{}'.format(
     model,
     block,
     nocondition,
     inject_c1_eachlayer,
     initial_lr,
     topn,
     layer_norm,
     trial)
save_name = os.path.join(mfp, '{}.pt'.format(job_name))

############### Model training ###################

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), 
                             lr=initial_lr, 
                             weight_decay = 0.0001)
if not nolrscheduler:
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer)

def adjust_lr(optimizer, decay_rate=0.95):
    for param_group in optimizer.param_groups:
        param_group['lr'] *= decay_rate

max_epochs = 500
early_stop = tcgatrain.EarlyStop(patience=50, save_name=save_name)
print("training on ", device)

log = {'save_name': save_name,
       '<loss_train>': [],
       '<loss_val>': [],
       'hyperparameters': {
           'beta': beta,
           'layer_norm': layer_norm,
           'blocks': block,
           'noconditioning': nocondition,
           'approach': approach,
           'model': model,
           # 'hidden_size': hidden_size,
           # 'enc_out_dim': enc_out_dim,
           'initial_lr': initial_lr,
           # 'nb_layers': nb_layers,
           'no_lr_scheduler': nolrscheduler,
           'batch_size': batch_size,
           # 'args': args,
       },
      }
for epoch in range(max_epochs):

     # train
     train_loss, n, start = 0.0, 0, time.time()
     net.train()
     for X, y, X_div_cnv, idx in tqdm.tqdm(train_dl):
         X = X.squeeze().to(device) # squeeze may not be needed
         X_div_cnv = X_div_cnv.squeeze().to(device)
         y = y.squeeze().to(device)

         if approach == '1':
             if nocondition:
                 output = net(X_div_cnv)
             else:
                 output = net(X_div_cnv, y)
         elif approach == '2':
             if nocondition:
                 output = net(X)
             else:
                 output = net(X, y)

         if model == 'cVAE':
             Xhat, mean, logvar = output
         elif model == 'cAE':
             Xhat = output
         del output

         l = criterion(X_div_cnv, Xhat, mean, logvar).to(device) if model=='cVAE' else criterion(Xhat, X_div_cnv).to(device)
         optimizer.zero_grad()
         l.backward()
         optimizer.step()

         train_loss += l.cpu().item()
         n += X.shape[0]

     train_loss /= n
     log['<loss_train>'].append(train_loss)

     # val
     val_loss, n = 0.0, 0
     net.eval()
     for X, y, X_div_cnv, idx in tqdm.tqdm(val_dl):
         X = X.squeeze().to(device)
         X_div_cnv = X_div_cnv.squeeze().to(device)
         y = y.squeeze().to(device)

         if approach == '1':
             if nocondition:
                 output = net(X_div_cnv)
             else:
                 output = net(X_div_cnv, y)
         elif approach == '2':
             if nocondition:
                 output = net(X)
             else:
                 output = net(X, y)

         if model == 'cVAE':
             Xhat, mean, logvar = output
         elif model == 'cAE':
             Xhat = output
         del output

         l = criterion(X_div_cnv, Xhat, mean, logvar).to(device) if model=='cVAE' else criterion(X_div_cnv, Xhat).to(device)

         val_loss += l.cpu().item()
         n += X.shape[0]

     val_loss /= n
     log['<loss_val>'].append(val_loss)


     print('epoch %d\ttrain loss: %.4f\tval loss: %.4f\ttime: %.1f-s'
           % (epoch, train_loss, val_loss, time.time() - start))

     if not nolrscheduler:
            adjust_lr(optimizer)
            scheduler.step(val_loss)

     if (early_stop(val_loss, net, optimizer)):
         break


