# Thymus ageing atlas - T/NK compartment : knn-transfer of TAA cell labels

In [None]:
import os
import sys
import session_info

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import anndata as ad
import hdf5plugin
from sklearn.metrics import f1_score

# Add repo path to sys path (allows to access scripts and metadata from repo)
repo_path = '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/B_compartment'
sys.path.insert(1, repo_path) 
sys.path.insert(2, '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts')

# Add R libs path
#os.environ['LD_LIBRARY_PATH'] = '' # Uncomment on jhub
#os.environ['R_HOME'] = '/nfs/team205/lm25/condaEnvs/thymusAgeing/lib/R' # Uncomment on jhub
os.environ['R_LIBS_USER'] = f'{os.path.split(sys.path[0])[0]}/R/library'

%load_ext rpy2.ipython
%load_ext autoreload
%autoreload 2

from annotate_ct import get_kNN_predictions

In [None]:
# Define paths
plots_path = f'{repo_path}/plots/'
data_path = f'{repo_path}/data/'
model_path = os.path.join(repo_path, 'models')
general_data_path = '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/data'

print('Dir for plots: {}'.format(plots_path))
print('Dir for data: {}'.format(data_path))

## TAA-based labels

In [None]:
object_version = 'v4_2024-11-06'
adata = ad.read_h5ad(f'{data_path}/objects/rna/thyAgeing_bSplit_scvi_{object_version}.zarr')

leiden_clus = pd.read_csv(f'{data_path}/objects/rna/thyAgeing_bSplit_scvi_{object_version}_leidenClusters.csv', index_col=0)
adata.obs.drop(leiden_clus.columns, axis = 1, errors = 'ignore', inplace = True)
adata.obs = adata.obs.join(leiden_clus)
adata.obs[leiden_clus.columns] = adata.obs[leiden_clus.columns].astype('category')

# Add previous TAA annotations
ct_labels = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v2_2024-06-20_curatedAnno_v6.csv', index_col = 0, dtype = 'category')
adata.obs.drop(ct_labels.columns, axis = 1, errors = 'ignore', inplace = True)
adata.obs = adata.obs.join(ct_labels)

adata

In [None]:
adata.obs['taa_l5'].value_counts()

### Create reference dataset

In [None]:
adata.obs['taa_l1'].value_counts()

In [None]:
# Prepare reference 
adata_ref = adata[adata.obs['taa_l1'].isin(['B'])]
adata_ref

In [None]:
# Remove cells which are not abundant in the reference (n < 5)
adata_ref.obs['taa_l5'].value_counts().to_frame()
ct_remove = adata_ref.obs['taa_l5'].value_counts().loc[adata_ref.obs['taa_l5'].value_counts() < 5].index.tolist()
print(f'Removing cell types with n < 5: {ct_remove}')
adata_ref.obs['taa_l5'].value_counts().to_frame()

In [None]:
# Remove ct_remove from adata
adata_ref = adata_ref[~adata_ref.obs['taa_l5'].isin(ct_remove)]

In [None]:
# Check whether split is correct
adata_ref.obs['study'].value_counts()

In [None]:
# Inspect ref UMAP
p = sc.pl.umap(adata_ref, color = 'taa_l5', ncols = 1, legend_fontsize = 4,
           save = False, return_fig = True, show = False)
plt.savefig(f'{plots_path}/ctAnnotation/thyAgeing_bSplit_scvi_{object_version}_htsaKnnRef_umap.png', dpi=300, bbox_inches='tight')

### Create query dataset

In [None]:
# Prepare query with cells whose barcodes are not in the reference
ref_barcodes = adata_ref.obs_names.tolist()
adata_query = adata[~adata.obs_names.isin(ref_barcodes)]

# Remove old annotations
adata_query.obs = adata_query.obs.drop(columns=ct_labels.columns)

adata_query

In [None]:
# Check whether all cells are either in ref or query
adata.shape[0] == adata_ref.shape[0] + adata_query.shape[0]

### Predict labels

#### Determine optimal k

In [None]:
adata_ref.shape[0]/adata.shape[0]

In [None]:
# Split reference data set into two parts
adata_ref_ref = sc.pp.subsample(adata_ref, fraction = 0.5, copy = True)
adata_ref_query = adata_ref[adata_ref.obs_names.isin(adata_ref_ref.obs_names) == False].copy()

