In [None]:
for v in ['n_latent','batch', 'annotation']:
    if v in locals() or v in globals():
        print(f"{v} = {eval(v)}")
    else:
        raise ValueError(f"{v} is not defined")

In [1]:
import os
import sys

import random 
import numpy as np
import matplotlib.pyplot as plt

import scanpy as sc
import scvi
import scgen

import torch
from lightning.pytorch.loggers import CSVLogger

import pandas as pd

import session_info
import warnings
from pyprojroot.here import here

torch.set_float32_matmul_precision('medium')

random_seed = 42

#scvi.settings.dl_num_workers = 8
scvi.settings.seed = random_seed
print("scvi-tools version:", scvi.__version__)

Seed set to 42


scvi-tools version: 1.1.2


## Loading main adata

In [None]:
here()

In [None]:
adataM = sc.read_h5ad(here("03_downstream_analysis/02_gene_universe_definition/results/04_MAIN_geneUniverse.log1p.h5ad"))
adataM

### scGen preprocessing

In [None]:
adataM.obs['batch'] = adataM.obs[batch].tolist()
adataM.obs['cell_type'] = adataM.obs[annotation].tolist()
scgen.SCGEN.setup_anndata(adataM, batch_key="batch", labels_key="cell_type")

#### scGen parameters

In [None]:
scgen_model_params = dict({
    'n_latent': n_latent,
    'n_hidden': 800,
    'n_layers': 3,
    'dropout_rate': 0.2,
})

In [None]:
train_params = dict({
    'max_epochs':1000,
    'batch_size':128, 
    'early_stopping': True, 
    'early_stopping_patience':5, 
    'log_every_n_steps':10000,
    #'plan_kwargs':dict({'lr':0.001}) # default learning rate
})

In [None]:
model_scGen = scgen.SCGEN(adataM, **scgen_model_params)
model_scGen.view_anndata_setup()

### Training the model

In [None]:
model_scGen.train(**train_params)

### Plotting loss functions

In [None]:
plt.plot(model_scGen.history['reconstruction_loss_train'][1:], label = 'reconstruction_loss_train')
plt.plot(model_scGen.history['reconstruction_loss_validation'][1:], label = 'reconstruction_loss_validation')
plt.title('')
plt.legend()
plt.show()

In [None]:
plt.plot(model_scGen.history['kl_local_train'][1:], label = 'kl_local_train')
plt.plot(model_scGen.history['kl_local_validation'][1:], label = 'kl_local_validation')
plt.title('')
plt.legend()
plt.show()

In [None]:
plt.plot(model_scGen.history['elbo_train'][1:], label = 'elbo_train')
plt.plot(model_scGen.history['elbo_validation'][1:], label = 'elbo_validation')
plt.title('')
plt.legend()
plt.show()

### Save the results

**scGen model**

In [None]:
model_scGen.save(here(f"03_downstream_analysis/08_PatientClassifier/scGen/results/01_scGen_train_nLat{n_latent}_{batch}_{annotation}"), 
                 overwrite = True, save_anndata = False)