In [11]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt

import scanpy as sc
import scvi
import torch

from pytorch_lightning.loggers import WandbLogger

import wandb

import session_info
import warnings
from pyprojroot.here import here

from dotenv import load_dotenv

sys.path.insert(1, str(here('bin')))

torch.set_float32_matmul_precision('high')

#plt.style.use(['science','nature','no-latex'])
dpi_fig_save = 300
sc.set_figure_params(dpi=100, dpi_save=dpi_fig_save, vector_friendly=True)

# Setting some parameters
warnings.filterwarnings("ignore")

overwriteData = True
overwriteFigures = True

load_dotenv(here('.env'))

True

In [12]:
class CustomWandbLogger(WandbLogger):
    @property
    def save_dir(self):
        """Gets the save directory.

        Returns:
            The path to the save directory.

        """
        return self.experiment.dir

In [3]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

Seed set to 0


Last run with scvi-tools version: 1.1.2


**Load data**

In [4]:
# Load the h5ad file
adata = sc.read_h5ad(here("01_data_processing/SCGT00_CentralizedDataset/results/2_SCGT00_MAIN_normalized_HVGsubset.h5ad"), backed='r')
adata

AnnData object with n_obs × n_vars = 855417 × 3126 backed at '/scratch_isilon/groups/singlecell/shared/projects/Inflammation-PBMCs-Atlas/01_data_processing/SCGT00_CentralizedDataset/results/2_SCGT00_MAIN_normalized_HVGsubset.h5ad'
    obs: 'studyID', 'libraryID', 'sampleID', 'chemistry', 'technology', 'patientID', 'disease', 'timepoint_replicate', 'treatmentStatus', 'therapyResponse', 'sex', 'age', 'BMI', 'binned_age', 'diseaseStatus', 'smokingStatus', 'ethnicity', 'institute', 'diseaseGroup', 'batches', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'total_counts_plt', 'log1p_total_counts_plt', 'pct_counts_plt', 'S_score', 'G2M_score', 'phase'
    var: 'hgnc_id', 'symbol', 'locus_group', 'HUGO_status', 'mt', 'ribo', 'hb', 'plt', 'n

In [7]:
adata.obs.studyID

cellID
SCGT00_L051_I57.3P_T0_AAACCCAAGGTGAGAA           SCGT00
SCGT00_L051_I56.3P_T0_AAACCCAAGTCCGTCG           SCGT00
SCGT00_L051_I53.3P_T0_AAACCCAAGTGCACTT           SCGT00
SCGT00_L051_I52.3P_T0_AAACCCACAACTGTGT           SCGT00
SCGT00_L051_I56.3P_T0_AAACCCACAAGAATGT           SCGT00
                                                ...    
SCGT00val_L003_I0362_T0_TTTGTTGTCCGTCAAA      SCGT00val
SCGT00val_L003_I036018_T0_TTTGTTGTCGGTTGTA    SCGT00val
SCGT00val_L003_I0361_T0_TTTGTTGTCTACGCAA      SCGT00val
SCGT00val_L003_I036018_T0_TTTGTTGTCTCTGACC    SCGT00val
SCGT00val_L003_I036031_T0_TTTGTTGTCTTTGGAG    SCGT00val
Name: studyID, Length: 855417, dtype: category
Categories (2, object): ['SCGT00', 'SCGT00val']

## scVI integration

#### Parameters

In [8]:
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.train.Trainer.html#scvi.train.Trainer
setup_kwargs = dict(
    layer="counts", 
    batch_key='libraryID', 
    categorical_covariate_keys = ['disease', 'sampleID'],
)

scvi_kwargs = dict(n_hidden=512,
                   n_latent=30,
                   n_layers=2,
                   gene_likelihood='nb')

trainer_kwargs = dict(
    checkpointing_monitor = 'elbo_validation',
    early_stopping_monitor = 'reconstruction_loss_validation',
    early_stopping_patience = 10,
    early_stopping_min_delta=0.1,
    early_stopping = True,
    max_epochs = 1000,

    #logger = # wandb
)
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.train.TrainingPlan.html#scvi.train.TrainingPlan
plan_kwargs = dict(
    lr = 5e-4,
    #reduce_lr_on_plateau = True
)

# https://docs.scvi-tools.org/en/stable/api/reference/scvi.module.VAE.html#scvi.module.VAE
#vae = dict(
#    use_layer_norm='both',
#    use_batch_norm='none',
#    encode_covariates=True,
#    deeply_inject_covariates=False
#)
parameter_dict = setup_kwargs | scvi_kwargs | trainer_kwargs | plan_kwargs

In [9]:
parameter_dict

{'layer': 'counts',
 'batch_key': 'libraryID',
 'categorical_covariate_keys': ['disease', 'sampleID'],
 'n_hidden': 512,
 'n_latent': 30,
 'n_layers': 2,
 'gene_likelihood': 'nb',
 'checkpointing_monitor': 'elbo_validation',
 'early_stopping_monitor': 'reconstruction_loss_validation',
 'early_stopping_patience': 10,
 'early_stopping_min_delta': 0.1,
 'early_stopping': True,
 'max_epochs': 1000,
 'lr': 0.0005}

In [10]:
run_name = f"Step00_COV{'_'.join(parameter_dict['categorical_covariate_keys'])}"
run_name

'Step00_COVdisease_sampleID'

In [None]:
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI.train


In [14]:
scvi.model.SCVI.setup_anndata(adata, 
                              **setup_kwargs)

In [15]:
logger = CustomWandbLogger(name = run_name, project='inflammation_atlas_R1', config = parameter_dict)

Wd decided to use gene-label because-...
https://discourse.scverse.org/t/what-model-to-use-when-integrating-batches-of-scrna-seq-matrices-containing-150-000-t-and-innate-lymphoid-cell-ilc-sub-populations/454/7

In [16]:
model = scvi.model.SCVI(adata, **scvi_kwargs)

In [17]:
model.train(logger=logger, plan_kwargs = plan_kwargs, **trainer_kwargs)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mdav1989[0m ([33minflammation[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 38/1000:   4%|█▊                                              | 38/1000 [40:06<16:55:10, 63.32s/it, v_num=humh, train_loss_step=730, train_loss_epoch=716]
Monitored metric reconstruction_loss_validation did not improve in the last 10 records. Best score: 708.102. Signaling Trainer to stop.


In [18]:
wandb.finish()

0,1
elbo_train,█▇▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
elbo_validation,█▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
kl_global_train,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
kl_global_validation,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
kl_local_train,██▅▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
kl_local_validation,█▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
kl_weight,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
reconstruction_loss_train,█▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
reconstruction_loss_validation,█▅▄▃▃▂▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
elbo_train,768.26147
elbo_validation,764.43066
epoch,37.0
kl_global_train,0.0
kl_global_validation,0.0
kl_local_train,57.95324
kl_local_validation,56.35727
kl_weight,0.0925
reconstruction_loss_train,710.30823
reconstruction_loss_validation,708.07336


## Save the results

In [20]:
if overwriteData:
    model.save(here(f"01_data_processing/SCGT00_CentralizedDataset/results/3_SCGT00_MAIN_HVGsubset_scVI_step00/"), 
               overwrite = True, 
               save_anndata = False)

In [21]:
session_info.show(excludes=['google3'],)