print('Number of cells in reference adata: {}'.format(adata_ref_ref.shape[0]))
print('Number of cells in query adata: {}'.format(adata_ref_query.shape[0]))

In [None]:
# Evaluate which k to use for predictions
from sklearn.metrics import f1_score

out_res = {}
f1_scores = {}

k_vals = [10, 20, 30, 40, 50, 75, 100]
for k in k_vals:
    # Transfer labels
    out_res[str(k)] = get_kNN_predictions(adata_ref_ref, adata_ref_query, "X_scVI", k, "taa_l5")
    lab = out_res[str(k)][0]
    # Calculating F1 score
    f1_scores[str(k)] = f1_score(adata_ref_query.obs['taa_l5'], lab['taa_l5'], average='weighted')

In [None]:
# Inspect F1 scores for different k values: k = 30 seems to be the best
f1_scores

#### Predict labels

In [None]:
# Test which k is best to use for kNN
labels, uncert = get_kNN_predictions(adata_ref, adata_query, 'X_scVI', 20, 'taa_l5')

In [None]:
# Add labels and uncertainties to adata
import re
pattern = r'(taa_.+)'

for i in range(labels.shape[1]):
    
    col_substring = re.search(pattern,labels.columns[i]).group(0)
    
    # Add annotations for query
    adata_query.obs['knn_pred_' + col_substring] = labels[col_substring]
    adata_query.obs['knn_prob_' + col_substring] = (1-uncert[col_substring])

    # Add annotations for reference
    adata_ref.obs['knn_pred_' + col_substring] = adata_ref.obs[col_substring]

In [None]:
# Inspect query UMAP
p = sc.pl.umap(adata_query, color = ['knn_pred_taa_l5', 'knn_prob_taa_l5'], ncols = 1, legend_loc = "on data", legend_fontsize = 4,
           save = False, return_fig = True, show = False)
plt.savefig(f'{plots_path}/ctAnnotation/thyAgeing_bSplit_scvi_{object_version}_htsaKnnQuery_umap.png', dpi=300, bbox_inches='tight')

In [None]:
# Reformat columns
cols_cat = adata_query.obs.columns[adata_query.obs.columns.str.startswith("knn_pred")]
cols_num = adata_query.obs.columns[adata_query.obs.columns.str.startswith("knn_prob")]

all_annot = pd.concat((adata_query.obs[adata_query.obs.columns[adata_query.obs.columns.str.startswith('knn_')]],
                       adata_ref.obs[adata_ref.obs.columns[adata_ref.obs.columns.str.startswith('knn_')]]))

all_annot[cols_cat] = all_annot[cols_cat].astype('category')
all_annot[cols_num] = all_annot[cols_num].astype(float)

# Save annotations
all_annot.to_csv(f'{data_path}/objects/thyAgeing_bSplit_scvi_{object_version}_htsaKnnAnnot.csv')

In [None]:
all_annot = pd.read_csv(f'{data_path}/objects/thyAgeing_bSplit_scvi_{object_version}_htsaKnnAnnot.csv').set_index('names')

In [None]:
all_annot

#### Add matched l0-l4 for l4_explore

In [None]:
# Load paired annotations
anno_pairing = pd.read_csv(f'{data_path}/curated/thyAgeing_htsa_matchedAnnoLevels.csv').dropna()

# Save annotation using only l4_explore and matched l0-l4
all_annot_matched = all_annot[['knn_pred_htsa_l4_explore', 'knn_prob_htsa_l4_explore']].reset_index(names='names')
all_annot_matched = all_annot_matched.merge(anno_pairing, on='knn_pred_htsa_l4_explore', how='left').set_index('names')
all_annot_matched = all_annot_matched[['knn_pred_htsa_l0','knn_pred_htsa_l1','knn_pred_htsa_l2','knn_pred_htsa_l3','knn_pred_htsa_l4', 'knn_pred_htsa_l4_explore', 'knn_prob_htsa_l4_explore']]

all_annot_matched.to_csv(f'{data_path}/objects/thyAgeing_bSplit_scvi_{object_version}_htsaKnnAnnot_matched.csv')

all_annot_matched.drop(columns='knn_prob_htsa_l4_explore').drop_duplicates().shape[0] == anno_pairing.shape[0]

