In [1]:
import os
import sys

import json

import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt

from lightning.pytorch.callbacks import ModelCheckpoint

import torch

from dotenv import load_dotenv

from lightning.pytorch.loggers import WandbLogger
import wandb

import session_info
import warnings
from pyprojroot.here import here

#plt.style.use(['science','nature','no-latex'])
dpi_fig_save = 300
sc.set_figure_params(dpi=100, dpi_save=dpi_fig_save, vector_friendly=True)

# Setting some parameters
warnings.filterwarnings("ignore")

from sklearn.model_selection import StratifiedKFold

overwriteData = True
overwriteFigures = True

# Set random seed
random_seed = 42

import warnings
warnings.filterwarnings('ignore')

import scvi
scvi.settings.dl_num_workers = 0
scvi.settings.seed = random_seed

#torch.set_float32_matmul_precision('high')
#torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)
# torch.multiprocessing.set_sharing_strategy('file_system')

%load_ext autoreload
%autoreload 2



  from .autonotebook import tqdm as notebook_tqdm


 captum (see https://github.com/pytorch/captum).






INFO: Seed set to 42


INFO:lightning.fabric.utilities.seed:Seed set to 42


In [2]:
assert load_dotenv()

In [3]:
workingDir = here('03_downstream_analysis/04_integration_with_annotation/')
workingDir

PosixPath('/home/jupyter/Inflammation-PBMCs-Atlas/03_downstream_analysis/03_scANVI_integration_with_annotation')

In [4]:
class CustomWandbLogger(WandbLogger):
    @property
    def save_dir(self):
        """Gets the save directory.

        Returns:
            The path to the save directory.

        """
        return self.experiment.dir

In [5]:
scvi.__version__

'1.1.2'

# Loading data


In [6]:
# Load the h5ad file
adata = sc.read_h5ad(here("03_downstream_analysis/04_integration_with_annotation/04_MAIN_geneUniverse_noRBCnPlatelets.h5ad"))#, 
                     #backed='r+', chunk_size=50000)


In [7]:
adata.obs['binned_age'] = adata.obs['binned_age'].astype(str)

#### Preparing scANVI training

In [8]:
scvi_model = sca.models.SCVI.load(here(f"{workingDir}/results/scVI_model_pretreined_noRBCnPlat/"), adata=adata) 

[34mINFO    [0m File                                                                                                      
         [35m/home/jupyter/Inflammation-PBMCs-Atlas/03_downstream_analysis/04_integration_with_annotation/resu[0m
         [35mlts/scVI_model_pretreined_noRBCnPlat/[0m[95mmodel.pt[0m already downloaded                                          




### Fine tuning with scANVI

**Parameters**

In [9]:
scANVI_trainer_kwargs = dict(
    n_samples_per_label = None,
    check_val_every_n_epoch = None,
    train_size = 0.8,
    validation_size = 0.2,
    shuffle_set_split = True,
    checkpointing_monitor = 'elbo_validation',
    early_stopping_monitor = 'reconstruction_loss_validation',
    early_stopping_patience = 2,
    early_stopping_min_delta=0.1,
    early_stopping = True,
    max_epochs = 1000,
)
# https://docs.scvi-tools.org/en/stable/api/reference/scvi.train.TrainingPlan.html#scvi.train.TrainingPlan
plan_kwargs = dict(
    lr = 5e-5,
    #reduce_lr_on_plateau = True
)
datasplitter_kwargs = dict(pin_memory=False)
scanvi_parameter_dict = scANVI_trainer_kwargs | plan_kwargs | datasplitter_kwargs

In [10]:
run_name = f"MAINobj_scANVI_fineTuning_lowLR_noRBCnPlat"
run_name

'MAINobj_scANVI_fineTuning_lowLR_noRBCnPlat'

In [11]:
scanvi_model = sca.models.SCANVI.from_scvi_model(scvi_model, unlabeled_category = "unknown")
scanvi_model



In [12]:
logger = CustomWandbLogger(name = run_name, project='inflammation_atlas_R1_scANVI', config = scanvi_parameter_dict)

In [13]:
model_checkpoint = ModelCheckpoint(monitor='elbo_validation', every_n_epochs=1, save_last=True, save_top_k = -1, 
                                   dirpath=f"{workingDir}/results/scANVI_model_fineTuned_lowLR_noRBCnPlat_checkpoints/")

In [14]:
try:
    scanvi_model.train(logger=logger, plan_kwargs = plan_kwargs, datasplitter_kwargs=datasplitter_kwargs, 
                       enable_checkpointing=True, callbacks=[model_checkpoint], 
                       **scANVI_trainer_kwargs) #SAME parameter as scVI
    scanvi_model.save(here(f"{workingDir}/results/scANVI_model_fineTuned_lowLR_noRBCnPlat/"),
               overwrite = True, 
               save_anndata = False)
    scanvi_emb = scanvi_model.get_latent_representation(adata=adata)
    np.savez_compressed(file = str(here(f"{workingDir}/results/scANVI_model_fineTuned_lowLR_noRBCnPlat/scANVI_embedding.npz")), arr=scanvi_emb)
    
except Exception as e:
    print(f"An error occurred: {e}")
    scanvi_model.save(here(f"{workingDir}/results/scANVI_model_fineTuned_lowLR_noRBCnPlat_WITHERRORS/"), 
                      overwrite=True, 
                      save_anndata=False)

[34mINFO    [0m Training for [1;36m1000[0m epochs.                                                                                 


INFO: GPU available: True (cuda), used: True


INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True


INFO: TPU available: False, using: 0 TPU cores


INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores


INFO: IPU available: False, using: 0 IPUs


INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs


INFO: HPU available: False, using: 0 HPUs


INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


[34m[1mwandb[0m: Currently logged in as: [33mdav1989[0m ([33minflammation[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: wandb version 0.17.0 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


[34m[1mwandb[0m: Tracking run with wandb version 0.16.5


[34m[1mwandb[0m: Run data is saved locally in [35m[1m./wandb/run-20240528_215509-ye7pjj32[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.


[34m[1mwandb[0m: Syncing run [33mMAINobj_scANVI_fineTuning_lowLR_noRBCnPlat[0m


[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/inflammation/inflammation_atlas_R1_scANVI[0m


[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/inflammation/inflammation_atlas_R1_scANVI/runs/ye7pjj32/workspace[0m


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Training:   0%|                                                                                                                                  | 0/1000 [00:00<?, ?it/s]

Epoch 1/1000:   0%|                                                                                                                              | 0/1000 [00:00<?, ?it/s]

Epoch 1/1000:   0%|                                                                                                                | 1/1000 [26:33<442:14:11, 1593.65s/it]

Epoch 1/1000:   0%|                                                  | 1/1000 [26:33<442:14:11, 1593.65s/it, v_num=jj32, train_loss_step=1.79e+3, train_loss_epoch=1892.5]

Epoch 2/1000:   0%|                                                  | 1/1000 [26:34<442:14:11, 1593.65s/it, v_num=jj32, train_loss_step=1.79e+3, train_loss_epoch=1892.5]

Epoch 2/1000:   0%|                                                  | 2/1000 [50:21<414:48:25, 1496.30s/it, v_num=jj32, train_loss_step=1.79e+3, train_loss_epoch=1892.5]

Epoch 2/1000:   0%|                                                 | 2/1000 [50:21<414:48:25, 1496.30s/it, v_num=jj32, train_loss_step=1.95e+3, train_loss_epoch=1.88e+3]

Epoch 3/1000:   0%|                                                 | 2/1000 [50:22<414:48:25, 1496.30s/it, v_num=jj32, train_loss_step=1.95e+3, train_loss_epoch=1.88e+3]

Epoch 3/1000:   0%|▏                                              | 3/1000 [1:14:10<405:50:57, 1465.45s/it, v_num=jj32, train_loss_step=1.95e+3, train_loss_epoch=1.88e+3]

Epoch 3/1000:   0%|▏                                              | 3/1000 [1:14:10<405:50:57, 1465.45s/it, v_num=jj32, train_loss_step=1.86e+3, train_loss_epoch=1.88e+3]

Epoch 4/1000:   0%|▏                                              | 3/1000 [1:14:11<405:50:57, 1465.45s/it, v_num=jj32, train_loss_step=1.86e+3, train_loss_epoch=1.88e+3]

Epoch 4/1000:   0%|▏                                              | 4/1000 [1:37:56<401:05:24, 1449.72s/it, v_num=jj32, train_loss_step=1.86e+3, train_loss_epoch=1.88e+3]

Epoch 4/1000:   0%|▏                                               | 4/1000 [1:37:56<401:05:24, 1449.72s/it, v_num=jj32, train_loss_step=1.9e+3, train_loss_epoch=1.88e+3]

Epoch 5/1000:   0%|▏                                               | 4/1000 [1:37:57<401:05:24, 1449.72s/it, v_num=jj32, train_loss_step=1.9e+3, train_loss_epoch=1.88e+3]

Epoch 5/1000:   0%|▏                                               | 5/1000 [2:01:39<398:04:42, 1440.28s/it, v_num=jj32, train_loss_step=1.9e+3, train_loss_epoch=1.88e+3]

Epoch 5/1000:   0%|▏                                              | 5/1000 [2:01:39<398:04:42, 1440.28s/it, v_num=jj32, train_loss_step=1.87e+3, train_loss_epoch=1.88e+3]

Epoch 6/1000:   0%|▏                                              | 5/1000 [2:01:40<398:04:42, 1440.28s/it, v_num=jj32, train_loss_step=1.87e+3, train_loss_epoch=1.88e+3]

Epoch 6/1000:   1%|▎                                              | 6/1000 [2:25:37<397:25:51, 1439.39s/it, v_num=jj32, train_loss_step=1.87e+3, train_loss_epoch=1.88e+3]

Epoch 6/1000:   1%|▎                                              | 6/1000 [2:25:37<397:25:51, 1439.39s/it, v_num=jj32, train_loss_step=1.88e+3, train_loss_epoch=1.88e+3]

Epoch 7/1000:   1%|▎                                              | 6/1000 [2:25:38<397:25:51, 1439.39s/it, v_num=jj32, train_loss_step=1.88e+3, train_loss_epoch=1.88e+3]

Epoch 7/1000:   1%|▎                                              | 7/1000 [2:49:24<395:56:00, 1435.41s/it, v_num=jj32, train_loss_step=1.88e+3, train_loss_epoch=1.88e+3]

Epoch 7/1000:   1%|▎                                              | 7/1000 [2:49:24<395:56:00, 1435.41s/it, v_num=jj32, train_loss_step=1.84e+3, train_loss_epoch=1.88e+3]

Epoch 8/1000:   1%|▎                                              | 7/1000 [2:49:25<395:56:00, 1435.41s/it, v_num=jj32, train_loss_step=1.84e+3, train_loss_epoch=1.88e+3]

Epoch 8/1000:   1%|▍                                              | 8/1000 [3:13:20<395:36:01, 1435.65s/it, v_num=jj32, train_loss_step=1.84e+3, train_loss_epoch=1.88e+3]

Epoch 8/1000:   1%|▍                                              | 8/1000 [3:13:20<395:36:01, 1435.65s/it, v_num=jj32, train_loss_step=1.84e+3, train_loss_epoch=1.88e+3]

Epoch 9/1000:   1%|▍                                              | 8/1000 [3:13:21<395:36:01, 1435.65s/it, v_num=jj32, train_loss_step=1.84e+3, train_loss_epoch=1.88e+3]

Epoch 9/1000:   1%|▍                                              | 9/1000 [3:37:16<395:13:05, 1435.71s/it, v_num=jj32, train_loss_step=1.84e+3, train_loss_epoch=1.88e+3]

Epoch 9/1000:   1%|▍                                              | 9/1000 [3:37:16<395:13:05, 1435.71s/it, v_num=jj32, train_loss_step=1.85e+3, train_loss_epoch=1.88e+3]

Epoch 10/1000:   1%|▍                                             | 9/1000 [3:37:17<395:13:05, 1435.71s/it, v_num=jj32, train_loss_step=1.85e+3, train_loss_epoch=1.88e+3]

Epoch 10/1000:   1%|▍                                            | 10/1000 [4:01:19<395:24:53, 1437.87s/it, v_num=jj32, train_loss_step=1.85e+3, train_loss_epoch=1.88e+3]

Epoch 10/1000:   1%|▍                                            | 10/1000 [4:01:19<395:24:53, 1437.87s/it, v_num=jj32, train_loss_step=1.98e+3, train_loss_epoch=1.88e+3]

Epoch 11/1000:   1%|▍                                            | 10/1000 [4:01:20<395:24:53, 1437.87s/it, v_num=jj32, train_loss_step=1.98e+3, train_loss_epoch=1.88e+3]

Epoch 11/1000:   1%|▍                                            | 11/1000 [4:25:18<395:05:32, 1438.15s/it, v_num=jj32, train_loss_step=1.98e+3, train_loss_epoch=1.88e+3]

Epoch 11/1000:   1%|▍                                            | 11/1000 [4:25:18<395:05:32, 1438.15s/it, v_num=jj32, train_loss_step=1.75e+3, train_loss_epoch=1.88e+3]

Epoch 12/1000:   1%|▍                                            | 11/1000 [4:25:18<395:05:32, 1438.15s/it, v_num=jj32, train_loss_step=1.75e+3, train_loss_epoch=1.88e+3]

Epoch 12/1000:   1%|▌                                            | 12/1000 [4:49:17<394:48:13, 1438.56s/it, v_num=jj32, train_loss_step=1.75e+3, train_loss_epoch=1.88e+3]

Epoch 12/1000:   1%|▌                                            | 12/1000 [4:49:17<394:48:13, 1438.56s/it, v_num=jj32, train_loss_step=1.95e+3, train_loss_epoch=1.88e+3]

Epoch 13/1000:   1%|▌                                            | 12/1000 [4:49:18<394:48:13, 1438.56s/it, v_num=jj32, train_loss_step=1.95e+3, train_loss_epoch=1.88e+3]

Epoch 13/1000:   1%|▌                                            | 13/1000 [5:13:02<393:16:32, 1434.44s/it, v_num=jj32, train_loss_step=1.95e+3, train_loss_epoch=1.88e+3]

Epoch 13/1000:   1%|▌                                            | 13/1000 [5:13:02<393:16:32, 1434.44s/it, v_num=jj32, train_loss_step=1.85e+3, train_loss_epoch=1.88e+3]

Epoch 14/1000:   1%|▌                                            | 13/1000 [5:13:03<393:16:32, 1434.44s/it, v_num=jj32, train_loss_step=1.85e+3, train_loss_epoch=1.88e+3]

Epoch 14/1000:   1%|▋                                            | 14/1000 [5:37:02<393:18:40, 1436.02s/it, v_num=jj32, train_loss_step=1.85e+3, train_loss_epoch=1.88e+3]

Epoch 14/1000:   1%|▋                                            | 14/1000 [5:37:02<393:18:40, 1436.02s/it, v_num=jj32, train_loss_step=1.88e+3, train_loss_epoch=1.88e+3]

Epoch 15/1000:   1%|▋                                            | 14/1000 [5:37:02<393:18:40, 1436.02s/it, v_num=jj32, train_loss_step=1.88e+3, train_loss_epoch=1.88e+3]

Epoch 15/1000:   2%|▋                                            | 15/1000 [6:01:16<394:23:52, 1441.45s/it, v_num=jj32, train_loss_step=1.88e+3, train_loss_epoch=1.88e+3]

Epoch 15/1000:   2%|▋                                            | 15/1000 [6:01:16<394:23:52, 1441.45s/it, v_num=jj32, train_loss_step=1.78e+3, train_loss_epoch=1.88e+3]

Epoch 16/1000:   2%|▋                                            | 15/1000 [6:01:17<394:23:52, 1441.45s/it, v_num=jj32, train_loss_step=1.78e+3, train_loss_epoch=1.88e+3]

Epoch 16/1000:   2%|▋                                            | 16/1000 [6:25:29<394:56:18, 1444.90s/it, v_num=jj32, train_loss_step=1.78e+3, train_loss_epoch=1.88e+3]

Epoch 16/1000:   2%|▋                                            | 16/1000 [6:25:29<394:56:18, 1444.90s/it, v_num=jj32, train_loss_step=1.94e+3, train_loss_epoch=1.88e+3]

Epoch 16/1000:   2%|▋                                            | 16/1000 [6:25:29<395:08:10, 1445.62s/it, v_num=jj32, train_loss_step=1.94e+3, train_loss_epoch=1.88e+3]




Monitored metric reconstruction_loss_validation did not improve in the last 2 records. Best score: 1855.179. Signaling Trainer to stop.


In [15]:
wandb.finish()

[34m[1mwandb[0m: - 0.004 MB of 0.004 MB uploaded

[34m[1mwandb[0m: \ 0.004 MB of 0.004 MB uploaded

[34m[1mwandb[0m: | 0.012 MB of 0.012 MB uploaded

[34m[1mwandb[0m:                                                                                


[34m[1mwandb[0m: 
[34m[1mwandb[0m: Run history:
[34m[1mwandb[0m:                     elbo_train ██▇▆▅▄▄▃▃▂▂▂▁▁▁▁
[34m[1mwandb[0m:                elbo_validation █▇▆▅▄▄▃▃▂▂▂▂▁▁▁▁
[34m[1mwandb[0m:                          epoch ▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇█████
[34m[1mwandb[0m:                kl_global_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:           kl_global_validation ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
[34m[1mwandb[0m:                 kl_local_train ▇█▇▆▅▄▄▃▃▂▂▂▁▁▁▁
[34m[1mwandb[0m:            kl_local_validation █▇▆▅▄▄▃▃▃▂▂▂▁▁▁▁
[34m[1mwandb[0m:      reconstruction_loss_train █▄▃▃▂▂▂▂▂▁▁▁▁▁▁▁
[34m[1mwandb[0m: reconstruction_loss_validation █▆▅▄▄▃▂▂▂▂▂▁▁▁▁▁
[34m[1mwandb[0m:                 train_accuracy ▁▅▅▆▆▇▇▇▇▇▇█████
[34m[1mwandb[0m:        train_calibration_error █▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁
[34m[1mwandb[0m:      train_classification_loss █▄▃▃▃▂▂▂▂▂▂▁▁▁▁▁
[34m[1mwandb[0m:                 train_f1_score ▁▅▅▆▆▇▇▇▇▇▇█████
[34m[1mwandb[0m:               trai

[34m[1mwandb[0m: 🚀 View run [33mMAINobj_scANVI_fineTuning_lowLR_noRBCnPlat[0m at: [34m[4mhttps://wandb.ai/inflammation/inflammation_atlas_R1_scANVI/runs/ye7pjj32/workspace[0m
[34m[1mwandb[0m: Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)


[34m[1mwandb[0m: Find logs at: [35m[1m./wandb/run-20240528_215509-ye7pjj32/logs[0m
