In [1]:
import os
import sys

import random 
import numpy as np
import matplotlib.pyplot as plt

import scanpy as sc
import scvi
import scgen

import torch
from lightning.pytorch.loggers import CSVLogger

import pandas as pd

import session_info
import warnings
from pyprojroot.here import here

torch.set_float32_matmul_precision('medium')

random_seed = 42

#scvi.settings.dl_num_workers = 8
scvi.settings.seed = random_seed
print("scvi-tools version:", scvi.__version__)

Seed set to 42


scvi-tools version: 1.1.2


### Parameters

In [2]:
n_latent = 30
batch_key = 'chemistry'
annotation = 'Level1'

## Loading main adata

In [4]:
adataM = sc.read_h5ad(here("03_downstream_analysis/02_gene_universe_definition/results/04_MAIN_geneUniverse_noRBCnPlatelets.log1p.h5ad"))
adataM

AnnData object with n_obs × n_vars = 4279352 × 8253
    obs: 'studyID', 'libraryID', 'sampleID', 'chemistry', 'disease', 'sex', 'binned_age', 'Level1', 'Level2'
    var: 'hgnc_id', 'symbol', 'locus_group', 'HUGO_status', 'highly_variable'
    uns: 'log1p'

### scGen preprocessing

In [5]:
adataM.obs['batch'] = adataM.obs[batch_key].tolist()
adataM.obs['cell_type'] = adataM.obs[annotation].tolist()
scgen.SCGEN.setup_anndata(adataM, batch_key="batch", labels_key="cell_type")

#### scGen parameters

In [6]:
scgen_model_params = dict({
    'n_latent': n_latent,
    'n_hidden': 800,
    'n_layers': 3,
    'dropout_rate': 0.2,
})

In [7]:
train_params = dict({
    'max_epochs':1000,
    'batch_size':128, 
    'early_stopping': True, 
    'early_stopping_patience':5, 
    'log_every_n_steps':10000,
    #'plan_kwargs':dict({'lr':0.001}) # default learning rate
})

In [8]:
model_scGen = scgen.SCGEN(adataM, **scgen_model_params)
model_scGen.view_anndata_setup()

### Training the model

In [9]:
model_scGen.train(**train_params)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
  self.pid = os.fork()


Epoch 1/1000:   0%|                                                                                                                   | 0/1000 [00:00<?, ?it/s]

  self.pid = os.fork()
/scratch_isilon/groups/singlecell/shared/conda_env/scvi-v112/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


### Plotting loss functions

In [None]:
plt.plot(model_scGen.history['reconstruction_loss_train'][1:], label = 'reconstruction_loss_train')
plt.plot(model_scGen.history['reconstruction_loss_validation'][1:], label = 'reconstruction_loss_validation')
plt.title('')
plt.legend()
plt.show()

In [None]:
plt.plot(model_scGen.history['kl_local_train'][1:], label = 'kl_local_train')
plt.plot(model_scGen.history['kl_local_validation'][1:], label = 'kl_local_validation')
plt.title('')
plt.legend()
plt.show()

In [None]:
plt.plot(model_scGen.history['elbo_train'][1:], label = 'elbo_train')
plt.plot(model_scGen.history['elbo_validation'][1:], label = 'elbo_validation')
plt.title('')
plt.legend()
plt.show()

### Save the results

**scGen model**

In [None]:
model_scGen.save(here('03_downstream_analysis/04_integration_with_annotation/results/scGen_model_noRBCnPlat'), overwrite = True, save_anndata = False)