### Diagnostic plots: prediction uncertainty

In [None]:
# Load data
object_version = 'v5_2024-04-03'
adata = ad.read_h5ad(f'{data_path}/objects/thyAgeing_bSplit_scvi_{object_version}.zarr')

# Add knn predictions to adata (original HTSA reference does not have uncertainties)
knn_predictions = pd.read_csv(f'{data_path}/objects/thyAgeing_bSplit_scvi_{object_version}_htsaKnnAnnot.csv')
adata.obs = adata.obs.join(knn_predictions.set_index('names'))

In [None]:
# UMAPs
prob_cols = [n for n in adata.obs.columns if 'knn_prob_' in n]
pred_cols = [n for n in adata.obs.columns if 'knn_pred_' in n]

sc.pl.umap(adata, color = prob_cols, ncols = 2, legend_loc = "on data", legend_fontsize = 4,
           return_fig = True, show = False)
plt.savefig(f'{plots_path}/ctAnnotation/thyAgeing_bSplit_htsaKnnQuery_prob_umap.png', dpi=300, bbox_inches='tight')

In [None]:
# Boxplots
prob_df = adata.obs[prob_cols + pred_cols + ['age_group2']]
prob_df.dropna(subset = prob_cols, inplace = True)

prob_df.head()

In [None]:
%%R -i prob_df,plots_path -w 300 -h 150 -u mm 

level = c('htsa_l0', 'htsa_l1', 'htsa_l2', 'htsa_l3', 'htsa_l4', 'htsa_l4_explore')

for (l in level) {
    
    pred_col = paste0('knn_pred_', l)
    prob_col = paste0('knn_prob_', l)

    prob_df %>%
    dplyr::filter(!!sym(pred_col) != 'see_lv4_explore') %>%
    dplyr::mutate(age_group2 = factor(age_group2, levels = c('infant', 'paediatric', 'paed(mid)', 'paed(late)', 'adult(young)', 'adult(middle)', 'adult(aged)'))) %>%
    ggplot(aes(x = !!sym(pred_col), y = !!sym(prob_col), fill = age_group2)) +
    geom_boxplot(outlier.size = 0.5, position = position_dodge2(preserve = 'single')) +
    ggsci::scale_fill_d3() +
    scale_y_continuous(limits = c(0, 1), labels = scales::percent_format()) +
    theme_simple() +
    theme(axis.text.x = element_text(angle = 90, hjust = 1))   
    ggsave(paste0('thyAgeing_bSplit_htsaKnnQuery_prob_', l, '_boxplot.png'), path = file.path(plots_path, 'ctAnnotation'),
        width = 300, height = 150, units = 'mm') 
    
}

## Marker expression by cluster

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
b_markers = {'B_cells': ['CD79A', 'TCL1A'],
'B_IFN': ['MX1','IFI44L', 'STAT1'], 
'B_naive': ["FCER2", "BANK1", "FCMR"], 
'B_transitional' : ["CD24", 'MYO1C', 'MS4A1'],
'B_activated': ['CD69','FOS','FOSB','DUSP1','CD83'], 
'B_preGC': ["MIR155HG", "HIVEP3", "PARVB"],
'B_GC': ["GMDS", "LMO2", "LPP", "BCL6", "AICDA", "H2AFZ", "MKI67", 'POU2AF1', 'CD40', 'SUGCT'], 
'B_pre-pro': ['IL7R', 'ZCCHC7', 'RAG1'],
'B_pro': ['MME', 'DNTT', 'IGLL1'],
'B_small-pre': ['MME', "CD24",],
'B_large-pre': ['MME', 'CD24','MKI67'],
'B_cycling': ['TOP2A', 'CD19', 'MKI67'], 
'B_follicular' : ['CXCR5', 'TNFRSF13B', 'CD22'],
'B_prePB': ["FRZB", "BTNL9", "HOPX"], 
'B_dev' : ['SPN', 'VPREB1'],
'B_plasma': ["XBP1", "PRDM1", "FKBP11"], 
'B_mem': ["TNFRSF13B", "FCRL4", "CLECL1", 'CR2', 'CD27', 'MS4A1'],
'B_age-associated' : ['FCRL2', 'ITGAX', 'TBX21'],
'perivasc_B': ['CXCR3', 'CR2', 'CD72' , 'CD37'],
'med_B': ['CD80', 'CD83' , 'CD86', 'HLA-DRA', 'AIRE', 'IL15', 'LTA', 'LTB'],
'T_cell': ['CD3E', 'CD8A', 'CD3D']}

