In [None]:
SCPOLI_ID: str
QUERY_ADATA_NAME: str
BUCKET_DIRPATH: str
N_EPOCHS: int = 50
PRETRAINING_EPOCHS: int = 40
ETA: int = 1
LEARNING_RATE: float = 1e-3

In [None]:
from scarches.models import scPoli
import scanpy as sc
import wandb
import os
import numpy as np
from dotenv import load_dotenv

In [None]:
assert load_dotenv()

In [None]:
def here(fpath):
    return os.path.join(BUCKET_DIRPATH, fpath)

In [None]:
overwriteData = True

In [None]:
reference_adata = sc.read_h5ad(
   "04_MAIN_geneUniverse.h5ad",
   backed='r'
)
target_adata = sc.read_h5ad(
    f"05_{QUERY_ADATA_NAME}_geneUniverse.h5ad",
    backed='r'
)

In [None]:
target_adata.obs['Level1'] = 'unknown'
target_adata.obs['Level2'] = 'unknown'

In [None]:
scpoli_query = scPoli.load_query_data(
    adata=target_adata,
    unknown_ct_names = ['unknown'],
    reference_model=here(f"03_Downstream_Analysis/PatientClassifier/scPoli/results/01_reference/scPoly_model_{SCPOLI_ID}"),
    labeled_indices=[],
)

In [None]:
early_stopping_kwargs = {
    "early_stopping_metric": "val_prototype_loss",
    "mode": "min",
    "threshold": 0.1, 
    "patience": 5,
    "reduce_lr": False,
    "lr_patience": 5,
    "lr_factor": 0.1,
}

In [None]:
train_parameters_dict = {
    "lr": LEARNING_RATE,
    "n_epochs": N_EPOCHS,
    "pretraining_epochs": PRETRAINING_EPOCHS,
    "eta": ETA, 
    "n_workers": 8
}

In [None]:
scPoli_params = {}
scPoli_params.update(early_stopping_kwargs)
scPoli_params.update(train_parameters_dict)
scPoli_params

In [None]:
logger = wandb.init(
    project='inflammation_atlas_PatientClassifier_scPoly', 
    entity='inflammation',
    config=scPoli_params,
    name = f'scPoli_query_{SCPOLI_ID}'
)

In [None]:
scpoli_query.train(
    **train_parameters_dict,
    w_logger=logger,
    prefix="query/",
    early_stopping_kwargs=early_stopping_kwargs
)

In [None]:
if overwriteData:
    scpoli_query.save(
        here(f"03_Downstream_Analysis/PatientClassifier/scPoli/results/02_query/query_models/scPoli_{QUERY_ADATA_NAME}_{SCPOLI_ID}_query"), 
        overwrite = True,
        save_anndata = False)

In [None]:
from tqdm.auto import trange 

batch_size = 1000
query_labels = []
for idx in trange(int(target_adata.shape[0] // batch_size + 1)):
    start_idx = idx * batch_size
    query_labels.append(scpoli_query.classify(adata=target_adata[start_idx:start_idx+batch_size], scale_uncertainties=False))

In [None]:
query_labels_final = {}
for level in ['Level1', 'Level2']:
    query_labels_final[level] = {}
    for value in ['preds', 'uncert', 'weighted_distances']:
        query_labels_final[level][value] = np.concatenate([query_labels[idx][level][value] for idx in range(len(query_labels))])

In [None]:
query_latents = scpoli_query.get_latent(
    target_adata,
    mean=True
)
reference_latents = scpoli_query.get_latent(
    reference_adata,
    mean=True
)

In [None]:
f_query_ad = sc.AnnData(
    X=query_latents, 
    obs=(target_adata.obs
         .assign(
             Level1=query_labels_final['Level1']['preds'], 
             Level1_unc=query_labels_final['Level1']['uncert'],
             Level2=query_labels_final['Level2']['preds'],
             Level2_unc=query_labels_final['Level2']['uncert'])))
f_query_ad.uns.update(scpoli_query.get_conditional_embeddings())
f_query_ad.write(here(f"03_Downstream_Analysis/PatientClassifier/scPoli/results/02_query/query_output/scPoli_{QUERY_ADATA_NAME}_{SCPOLI_ID}.h5ad"), compression='gzip')

In [None]:
f_reference_ad = sc.AnnData(
    X=reference_latents, 
    obs=reference_adata.obs)
f_reference_ad.uns.update(scpoli_query.get_conditional_embeddings())
f_reference_ad.write(here(f"03_Downstream_Analysis/PatientClassifier/scPoli/results/02_query/ref_latents/scPoli_{QUERY_ADATA_NAME}_{SCPOLI_ID}.h5ad"), compression='gzip')