In [1]:
import os
import sys

import json

import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt

from lightning.pytorch.callbacks import ModelCheckpoint

import torch

from dotenv import load_dotenv

from lightning.pytorch.loggers import WandbLogger
import wandb

import session_info
import warnings
from pyprojroot.here import here

#plt.style.use(['science','nature','no-latex'])
dpi_fig_save = 300
sc.set_figure_params(dpi=100, dpi_save=dpi_fig_save, vector_friendly=True)

# Setting some parameters
warnings.filterwarnings("ignore")

from sklearn.model_selection import StratifiedKFold

overwriteData = True
overwriteFigures = True

# Set random seed
random_seed = 42

import warnings
warnings.filterwarnings('ignore')

import scvi
scvi.settings.dl_num_workers = 0
scvi.settings.seed = random_seed

#torch.set_float32_matmul_precision('high')
#torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
# torch.multiprocessing.set_sharing_strategy('file_system')

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
 captum (see https://github.com/pytorch/captum).
INFO: Seed set to 42
INFO:lightning.fabric.utilities.seed:Seed set to 42


In [2]:
assert load_dotenv()

In [3]:
workingDir = here('04_visualizing_final_embedding_space/SCGT00_CentralizedDataset/')
workingDir

PosixPath('/scratch_isilon/groups/singlecell/shared/projects/Inflammation-PBMCs-Atlas/04_visualizing_final_embedding_space/SCGT00_CentralizedDataset')

In [4]:
class CustomWandbLogger(WandbLogger):
    @property
    def save_dir(self):
        """Gets the save directory.

        Returns:
            The path to the save directory.

        """
        return self.experiment.dir

In [5]:
scvi.__version__

'1.1.2'

# Loading data


In [6]:
# Load the h5ad file
adata = sc.read_h5ad(here(f"{workingDir}/results/scVI_model_pretreined/adata.h5ad"))#, 
                     #backed='r+', chunk_size=50000)


#### Preparing scANVI training

In [7]:
scvi_model = sca.models.SCVI.load(here(f"{workingDir}/results/scVI_model_pretreined_batches/"), adata=adata) 

[34mINFO    [0m File                                                                                                      
         [35m/scratch_isilon/groups/singlecell/shared/projects/Inflammation-PBMCs-Atlas/04_visualizing_final_embedding_space/SCGT00_CentralizedDataset/02_scANVI_integration_wit[0m
         [35mh_annotation/results/scVI_model_pretreined_batches/[0m[95mmodel.pt[0m already downloaded                            


Outdated cuSPARSE installation found.
Version JAX was built against: 12200
Minimum supported: 12100
Installed version: 12002
The local installation version must be no lower than 12100. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


### Fine tuning with scANVI

**Parameters**

In [8]:
scANVI_trainer_kwargs = dict(
    n_samples_per_label = None,
    check_val_every_n_epoch = None,
    train_size = 0.8,
    validation_size = 0.2,
    shuffle_set_split = True,
    checkpointing_monitor = 'elbo_validation',
    early_stopping_monitor = 'reconstruction_loss_validation',
    early_stopping_patience = 2,
    early_stopping_min_delta=0.1,
    early_stopping = True,
    max_epochs = 1000,
)
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.train.TrainingPlan.html#scvi.train.TrainingPlan
plan_kwargs = dict(
    lr = 5e-5,
    #reduce_lr_on_plateau = True
)
datasplitter_kwargs = dict(pin_memory=False)
scanvi_parameter_dict = scANVI_trainer_kwargs | plan_kwargs | datasplitter_kwargs

In [9]:
run_name = f"MAINobj_scANVI_fineTuning_lowLR_batches"
run_name

'MAINobj_scANVI_fineTuning_lowLR_batches'

In [10]:
scanvi_model = sca.models.SCANVI.from_scvi_model(scvi_model, unlabeled_category = "unknown")
scanvi_model



In [11]:
logger = CustomWandbLogger(name = run_name, project='inflammation_atlas_R1_scANVI', config = scanvi_parameter_dict)

In [None]:
scanvi_model.train(logger=logger, plan_kwargs = plan_kwargs, datasplitter_kwargs=datasplitter_kwargs, **scANVI_trainer_kwargs)

[34mINFO    [0m Training for [1;36m1000[0m epochs.                                                                                 


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize th

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 5/1000:   0%|▏                                         | 4/1000 [13:21<56:51:38, 205.52s/it, v_num=a86i, train_loss_step=1.89e+3, train_loss_epoch=1.99e+3]

In [None]:
wandb.finish()

In [None]:
if overwriteData:
    scanvi_model.save(here(f"{workingDir}/results/scANVI_model_fineTuned_lowLR_batches/"), 
                      overwrite = True, 
                      save_anndata = True)

In [None]:
scanvi_emb = scanvi_model.get_latent_representation(adata=adata)

In [None]:
np.savez_compressed(file = str(here(f"{workingDir}/results/scANVI_model_fineTuned_lowLR_batches/scANVI_embedding.npz")), arr=scanvi_emb)