gc_markers = {'DZ' : ['CXCR4', 'BACH2','PCNA', 'MKI67', 'CDK1', 'CDC20', 'FOXP1', 'AICDA', 'MYC', 'EZH2', 'E2F1', 'FOXO1', 'BCL6'], 
    'LZ': ['CD83', 'SERPINA9', 'CAMK1', 'MYC', 'RGS13', 'CD44', 'CD38', 'LMO2', 'EBI3', 'HLA-DQB2', 'TRAF4', 'PLEK', 'IER2', 'NFKBIA', 'BCAR3', 'DUSP2', 'SNX11', 'PLPP5', 'PHACTR1', 'TAP1', 
    'RAB3GAP2', 'DHRS9', 'FCRL5'],
    're-entry' : ['SLA', 'FCRL2', 'CFLAR', 'FOXP1'],                    
    'bcr_activation' : ['BTK', 'BLK','BLNK'],
    'TFh_INF_help' : ['CD40', 'TRAF1', 'ICAM1', 'NFKB1', 'NFKB2', 'REL', 'RELB'],             
    'pre-mem': ['BANK1', 'CCR6','CELF2', 'IFITM1', 'IFITM2', 'IFNGR1','GPR183','CD69', 'TNFRSF13B', 'SELL', 'MYC', 'FXYD5', 'STAT1'], 
    'exit' : ['MEF2B', 'RGS13', 'S1PR2'],   
    'LZ_plasmablasts' : ['PAX5', 'CD27', 'TNFSF13','CD9', 'PRDM1', 'XBP1', 'MZB1', 'TNFRSF17', 'FKBP11'], 
    'DZ/LZ' : ['NFKBIA', 'BCAR3', 'DUSP2', 'SNX11', 'PLPP5', 'PHACTR1', 'TAP1',
    'PCNA', 'MKI67', 'CDK1', 'CDC20', 'CD72', 'PTPN6', 'IFNGR1', 'CAMK1', 'CD22'],
    'DN' : ['RAB3GAP2', 'DHRS9',
    'FCRL5', 'SLAMF7', 'CD22', 'PDCD1', 'TBX21', 'ZEB2', 'CD19',  'IL12A'], 
    'misc' : ['IGHM', 'IGHD', 'IGHE', 'IGHA1','CCR2', 'RAG1', 'RAG2']}

In [None]:
# Plot markers
sc.pl.DotPlot(adata, 
            groupby='leiden_r2.5',
            var_names=b_markers,
            mean_only_expressed=True,
            cmap = 'viridis',).add_totals().savefig(f'{plots_path}/ctAnnotation/thyAgeing_bSplit_{object_version}_bMarkers_dotplot.png', dpi=300, bbox_inches='tight')

In [None]:
# Plot markers
sc.pl.DotPlot(adata, 
            groupby='leiden_r2.5',
            var_names=gc_markers,
            mean_only_expressed=True,
            cmap = 'viridis',).add_totals().savefig(f'{plots_path}/ctAnnotation/thyAgeing_bSplit_{object_version}_gcMarkers_dotplot.png', dpi=300, bbox_inches='tight')

In [None]:
b_dict = {'B_GC' : [17],
          'B_GC_plasma' : [34],
          'B_GC_IFN' : [36],
          'B_plasma' : [3,12,41]}

In [None]:
adata[adata.obs['leiden_r2.5'].isin([17,36,34])].obs['donor'].value_counts().to_frame().reset_index().merge(adata.obs[['donor','age_group']].drop_duplicates(), on='donor')

In [None]:
sc.pl.umap(adata, color = 'leiden_r2.5', return_fig = True, show = False)
plt.savefig(f'{plots_path}/preprocessing/scvi/thyAgeing_bSplit_{object_version}_leidenClusters.png', dpi=300, bbox_inches='tight')

In [None]:
adata.obs['leiden_group'] = adata.obs['leiden_r2.5'].astype(str)
sc.tl.rank_genes_groups(adata, groupby='leiden_group', groups = ['22','28','29','19'], method='wilcoxon', n_genes=50)

