# Querying the Atlas

In [1]:
import os

import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=DeprecationWarning)
warnings.simplefilter(action="ignore", category=UserWarning)

import sys
import scanpy as sc
import numpy as np
import pandas as pd
import scarches as sca
import anndata as ad
from scipy import sparse
import gdown
import gzip
import shutil
import urllib.request
sc.set_figure_params(figsize=(4, 4), dpi=50)
import pynndescent
# for white background of figures (only for docs rendering)
%config InlineBackend.print_figure_kwargs={'facecolor' : "w"}
%config InlineBackend.figure_format='retina'

Global seed set to 0


In [2]:
path = "/bmbl_data/cankun_notebook/loss_y/GC_sample.rds.h5ad"
adata = sc.read_h5ad(path)

In [3]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=2000,
    batch_key="sample",
    subset=True)

In [4]:
X_train = adata.X
ref_nn_index = pynndescent.NNDescent(X_train)
ref_nn_index.prepare()

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


In [5]:
adata

AnnData object with n_obs × n_vars = 1105 × 2000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'scrublet_scores', 'scrublet_predict', 'sample', 'study', 'percent.mt', 'manual_doublet', 'RNA_snn_res.0.5', 'seurat_clusters', 'celltype', 'cohort', 'patient_recode', 'celltype.big', 'Gender', 'Source', 'Type', 'Age', 'Lauren.s.classification', 'Primary.site', 'MSI.status', 'H..pylori', 'Signet.ring.cell.carcinoma', 'The.WHO.classification', 'Prior.treatment', 'loy_avg', 'housekeeping_avg', 'ratio_Y_housekeeping', 'is_fLOY'
    var: 'vst.mean', 'vst.variance', 'vst.variance.expected', 'vst.variance.standardized', 'vst.variable', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'log1p', 'hvg'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'

## Load custom query data

In [6]:
query_data = sc.read_h5ad('/bmbl_data/cankun_notebook/loss_y/Seurat.merge.T.excluded.qs.h5ad')

In [7]:
query_data

AnnData object with n_obs × n_vars = 110024 × 87266
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cell_name', 'sample', 'source', 'cell_type', 'cell_subtype_level1', 'cell_subtype_level2', 'malignant'
    var: 'features'

In [8]:
query_data = query_data[:, adata.var_names].copy()
query_data

AnnData object with n_obs × n_vars = 110024 × 2000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cell_name', 'sample', 'source', 'cell_type', 'cell_subtype_level1', 'cell_subtype_level2', 'malignant'
    var: 'features'

## (Remove this part for real data) Random subset 10k cells for example 


In [9]:
# Randomly select 10,000 cells
random_subset = np.random.choice(query_data.obs_names, 10000, replace=False)

# Subset the AnnData object
query_data = query_data[random_subset, :].copy()


In [10]:
query_data

AnnData object with n_obs × n_vars = 10000 × 2000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'cell_name', 'sample', 'source', 'cell_type', 'cell_subtype_level1', 'cell_subtype_level2', 'malignant'
    var: 'features'

## Train model

In [11]:
ref_path = 'T_atlas_reference_model/'

model = sca.models.SCVI.load_query_data(
    query_data,
    ref_path,
    freeze_dropout = True,
)

[34mINFO    [0m File T_atlas_reference_model/model.pt already downloaded                            


In [12]:
surgery_epochs = 500
train_kwargs_surgery = {
    "early_stopping": True,
    "early_stopping_monitor": "elbo_train",
    "early_stopping_patience": 10,
    "early_stopping_min_delta": 0.001,
    "plan_kwargs": {"weight_decay": 0.0},
}

In [13]:
model.train(
    max_epochs=surgery_epochs,
    **train_kwargs_surgery
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Epoch 500/500: 100%|██████████| 500/500 [05:40<00:00,  1.47it/s, loss=794, v_num=1]


## Extract prediction from model

In [14]:
query_emb = sc.AnnData(model.get_latent_representation())
query_emb.obs_names = query_data.obs_names

In [15]:
import numba
ref_neighbors, ref_distances = ref_nn_index.query(query_emb.X)

# convert distances to affinities
stds = np.std(ref_distances, axis=1)
stds = (2.0 / stds) ** 2
stds = stds.reshape(-1, 1)
ref_distances_tilda = np.exp(-np.true_divide(ref_distances, stds))
weights = ref_distances_tilda / np.sum(
    ref_distances_tilda, axis=1, keepdims=True
)

@numba.njit
def weighted_prediction(weights, ref_cats):
    """Get highest weight category."""
    N = len(weights)
    predictions = np.zeros((N,), dtype=ref_cats.dtype)
    uncertainty = np.zeros((N,))
    for i in range(N):
        obs_weights = weights[i]
        obs_cats = ref_cats[i]
        best_prob = 0
        for c in np.unique(obs_cats):
            cand_prob = np.sum(obs_weights[obs_cats == c])
            if cand_prob > best_prob:
                best_prob = cand_prob
                predictions[i] = c
                uncertainty[i] = max(1 - best_prob, 0)

    return predictions, uncertainty

# for each annotation level, get prediction and uncertainty
label_keys =  ["celltype"]
for l in label_keys:
    ref_cats = adata.obs[l].cat.codes.to_numpy()[ref_neighbors]
    p, u = weighted_prediction(weights, ref_cats)
    p = np.asarray(adata.obs[l].cat.categories)[p]
    query_emb.obs[l + "_pred"], query_emb.obs[l + "_uncertainty"] = p, u


In [17]:
frequency = query_emb.obs['celltype_pred'].value_counts()
print(frequency)


CD8_C2    9997
CD4_C4       3
Name: celltype_pred, dtype: int64


In [18]:
query_emb.obs

Unnamed: 0,celltype_pred,celltype_uncertainty
SMC07-T_CATTATCTCCCAACGG_1,CD8_C2,0.699972
P08_T_0860_1,CD8_C2,0.699965
SSN09_SSN_CACACTCCAAGAAAGG_2,CD8_C2,0.699973
ACGTTGAAAGTTCCAGAC_10367_2,CD8_C2,0.699981
P14_GTTCATTCAGCCTATA-1_1,CD8_C2,0.699972
...,...,...
P11_T_0216_1,CD8_C2,0.699965
SSN12_SSN_TTCTCAAGTGGTCCGT_2,CD8_C2,0.699965
SSN28_SSN_GCACATAGTCCGTTAA_2,CD8_C2,0.699960
sc5rJUQ050_GCAATCACAGGATTGG_2,CD8_C2,0.699963


## Optional: Set the celltype_pred to Unknown if prediction larger than uncertainty_threshold. Modify the uncertainty_threshold if needed (default=0.2)

In [None]:
uncertainty_threshold =  0.2
for l in label_keys:
    mask = query_emb.obs[l + "_uncertainty"] > uncertainty_threshold
    print(f"{l}: {sum(mask)/len(mask)} unknown")
    query_emb.obs[l + "_pred"].loc[mask] = "Unknown"
query_emb.obs["dataset"] = "test_dataset_TcellAtlas"

## Save results

In [19]:
# Save embedding and object
query_emb.obs.to_csv('./TcellAtlas_query_emb_full.csv')