# Thymus ageing atlas - T/NK compartment : knn-transfer of TAA cell labels

In [None]:
import os
import sys
import session_info

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import anndata as ad
import hdf5plugin
from sklearn.metrics import f1_score

# Add repo path to sys path (allows to access scripts and metadata from repo)
repo_path = '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/T_NK_compartment'
sys.path.insert(1, repo_path) 
sys.path.insert(2, '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts')

%load_ext autoreload
%autoreload 2

from annotate_ct import get_kNN_predictions

In [None]:
# Define paths
plots_path = f'{repo_path}/plots/'
data_path = f'{repo_path}/data/'
model_path = os.path.join(repo_path, 'models')
general_data_path = '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/data'

print('Dir for plots: {}'.format(plots_path))
print('Dir for data: {}'.format(data_path))

## TAA-based labels

In [None]:
object_version = 'v8_2024-11-07'
adata = ad.read_h5ad(f'{data_path}/objects/rna/thyAgeing_tSplit_scvi_{object_version}.zarr')

leiden_clus = pd.read_csv(f'{data_path}/thyAgeing_tSplit_scvi_{object_version}_leidenClusters.csv', index_col=0)
if any(leiden_clus.columns.isin(adata.obs.columns)):
    adata.obs.drop(leiden_clus.columns, axis = 1, inplace = True)
adata.obs = adata.obs.join(leiden_clus)
adata.obs[leiden_clus.columns] = adata.obs[leiden_clus.columns].astype('category')

# Add celltypist predictions to adata
celltypist_predictions = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v3_2024-11-05_celltypist_taa_l1.csv', index_col=0)
adata.obs = adata.obs.join(celltypist_predictions, how = 'left')

# Add previous TAA annotations
ct_labels = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v2_2024-06-20_curatedAnno_v6.csv', index_col = 0, dtype = 'category')
adata.obs.drop(ct_labels.columns, axis = 1, errors = 'ignore', inplace = True)
adata.obs = adata.obs.join(ct_labels)

adata

In [None]:
sc.pl.umap(adata, color = 'taa_l5')

In [None]:
adata.obs['taa_l5'].value_counts()

### Create reference dataset

In [None]:
adata.obs['taa_l1'].value_counts()

## Prepare reference

In [None]:
# Prepare reference 
adata_ref = adata[adata.obs['taa_l1'].isin(['T', 'ILC', 'NK'])]
adata_ref

In [None]:
# Remove cells which are not abundant in the reference (n < 5)
adata_ref.obs['taa_l5'].value_counts().to_frame()
ct_remove = adata_ref.obs['taa_l5'].value_counts().loc[adata_ref.obs['taa_l5'].value_counts() < 5].index.tolist()
print(f'Removing cell types with n < 5: {ct_remove}')
adata_ref.obs['taa_l5'].value_counts().to_frame()

In [None]:
# Remove ct_remove from adata
adata_ref = adata_ref[~adata_ref.obs['taa_l5'].isin(ct_remove)]

In [None]:
# Check whether split is correct
adata_ref.obs['study'].value_counts()

In [None]:
# Inspect ref UMAP
p = sc.pl.umap(adata_ref, color = 'taa_l5', ncols = 1, legend_fontsize = 4,
           save = False, return_fig = True, show = False)
plt.savefig(f'{plots_path}/ctAnnotation/thyAgeing_tSplit_scvi_{object_version}_htsaKnnRef_umap.png', dpi=300, bbox_inches='tight')

### Create query dataset

In [None]:
# Prepare query with cells whose barcodes are not in the reference
ref_barcodes = adata_ref.obs_names.tolist()
adata_query = adata[~adata.obs_names.isin(ref_barcodes)]

# Remove old annotations
adata_query.obs = adata_query.obs.drop(columns=ct_labels.columns, errors = 'ignore')

adata_query

In [None]:
# Check whether all cells are either in ref or query
adata.shape[0] == adata_ref.shape[0] + adata_query.shape[0]

### Predict labels

#### Determine optimal k

In [None]:
adata_ref.shape[0]/adata.shape[0]

In [None]:
# Split reference data set into two parts
adata_ref_ref = sc.pp.subsample(adata_ref, fraction = 0.7, copy = True)
adata_ref_query = adata_ref[adata_ref.obs_names.isin(adata_ref_ref.obs_names) == False].copy()

print('Number of cells in reference adata: {}'.format(adata_ref_ref.shape[0]))
print('Number of cells in query adata: {}'.format(adata_ref_query.shape[0]))

In [None]:
# Evaluate which k to use for predictions
from sklearn.metrics import f1_score

out_res = {}
f1_scores = {}

k_vals = [10, 20, 30, 40, 50, 75, 100]
for k in k_vals:
    # Transfer labels
    out_res[str(k)] = get_kNN_predictions(adata_ref_ref, adata_ref_query, "X_scVI", k, "taa_l5")
    lab = out_res[str(k)][0]
    # Calculating F1 score
    f1_scores[str(k)] = f1_score(adata_ref_query.obs['taa_l5'], lab['taa_l5'], average='weighted')