In [None]:
sc.tl.rank_genes_groups(adata, groupby='leiden_group', groups = ['17'], method='wilcoxon', n_genes=50)

In [None]:
sc.get.rank_genes_groups_df(adata, group='19').head(20)

In [None]:
sc.get.rank_genes_groups_df(adata, group='22').head(20)

- TCL1A: proto-oncogene in B cell lymphomas, critical for TLS formation in OSCC
- BTG1: mutations affecting BTG1 disrupt a critical immune gatekeeper mechanism that strictly limits B cell fitness during antibody affinity maturation. This mechanism converted germinal center B cells into supercompetitors that rapidly outstrip their normal counterparts

In [None]:
sc.get.rank_genes_groups_df(adata, group='28').head(20)

In [None]:
sc.get.rank_genes_groups_df(adata, group='29').head(20)

In [None]:
sc.pl.umap(adata, color = 'AIRE')

In [None]:
adata[adata.obs['leiden_r2.5'].isin([2,28,29])].obs['study'].value_counts()

In [None]:
adata.obs['age_months'] = adata.obs['age_months'].astype(float)

In [None]:
adata.obs['study'].unique()

In [None]:
sc.pl.umap(adata, color = ['age_group', 'study', 'age_months'])

In [None]:
sc.pl.umap(adata, color = ['percent_mito', 'n_genes', 'percent_ribo', 'n_counts', 'scrublet_score'], cmap = 'jet')

In [None]:
b_markers

In [None]:
sc.pl.umap(adata, color = b_markers['B_cells'] + b_markers['B_naive'] + b_markers['B_activated'], cmap = 'viridis')

In [None]:
sc.pl.umap(adata, color = ["RPL28", 'RPL41', 'RPS15A', 'RPL18A', 'RPL12', 'MT2A', 'MT1E'])

## Celltypist immune low model

In [None]:
object_version = 'v4_2024-11-06'
adata = ad.read_h5ad(f'{data_path}/objects/rna/thyAgeing_bSplit_scvi_{object_version}.zarr')

leiden_clus = pd.read_csv(f'{data_path}/objects/rna/thyAgeing_bSplit_scvi_{object_version}_leidenClusters.csv', index_col=0)
adata.obs.drop(leiden_clus.columns, axis = 1, errors = 'ignore', inplace = True)
adata.obs = adata.obs.join(leiden_clus)
adata.obs[leiden_clus.columns] = adata.obs[leiden_clus.columns].astype('category')

In [None]:
# Log-normalise counts for celltypist
adata.X = adata.X.astype(float)
#adata.layers['counts'] = adata.X.copy().astype(int) -> only save predictions
sc.pp.normalize_per_cell(adata, counts_per_cell_after=10000)
sc.pp.log1p(adata)

In [None]:
import celltypist
immune_low_model = celltypist.models.Model.load(model = 'Immune_All_Low.pkl')

# Celltypist predictions: Immune low
predictions = celltypist.annotate(adata, model=immune_low_model, majority_voting=True)
predictions.to_adata(prefix='celltypist_immune_low_')

In [None]:
celltypist_predictions = adata.obs[adata.obs.columns[adata.obs.columns.str.startswith('celltypist_immune_low_')]] \
    .rename(columns = {'celltypist_immune_low_predicted_labels' : 'celltypist_pred_immune_low', 
                       'celltypist_immune_low_over_clustering' : 'celltypist_mv_overclustering_pred_immune_low',
                       'celltypist_immune_low_majority_voting' : 'celltypist_mv_pred_immune_low',
                       'celltypist_immune_low_conf_score' : 'celltypist_prob_immune_low'})

celltypist_predictions

In [None]:
celltypist_predictions.to_csv(f'{data_path}/objects/rna/thyAgeing_bSplit_scvi_{object_version}_celltypistImmuneLowAnnot.csv')

In [None]:
sc.pl.umap(adata, color = ['celltypist_immune_low_majority_voting','celltypist_immune_low_predicted_labels', 'celltypist_immune_low_conf_score'], wspace = 0.5, legend_loc = 'on data', legend_fontsize = 6, ncols = 2, return_fig = True)
plt.savefig(f'{plots_path}/ctAnnotation/thyAgeing_bSplit_scvi_{object_version}_celltypistImmuneLow_umap.png', dpi=300, bbox_inches='tight')