In [1]:
# Parameters
adataPATH_ref = "03_downstream_analysis/08_PatientClassifier/scANVI/results/reference/scANVI_EXTERNAL_*_200_Level2_*.h5ad"
adataPATH_query = "03_downstream_analysis/08_PatientClassifier/scANVI/results/query/scANVI_EXTERNAL_*_200_Level2_*.h5ad"


In [2]:
for v in ['adataPATH_ref','adataPATH_query']:
    if v in locals() or v in globals():
        print(f"{v} = {eval(v)}")
    else:
        raise Exception(f"{v} not specified")

adataPATH_ref = 03_downstream_analysis/08_PatientClassifier/scANVI/results/reference/scANVI_EXTERNAL_*_200_Level2_*.h5ad
adataPATH_query = 03_downstream_analysis/08_PatientClassifier/scANVI/results/query/scANVI_EXTERNAL_*_200_Level2_*.h5ad


In [3]:
import os
import sys
from glob import glob

import scanpy as sc
import pandas as pd

import pynndescent
import numpy as np
import numba

from pyprojroot import here

sys.path.insert(1, str(here('bin')))
# Import custom functions
from customPythonFunctions import aggregating_features

### Defining kNN label transfer function

In [4]:
class LabelTransferWithKNN:
    """ See https://docs.scvi-tools.org/en/stable/tutorials/notebooks/scrna/query_hlca_knn.html """
    def fit(self, ref_adata):
        self.ref_adata = ref_adata
        self.nn_index = pynndescent.NNDescent(self.ref_adata.X)
        self.nn_index.prepare()
        return self

    def predict(self, query_adata, label_keys: list[str] = ['Level1']):

        ref_neighbors, ref_distances = self.nn_index.query(query_adata.X)

        # convert distances to affinities
        stds = np.std(ref_distances, axis=1)
        stds = (2.0 / stds) ** 2
        stds = stds.reshape(-1, 1)
        ref_distances_tilda = np.exp(-np.true_divide(ref_distances, stds))
        weights = ref_distances_tilda / np.sum(ref_distances_tilda, axis=1, keepdims=True)

        for l in label_keys:
            ref_cats = self.ref_adata.obs[l].cat.codes.to_numpy()[ref_neighbors]
            p, u = self.weighted_prediction(weights, ref_cats)
            p = np.asarray(self.ref_adata.obs[l].cat.categories)[p]
            query_adata.obs[l + "_pred"], query_adata.obs[l + "_uncertainty"] = p, u
        
        return query_adata
    
    @staticmethod
    @numba.njit
    def weighted_prediction(weights, ref_cats):
        """Get highest weight category."""
        N = len(weights)
        predictions = np.zeros((N,), dtype=ref_cats.dtype)
        uncertainty = np.zeros((N,))
        for i in range(N):
            obs_weights = weights[i]
            obs_cats = ref_cats[i]
            best_prob = 0
            for c in np.unique(obs_cats):
                cand_prob = np.sum(obs_weights[obs_cats == c])
                if cand_prob > best_prob:
                    best_prob = cand_prob
                    predictions[i] = c
                    uncertainty[i] = max(1 - best_prob, 0)

        return predictions, uncertainty

### Loading data

In [5]:
adataPATH_ref
adataPATH_ref_list = glob(str(here(adataPATH_ref)))
assert(len(adataPATH_ref_list) == 1)

adataPATH_query
adataPATH_query_list = glob(str(here(adataPATH_query)))
assert(len(adataPATH_query_list) == 1)

In [6]:
adataR = sc.read_h5ad(adataPATH_ref_list[0])
adataQ = sc.read_h5ad(adataPATH_query_list[0])
adataR, adataQ

(AnnData object with n_obs × n_vars = 4435922 × 200
     obs: 'studyID', 'libraryID', 'sampleID', 'chemistry', 'disease', 'sex', 'binned_age', 'Level1', 'Level2', '_scvi_batch', '_scvi_labels',
 AnnData object with n_obs × n_vars = 572872 × 200
     obs: 'studyID', 'libraryID', 'sampleID', 'chemistry', 'technology', 'disease', 'sex', 'binned_age', '_scvi_batch', 'Level2', '_scvi_labels', 'labels')

In [7]:
Level1_dict = adataR.obs[['Level1','Level2']].set_index('Level2').to_dict()['Level1']

In [8]:
adataQ.obs['Level1_scANVIpredict'] = adataQ.obs['labels'].map(Level1_dict)

In [9]:
adataQ.obs['Level1_scANVIpredict'].value_counts(normalize = True)

Level1_scANVIpredict
Mono              0.240013
T_CD4_NonNaive    0.236730
T_CD4_Naive       0.132894
T_CD8_NonNaive    0.114005
B                 0.078096
ILC               0.065545
T_CD8_Naive       0.061225
UTC               0.037452
Platelets         0.011968
DC                0.011479
pDC               0.004270
Cycling_cells     0.003189
Plasma            0.001243
Progenitors       0.001107
RBC               0.000786
Name: proportion, dtype: float64

#### Label transfer with kNN

In [10]:
kNNclf = LabelTransferWithKNN().fit(adataR)

In [11]:
adataQ = kNNclf.predict(query_adata = adataQ, label_keys = ['Level1'])
adataQ

