In [1]:
import os
import sys

import json

import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt

import torch

from dotenv import load_dotenv

from lightning.pytorch.loggers import WandbLogger
import wandb

import session_info
import warnings
from pyprojroot.here import here

#plt.style.use(['science','nature','no-latex'])
dpi_fig_save = 300
sc.set_figure_params(dpi=100, dpi_save=dpi_fig_save, vector_friendly=True)

# Setting some parameters
warnings.filterwarnings("ignore")

from sklearn.model_selection import StratifiedKFold

overwriteData = True
overwriteFigures = True

# Set random seed
random_seed = 42

import scvi
scvi.settings.dl_num_workers = 8
scvi.settings.seed = random_seed

import warnings
warnings.filterwarnings('ignore')

torch.set_float32_matmul_precision('high')
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
torch.multiprocessing.set_sharing_strategy('file_system')

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
 captum (see https://github.com/pytorch/captum).
INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


In [2]:
assert load_dotenv()

In [3]:
workingDir = here('03_downstream_analysis/04_integration_with_annotation/')

In [4]:
class CustomWandbLogger(WandbLogger):
    @property
    def save_dir(self):
        """Gets the save directory.

        Returns:
            The path to the save directory.

        """
        return self.experiment.dir

In [5]:
scvi.__version__

'1.1.2'

# Loading data


In [6]:
# Load the h5ad file
adata = sc.read_h5ad(here("03_downstream_analysis/04_integration_with_annotation/04_MAIN_geneUniverse_noRBCnPlatelets.h5ad"))#, 
#                     backed='r+', chunk_size=50000)


In [7]:
adata.obs['binned_age'] = adata.obs['binned_age'].astype(str)

In [8]:
genes = adata.var

In [9]:
# Retrieve MT and RB genes present in the dataset
MT_gene_idx = genes["symbol"].str.startswith("MT-")
print(f"{np.sum(MT_gene_idx)} mitochondrial genes")

RB_gene_idx = genes["symbol"].str.startswith(("RPS", "RPL"))
print(f"{np.sum(RB_gene_idx)} ribosomal genes")

# Retrieve TCR and BCR present in the dataset
TCR_gene_idx = genes["symbol"].str.contains("^TRA(J|V)|^TRB(J|V|D)")
print(f"{np.sum(TCR_gene_idx)} TCR genes")

BCR_gene_idx = genes["symbol"].str.contains("^IGH(J|V)")
print(f"{np.sum(BCR_gene_idx)} TCR genes")

HB_gene_idx = genes["symbol"].str.contains("^HB[^(P)]")
print(f"{np.sum(HB_gene_idx)} HB genes")

# Some of those genes will be included anyway because are part of curated gene sets.
MHC_gene_idx = genes["symbol"].str.contains("^HLA-")
print(f"{np.sum(MHC_gene_idx)} MHC genes")

PLT_gene_idx = genes["symbol"].isin(["PPBP", "PDGF", "ANG1", "LAPTM4B", "WASF3", "TPM3", "PF4", "TAC1"])
print(f"{np.sum(PLT_gene_idx)} PLT genes")

0 mitochondrial genes
0 ribosomal genes
0 TCR genes
0 TCR genes
1 HB genes
20 MHC genes
3 PLT genes


**Parameters**

In [None]:
setup_kwargs = dict(
    layer=None, 
    batch_key='chemistry',
    categorical_covariate_keys = ['libraryID','studyID','sex','binned_age'],
    labels_key = 'Level1' # needed for the following scANVI fine tuning
)

scvi_kwargs = dict(n_hidden=256,
                   n_latent=30,
                   n_layers=4, 
                   gene_likelihood='nb',
                   dispersion='gene-batch')

trainer_kwargs = dict(
    checkpointing_monitor = 'elbo_validation',
    early_stopping_monitor = 'reconstruction_loss_validation',
    early_stopping_patience = 2,
    early_stopping_min_delta=0.1,
    early_stopping = True,
    max_epochs = 1000,

    #logger = # wandb
)
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.train.TrainingPlan.html#scvi.train.TrainingPlan
plan_kwargs = dict(
    lr = 5e-4,
    #reduce_lr_on_plateau = True
)

# https://docs.scvi-tools.org/en/stable/api/reference/scvi.module.VAE.html#scvi.module.VAE
#vae = dict(
#    use_layer_norm='both',
#    use_batch_norm='none',
#    encode_covariates=True,
#    deeply_inject_covariates=False
#)
datasplitter_kwargs = dict(pin_memory=True)
scvi_parameter_dict = setup_kwargs | scvi_kwargs | trainer_kwargs | plan_kwargs| datasplitter_kwargs

In [None]:
run_name = f"MAINobj_scVI_pretraining_noRBCnPlat"
run_name

In [None]:
sca.models.SCVI.setup_anndata(adata, 
                              **setup_kwargs)

In [None]:
logger = CustomWandbLogger(name = run_name, project='inflammation_atlas_R1_scANVI', config = scvi_parameter_dict)

In [None]:
scvi_model = sca.models.SCVI(adata, **scvi_kwargs)

In [None]:
scvi_model.train(logger=logger, datasplitter_kwargs=datasplitter_kwargs, plan_kwargs = plan_kwargs, **trainer_kwargs)

In [None]:
wandb.finish()

In [None]:
if overwriteData:
    scvi_model.save(here(f"{workingDir}/results/scVI_model_pretreined_noRBCnPlat/"), 
               overwrite = True, 
               save_anndata = False)

In [None]:
scvi_emb = scvi_model.get_latent_representation(adata=adata)

In [None]:
np.savez_compressed(file = str(here(f"{workingDir}/results/scVI_model_pretreined_noRBCnPlat/scVI_embedding.npz")), arr=scvi_emb)