In [1]:
import scarches as sca
import pickle as pkl
import scvi
import scanpy as sc
import os
import numpy as np
from dotenv import load_dotenv
from lightning.pytorch.loggers import WandbLogger

from pyprojroot import here

overwriteData = True

  from .autonotebook import tqdm as notebook_tqdm
 captum (see https://github.com/pytorch/captum).


In [2]:
assert load_dotenv()

### Loading data

In [3]:
reference_adata = sc.read_h5ad(here(f"09_patient_classifier/SCGT00_CentralizedDataset/results_batches/scANVI_model_fineTuned_lowLR_batches/adata.h5ad"))

external_adata = sc.read_h5ad(here("01_data_processing/SCGT00_CentralizedDataset/results/SCGT00_EXTERNAL_afterQC.h5ad"))

**Keep only selected genes**

In [4]:
external_adata = external_adata[:,reference_adata.var.index].copy()

In [5]:
reference_adata, external_adata

(AnnData object with n_obs × n_vars = 756120 × 8253
     obs: 'studyID', 'libraryID', 'sampleID', 'chemistry', 'disease', 'sex', 'binned_age', 'batches', 'Level1', '_scvi_batch', '_scvi_labels'
     var: 'hgnc_id', 'symbol', 'locus_group', 'HUGO_status'
     uns: '_scvi_manager_uuid', '_scvi_uuid',
 AnnData object with n_obs × n_vars = 379359 × 8253
     obs: 'studyID', 'libraryID', 'sampleID', 'chemistry', 'technology', 'patientID', 'disease', 'timepoint_replicate', 'treatmentStatus', 'therapyResponse', 'sex', 'age', 'BMI', 'binned_age', 'diseaseStatus', 'smokingStatus', 'ethnicity', 'institute', 'diseaseGroup', 'batches', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_20_genes', 'total_counts_mt', 'log1p_total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'log1p_total_counts_ribo', 'pct_counts_ribo', 'total_counts_hb', 'log1p_total_counts_hb', 'pct_counts_hb', 'total_counts_plt', 'log1p_total_counts_plt', 'pct_counts_plt'


In [6]:
scanvi_model = scvi.model.SCANVI.load(here(f"09_patient_classifier/SCGT00_CentralizedDataset/results_batches/scANVI_model_fineTuned_lowLR_batches/"), 
                                      adata=reference_adata)

[34mINFO    [0m File                                                                                                      
         [35m/scratch_isilon/groups/singlecell/shared/projects/Inflammation-PBMCs-Atlas/04_visualizing_final_embedding_space/SCGT00_CentralizedDataset/02_scANVI_integration_wit[0m
         [35mh_annotation/results/scANVI_model_fineTuned_lowLR_batches/[0m[95mmodel.pt[0m already downloaded                     


Outdated cuSPARSE installation found.
Version JAX was built against: 12200
Minimum supported: 12100
Installed version: 12002
The local installation version must be no lower than 12100. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


### Parametrize model

In [8]:
model = scvi.model.SCANVI.load_query_data(
    external_adata,
    scanvi_model,
    freeze_dropout = True,
)
model._unlabeled_indices = np.arange(external_adata.n_obs)
model._labeled_indices = []

In [9]:
parameter_dict = {}
trainer_kwargs = dict(
    checkpointing_monitor = 'elbo_validation',
    early_stopping_monitor = 'reconstruction_loss_validation',
    early_stopping_patience = 10,
    early_stopping = True,
    max_epochs = 100,
    batch_size = 128 # if QUERY_ADATA_NAME != 'EXTERNAL' else 127
)

plan_kwargs = dict(weight_decay=0.0)
parameter_dict.update(trainer_kwargs)
parameter_dict.update(plan_kwargs)

In [10]:
logger = WandbLogger(
    project='inflammation_atlas_PatientClassifier_scANVI', 
    entity='inflammation',
    config=parameter_dict,
    name = f'scANVI_query_test_SCGT00_EXTERNAL',
)

### Querying external

In [11]:
model.train(
    logger=logger, 
    plan_kwargs=plan_kwargs,
    **trainer_kwargs
)

[34mINFO    [0m Training for [1;36m100[0m epochs.                                                                                  


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA RTX A6000') that has Tensor Cores. To properly utilize th

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/dmaspero/miniconda3/envs/scarches/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=35` in the `DataLoader` to improve performance.
/home/dmaspero/miniconda3/envs/scarches/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=35` in the `DataLoader` to improve performance.


Epoch 2/100:   1%|▍                                             | 1/100 [01:19<2:11:55, 79.96s/it, v_num=dqe3, train_loss_step=1.91e+3, train_loss_epoch=1.88e+3]

/home/dmaspero/miniconda3/envs/scarches/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [12]:
if overwriteData:
    model.save(
        here(f"09_patient_classifier/SCGT00_CentralizedDataset/results_batches/scANVI_SCGT00_EXTERNAL_batches_queried"), 
        overwrite = True,
        save_anndata = False)

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7b34c8528130>>
Traceback (most recent call last):
  File "/home/dmaspero/miniconda3/envs/scarches/lib/python3.9/site-packages/ipykernel/ipkernel.py", line 770, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 

KeyboardInterrupt



### Label transfer

In [None]:
query_labels = model.predict(external_adata)

In [None]:
query_latents = model.get_latent_representation(
    external_adata
)
reference_latents = model.get_latent_representation(
    reference_adata
)

In [None]:
query_ad = sc.AnnData(
    X=query_latents, 
    obs=(external_adata.obs.assign(labels=query_labels)))
query_ad.write(here(f"09_patient_classifier/SCGT00_CentralizedDataset/results_batches/scANVI_SCGT00_EXTERNAL_batches_latent.h5ad"), compression='gzip')

In [None]:
reference_ad = sc.AnnData(
    X=reference_latents, 
    obs=reference_adata.obs)
reference_ad.write(here(f"09_patient_classifier/SCGT00_CentralizedDataset/results_batches/scANVI_SCGT00_MAIN_batches_latent.h5ad"), compression='gzip')

<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
