In [None]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt

import scanpy as sc
import scvi
import torch

from pytorch_lightning.loggers import WandbLogger

import wandb

import session_info
import warnings
from pyprojroot.here import here

from dotenv import load_dotenv

sys.path.insert(1, str(here('bin')))

torch.set_float32_matmul_precision('high')

#plt.style.use(['science','nature','no-latex'])
dpi_fig_save = 300
sc.set_figure_params(dpi=100, dpi_save=dpi_fig_save, vector_friendly=True)

# Setting some parameters
warnings.filterwarnings("ignore")

overwriteData = True
overwriteFigures = True

load_dotenv(here('.env'))

In [None]:
class CustomWandbLogger(WandbLogger):
    @property
    def save_dir(self):
        """Gets the save directory.

        Returns:
            The path to the save directory.

        """
        return self.experiment.dir

In [None]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

**Load data**

In [None]:
# Load the h5ad file
adata = sc.read_h5ad(here("01_data_processing/results/03_INFLAMMATION_main_normalized_HVGsubset.h5ad"), backed='r')
adata

## scVI integration

#### Parameters

In [None]:
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.train.Trainer.html#scvi.train.Trainer
setup_kwargs = dict(
    layer="counts", 
    batch_key='libraryID', 
    categorical_covariate_keys = ['chemistry', 'studyID', 'disease', 'sampleID'],
    labels_key = 'chemistry'    
)

scvi_kwargs = dict(n_hidden=512,
                   n_latent=30,
                   n_layers=2,
                   gene_likelihood='nb',
                   dispersion='gene-label')

trainer_kwargs = dict(
    checkpointing_monitor = 'elbo_validation',
    early_stopping_monitor = 'reconstruction_loss_validation',
    early_stopping_patience = 10,
    early_stopping_min_delta=0.1,
    early_stopping = True,
    max_epochs = 1000,

    #logger = # wandb
)
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.train.TrainingPlan.html#scvi.train.TrainingPlan
plan_kwargs = dict(
    lr = 5e-4,
    #reduce_lr_on_plateau = True
)

# https://docs.scvi-tools.org/en/stable/api/reference/scvi.module.VAE.html#scvi.module.VAE
#vae = dict(
#    use_layer_norm='both',
#    use_batch_norm='none',
#    encode_covariates=True,
#    deeply_inject_covariates=False
#)
parameter_dict = setup_kwargs | scvi_kwargs | trainer_kwargs | plan_kwargs

In [None]:
parameter_dict

In [None]:
run_name = f"Step00_COV{'_'.join(parameter_dict['categorical_covariate_keys'])}"
run_name

In [None]:
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.model.SCVI.html#scvi.model.SCVI.train


In [None]:
scvi.model.SCVI.setup_anndata(adata, 
                              **setup_kwargs)

In [None]:
logger = CustomWandbLogger(name = run_name, project='inflammation_atlas_R1', config = parameter_dict)

Wd decided to use gene-label because-...
https://discourse.scverse.org/t/what-model-to-use-when-integrating-batches-of-scrna-seq-matrices-containing-150-000-t-and-innate-lymphoid-cell-ilc-sub-populations/454/7

In [None]:
model = scvi.model.SCVI(adata, **scvi_kwargs)

In [None]:
model.train(logger=logger, plan_kwargs = plan_kwargs, **trainer_kwargs)

In [None]:
wandb.finish()

## Save the results

In [None]:
if overwriteData:
    model.save(here(f"01_data_processing/results/04_INFLAMMATION_main_HVGsubset_scVI_step00/"), 
               overwrite = True, 
               save_anndata = False)

In [None]:
session_info.show(excludes=['google3'],)