In [None]:
N_HIDDEN: int = 256
N_LATENT: int = 30
RUN_TAG: str = "test"
REF_NAME: str = ''
LABELS_KEY: str = 'Level2'
BUCKET_DIRPATH: str = ""

In [None]:
for v in ['N_HIDDEN','N_LATENT', 'RUN_TAG', 'LABELS_KEY','BUCKET_DIRPATH']:
    if ((v not in globals()) and (v not in locals())):
        raise ValueError(f"{v} not defined")
    else:
        print(f"{v} = {eval(v)}")

In [None]:
import os
import sys

import json

import pickle as pkl
import scvi
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report
from scarches.models.scpoli import scPoli

import torch

from dotenv import load_dotenv

from lightning.pytorch.loggers import WandbLogger
import wandb

import session_info
import warnings
from pyprojroot.here import here

torch.set_float32_matmul_precision('high')

#plt.style.use(['science','nature','no-latex'])
dpi_fig_save = 300
sc.set_figure_params(dpi=100, dpi_save=dpi_fig_save, vector_friendly=True)

# Setting some parameters
warnings.filterwarnings("ignore")

from sklearn.model_selection import StratifiedKFold

overwriteData = True
overwriteFigures = True

# Set random seed
random_seed = 5

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2

In [None]:
def here(fpath):
    return os.path.join(BUCKET_DIRPATH, fpath)

In [None]:
assert load_dotenv()

# Loading data


In [None]:
# Load the h5ad file
adata_path = '04_MAIN_geneUniverse.h5ad'
adata = sc.read_h5ad(adata_path)

In [None]:
if REF_NAME != '':
    fold_idx = int(REF_NAME.split('_')[1])
    with open(here('03_Downstream_Analysis/PatientClassifier/5foldCV/data/K_FOLD_cellID.pkl'), 'rb') as f:
        splits = pkl.load(f)
    train_idx, _ = splits[fold_idx]
    adata = adata[train_idx].copy()

In [None]:
adata

In [None]:
fold_idx

### Setting up parameters for scANVI training

In [None]:
setup_kwargs = dict(
    layer=None, 
    batch_key='chemistry', # scArches supports only one batch variable. thus, we concatenate the two main sources of batch effect into one colum
    #categorical_covariate_keys = ['sex','binned_age'],
    labels_key = LABELS_KEY # needed for the following scANVI fine tuning
)

scvi_kwargs = dict(
    n_hidden=N_HIDDEN,
    n_latent=N_LATENT,
    n_layers=3, 
    gene_likelihood='nb',
    dispersion='gene-batch')

trainer_kwargs = dict(
    checkpointing_monitor = 'elbo_validation',
    early_stopping_monitor = 'reconstruction_loss_validation',
    early_stopping_patience = 10,
    early_stopping_min_delta = 0.1,
    early_stopping = True,
    max_epochs = 1000,
)

# https://docs.scvi-tools.org/en/stable/api/reference/scvi.train.TrainingPlan.html#scvi.train.TrainingPlan
plan_kwargs = dict(
    lr = 5e-4,
)

In [None]:
scvi_parameter_dict = {}
scvi_parameter_dict.update(setup_kwargs)
scvi_parameter_dict.update(scvi_kwargs)
scvi_parameter_dict.update(trainer_kwargs)
scvi_parameter_dict

#### Connect to wandb

In [None]:
logger = WandbLogger(
    project='inflammation_atlas_PatientClassifier_scANVI', 
    entity='inflammation',
    config=scvi_parameter_dict,
    name = f'scANVI_{N_HIDDEN}_{N_LATENT}_{LABELS_KEY}_{REF_NAME}_{RUN_TAG}',
    tags = [RUN_TAG]
)

In [None]:
#RunParams = dict(random_seed=random_seed, wandb_run_id = logger.id)
#RunParams.update(scPoli_params)
#RunParams

# scVI integration

In [None]:
scvi.model.SCVI.setup_anndata(
    adata, 
    **setup_kwargs)

In [None]:
scvi_model = scvi.model.SCVI(
    adata=adata, 
    **scvi_kwargs
)

In [None]:
scvi_model.train(
    logger=logger,
    plan_kwargs=plan_kwargs,
    **trainer_kwargs
)

## Save the results

In [None]:
if overwriteData:
    scvi_model.save(
        here(f"03_Downstream_Analysis/PatientClassifier/scANVI/results/01_reference/scANVI_{N_HIDDEN}_{N_LATENT}_{LABELS_KEY}_{REF_NAME}_{RUN_TAG}"), 
        overwrite = True, 
        save_anndata = False)

In [None]:
latents = scvi_model.get_latent_representation(
    adata,
)

In [None]:
final_ad = sc.AnnData(
    X=latents, 
    obs=adata.obs
)

In [None]:
final_ad.write(
    here(f"03_Downstream_Analysis/PatientClassifier/scANVI/results/01_reference/output/scANVI_{N_HIDDEN}_{N_LATENT}_{LABELS_KEY}_{REF_NAME}_{RUN_TAG}.h5ad"), 
    compression='gzip')