AnnData object with n_obs × n_vars = 572872 × 200
    obs: 'studyID', 'libraryID', 'sampleID', 'chemistry', 'technology', 'disease', 'sex', 'binned_age', '_scvi_batch', 'Level2', '_scvi_labels', 'labels', 'Level1_scANVIpredict', 'Level1_pred', 'Level1_uncertainty'

In [12]:
adataQ.obs['Level1_kNN_pred'] = adataQ.obs['Level1_pred']

In [13]:
adataQ.obs['Level1_kNN_pred'].value_counts(normalize = True)

Level1_kNN_pred
Mono              0.235677
T_CD4_Naive       0.206100
T_CD4_NonNaive    0.170462
T_CD8_NonNaive    0.107530
B                 0.079892
ILC               0.070147
T_CD8_Naive       0.060794
UTC               0.026708
Platelets         0.021612
DC                0.009248
Cycling_cells     0.004284
pDC               0.004238
Progenitors       0.001290
Plasma            0.001231
RBC               0.000787
Name: proportion, dtype: float64

In [14]:
from sklearn.metrics import adjusted_rand_score as ari

In [15]:
ari(adataQ.obs['Level1_kNN_pred'],adataQ.obs['Level1_scANVIpredict'])

0.7699958962974668

#### Generating PSEUDOBULKs

In [16]:
adataPB_R = aggregating_features(Z = adataR.X, 
                             obsDF = adataR.obs[['sampleID','Level1','disease']], 
                             mode = 'mean', 
                             obs_names_col=['sampleID','Level1'], 
                             min_observation=0)
adataPB_R

AnnData object with n_obs × n_vars = 11372 × 200
    obs: 'sampleID', 'Level1', 'disease', 'n_observation'

**Considering scANVI predicted labels**

In [17]:
adataPB_Q_scANVI = aggregating_features(Z = adataQ.X, 
                             obsDF = adataQ.obs[['sampleID','Level1_scANVIpredict','disease']], 
                             mode = 'mean', 
                             obs_names_col=['sampleID','Level1_scANVIpredict'], 
                             min_observation=0)

adataPB_Q_scANVI.obs.rename({'Level1_scANVIpredict':'Level1'}, axis=1, inplace=True)
adataPB_Q_scANVI

AnnData object with n_obs × n_vars = 1207 × 200
    obs: 'sampleID', 'Level1', 'disease', 'n_observation'

In [18]:
adataPB_Q_scANVI.obs.Level1.value_counts()

Level1
B                 86
DC                86
ILC               86
Mono              86
T_CD4_Naive       86
T_CD4_NonNaive    86
T_CD8_Naive       86
T_CD8_NonNaive    86
UTC               86
Platelets         84
Cycling_cells     81
Plasma            81
Progenitors       81
pDC               77
RBC               29
Name: count, dtype: int64

**Considering labels transferred with kNN**

In [19]:
adataPB_Q_kNN = aggregating_features(Z = adataQ.X, 
                             obsDF = adataQ.obs[['sampleID','Level1_kNN_pred','disease']], 
                             mode = 'mean', 
                             obs_names_col=['sampleID','Level1_kNN_pred'], 
                             min_observation=0)
adataPB_Q_kNN.obs.rename({'Level1_kNN_pred':'Level1'}, axis=1, inplace=True)
adataPB_Q_kNN

AnnData object with n_obs × n_vars = 1217 × 200
    obs: 'sampleID', 'Level1', 'disease', 'n_observation'

In [20]:
adataPB_Q_kNN.obs.Level1.value_counts()

Level1
B                 86
DC                86
ILC               86
Mono              86
Platelets         86
T_CD4_Naive       86
T_CD4_NonNaive    86
T_CD8_Naive       86
T_CD8_NonNaive    86
UTC               86
Cycling_cells     85
Progenitors       82
Plasma            81
pDC               78
RBC               31
Name: count, dtype: int64

### Saving pseudobulk adata objects

In [21]:
adataPATH_ref_list[0]

'/scratch_isilon/groups/singlecell/shared/projects/Inflammation-PBMCs-Atlas/03_downstream_analysis/08_PatientClassifier/scANVI/results/reference/scANVI_EXTERNAL_512_200_Level2_run1_finetuning.h5ad'

In [22]:
adataPATH_ref_list[0].replace('/reference/scANVI_','/PSEUDOBULKs/scANVI_PSEUDOBULK_reference_')

'/scratch_isilon/groups/singlecell/shared/projects/Inflammation-PBMCs-Atlas/03_downstream_analysis/08_PatientClassifier/scANVI/results/PSEUDOBULKs/scANVI_PSEUDOBULK_reference_EXTERNAL_512_200_Level2_run1_finetuning.h5ad'

In [23]:
adataPB_R.write(here(adataPATH_ref_list[0].replace('/reference/scANVI_','/PSEUDOBULKs/scANVI_PSEUDOBULK_reference_')), compression='gzip')

In [24]:
adataPB_Q_scANVI.write(here(adataPATH_query_list[0].replace('/query/scANVI_','/PSEUDOBULKs/scANVI_PSEUDOBULK_query_')), compression='gzip')

In [25]:
adataPB_Q_kNN.write(here(adataPATH_query_list[0].replace('/query/scANVI_','/PSEUDOBULKs/scANVI_PSEUDOBULK_kNN_query_')), compression='gzip')