In [None]:
# Inspect F1 scores for different k values: k = 30 seems to be the best
f1_scores

#### Predict labels

In [None]:
# Test which k is best to use for kNN
labels, uncert = get_kNN_predictions(adata_ref, adata_query, 'X_scVI', 30, 'taa_l5')

In [None]:
# Add labels and uncertainties to adata
import re
pattern = r'(taa_.+)'

for i in range(labels.shape[1]):
    
    col_substring = re.search(pattern,labels.columns[i]).group(0)
    
    # Add annotations for query
    adata_query.obs['knn_pred_' + col_substring] = labels[col_substring]
    adata_query.obs['knn_prob_' + col_substring] = (1-uncert[col_substring])

    # Add annotations for reference
    adata_ref.obs['knn_pred_' + col_substring] = adata_ref.obs[col_substring]

In [None]:
# Inspect query UMAP
p = sc.pl.umap(adata_query, color = ['knn_pred_taa_l5', 'knn_prob_taa_l5'], ncols = 1, legend_loc = "on data", legend_fontsize = 4,
           save = False, return_fig = True, show = False)
plt.savefig(f'{plots_path}/ctAnnotation/thyAgeing_bSplit_scvi_{object_version}_htsaKnnQuery_umap.png', dpi=300, bbox_inches='tight')

In [None]:
# Add matched levels
all_annot = pd.concat((adata_query.obs[adata_query.obs.columns[adata_query.obs.columns.str.startswith('knn_')]],
                       adata_ref.obs[adata_ref.obs.columns[adata_ref.obs.columns.str.startswith('knn_')]]))

matched_anno_levels = pd.read_excel(f'{general_data_path}/curated/thyAgeing_full_curatedAnno_v2_2024-08-15_levels.xlsx', index_col = 0, dtype = 'category').reset_index(drop=True)
matched_anno_levels.columns = [f'knn_pred_{c}' if 'taa' in c else c for c in matched_anno_levels.columns]

all_annot = all_annot.merge(matched_anno_levels, on = 'knn_pred_taa_l5')

# Reformat columns
cols_cat = adata_query.obs.columns[adata_query.obs.columns.str.startswith("knn_pred")]
cols_num = adata_query.obs.columns[adata_query.obs.columns.str.startswith("knn_prob")]
all_annot[cols_cat] = all_annot[cols_cat].astype('category')
all_annot[cols_num] = all_annot[cols_num].astype(float)

# Save annotations
all_annot.to_csv(f'{data_path}/objects/thyAgeing_tSplit_scvi_{object_version}_taaKnnAnnot.csv')

all_annot.head()

## Celltypist immune-low model

In [None]:
object_version = 'v8_2024-11-07'
adata = ad.read_h5ad(f'{data_path}/objects/rna/thyAgeing_tSplit_scvi_{object_version}.zarr')

leiden_clus = pd.read_csv(f'{data_path}/objects/rna/thyAgeing_tSplit_scvi_{object_version}_leidenClusters.csv', index_col=0)
if any(leiden_clus.columns.isin(adata.obs.columns)):
    adata.obs.drop(leiden_clus.columns, axis = 1, inplace = True)
adata.obs = adata.obs.join(leiden_clus)
adata.obs[leiden_clus.columns] = adata.obs[leiden_clus.columns].astype('category')

In [None]:
# Log-normalise counts for celltypist
adata.X = adata.X.astype(float)
#adata.layers['counts'] = adata.X.copy().astype(int) -> only save predictions
sc.pp.normalize_per_cell(adata, counts_per_cell_after=10000)
sc.pp.log1p(adata)

In [None]:
import celltypist
immune_low_model = celltypist.models.Model.load(model = 'Immune_All_Low.pkl')

# Celltypist predictions: Immune low
predictions = celltypist.annotate(adata, model=immune_low_model, majority_voting=True)
predictions.to_adata(prefix='celltypist_immune_low_')

In [None]:
celltypist_predictions = adata.obs[adata.obs.columns[adata.obs.columns.str.startswith('celltypist_immune_low_')]] \
    .rename(columns = {'celltypist_immune_low_predicted_labels' : 'celltypist_pred_immune_low', 
                       'celltypist_immune_low_over_clustering' : 'celltypist_mv_overclustering_pred_immune_low',
                       'celltypist_immune_low_majority_voting' : 'celltypist_mv_pred_immune_low',
                       'celltypist_immune_low_conf_score' : 'celltypist_prob_immune_low'})

celltypist_predictions

In [None]:
celltypist_predictions.to_csv(f'{data_path}/objects/rna/thyAgeing_tSplit_scvi_{object_version}_celltypistImmuneLowAnnot.csv')

In [None]:
sc.pl.umap(adata, color = ['celltypist_immune_low_majority_voting','celltypist_immune_low_conf_score'], wspace = 0.5, legend_fontsize = 6, ncols = 1, return_fig = True)
plt.savefig(f'{plots_path}/ctAnnotation/thyAgeing_tSplit_scvi_{object_version}_celltypistImmuneLow_umap.png', dpi=300, bbox_inches='tight')

In [None]:
session_info.show()