In [None]:
SCANVI_ID: str = '256_30_Level1_run1'
QUERY_ADATA_NAME: str = 'EXTERNAL'
BUCKET_DIRPATH: str

In [10]:
import scarches as sca
import pickle as pkl
import scvi
import scanpy as sc
import os
import numpy as np
from dotenv import load_dotenv
from lightning.pytorch.loggers import WandbLogger

 captum (see https://github.com/pytorch/captum).


In [None]:
assert load_dotenv()

In [None]:
def here(fpath):
    return os.path.join(BUCKET_DIRPATH, fpath)

In [None]:
overwriteData = True

In [None]:
reference_adata = sc.read_h5ad(
   "04_MAIN_geneUniverse.h5ad",
)
if 'SPLIT' in QUERY_ADATA_NAME:
    fold_idx = int(QUERY_ADATA_NAME.split('_')[1])
    with open(here('03_Downstream_Analysis/PatientClassifier/5foldCV/data/K_FOLD_cellID.pkl'), 'rb') as f:
        splits = pkl.load(f)
    print(fold_idx)
    train_idx, validation_idx = splits[fold_idx]
    target_adata = reference_adata[validation_idx].copy()
    reference_adata = reference_adata[train_idx].copy()
else:
    target_adata = sc.read_h5ad(
        f"05_{QUERY_ADATA_NAME}_geneUniverse.h5ad",
    )

In [None]:
target_adata

In [None]:
reference_adata

In [None]:
scanvi_model = scvi.model.SCANVI.load(here(f"03_Downstream_Analysis/PatientClassifier/scANVI/results/012_fine_tuning/models/scANVI_{SCANVI_ID}"), adata=reference_adata)

In [None]:
model = scvi.model.SCANVI.load_query_data(
    target_adata,
    scanvi_model,
    freeze_dropout = True,
)
model._unlabeled_indices = np.arange(target_adata.n_obs)
model._labeled_indices = []

In [None]:
scvi_parameter_dict = {}
trainer_kwargs = dict(
    checkpointing_monitor = 'elbo_validation',
    early_stopping_monitor = 'reconstruction_loss_validation',
    early_stopping_patience = 10,
    early_stopping = True,
    max_epochs = 100,
    batch_size = 128 if QUERY_ADATA_NAME != 'EXTERNAL' else 127
)

plan_kwargs = dict(weight_decay=0.0)
scvi_parameter_dict.update(trainer_kwargs)
scvi_parameter_dict.update(plan_kwargs)

In [None]:
logger = WandbLogger(
    project='inflammation_atlas_PatientClassifier_scANVI', 
    entity='inflammation',
    config=scvi_parameter_dict,
    name = f'scANVI_query_test_{QUERY_ADATA_NAME}_{SCANVI_ID}',
)

In [None]:
model.train(
    logger=logger, 
    plan_kwargs=plan_kwargs,
    **trainer_kwargs
)

In [None]:
if overwriteData:
    model.save(
        here(f"03_Downstream_Analysis/PatientClassifier/scANVI/results/02_query/models/scANVI_{QUERY_ADATA_NAME}_{SCANVI_ID}_query"), 
        overwrite = True,
        save_anndata = False)

In [None]:
query_labels = model.predict(target_adata)

In [None]:
query_latents = model.get_latent_representation(
    target_adata
)
reference_latents = model.get_latent_representation(
    reference_adata
)

In [None]:
query_ad = sc.AnnData(
    X=query_latents, 
    obs=(target_adata.obs.assign(labels=query_labels)))
query_ad.write(here(f"03_Downstream_Analysis/PatientClassifier/scANVI/results/02_query/output/scANVI_{QUERY_ADATA_NAME}_{SCANVI_ID}.h5ad"), compression='gzip')

In [None]:
reference_ad = sc.AnnData(
    X=reference_latents, 
    obs=reference_adata.obs)
reference_ad.write(here(f"03_Downstream_Analysis/PatientClassifier/scANVI/results/02_query/ref_latents/scANVI_{QUERY_ADATA_NAME}_{SCANVI_ID}.h5ad"), compression='gzip')