# Thymus ageing atlas: Adding HTSA celltypist predictions

In [None]:
import os
import sys
import session_info
from datetime import datetime
today = datetime.today().strftime('%Y-%m-%d')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import anndata as ad
import hdf5plugin
import celltypist

# Add repo path to sys path (allows to access scripts and metadata from repo)
#repo_path,_ = os.path.split(os.path.split(os.getcwd())[0])
repo_path = '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis'
sys.path.insert(1, repo_path) 
sys.path.insert(2, '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts')

# Add R libs path
#os.environ['LD_LIBRARY_PATH'] = '' # Uncomment on jhub
#os.environ['R_HOME'] = '/nfs/team205/lm25/condaEnvs/thymusAgeing/lib/R' # Uncomment on jhub
os.environ['R_LIBS_USER'] = '/nfs/team205/lm25/condaEnvs/thymusAgeing/lib/R/library'

%load_ext rpy2.ipython
%load_ext autoreload
%autoreload 2

In [None]:
# Define paths
plots_path = f'{repo_path}/plots/preprocessing'
data_path = f'{repo_path}/data'
general_data_path = f'{repo_path}/data'

## Train celltypist model on previous annotations

In [None]:
adata = ad.read_h5ad(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v2_2024-06-20.zarr')

# Add cell type labels and remove cells with no labels
cell_labels = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v2_2024-08-22_curatedAnno_v4.csv', index_col=0)
adata.obs = adata.obs.join(cell_labels[[c for c in cell_labels.columns if 'taa' in c]], how = 'left')
adata = adata[~pd.isna(adata.obs['taa_l0'])]

adata

In [None]:
adata.obs['taa_l0'].value_counts()

In [None]:
sc.pp.subsample(adata, fraction=0.5, copy=False)
adata.obs['taa_l0_mod'] = ['B' if x1 == 'B' else 'T/innate' if x0 == 'Lymphoid' else x0 for x0, x1 in zip(adata.obs['taa_l0'], adata.obs['taa_l1'])]

adata.obs['taa_l0_mod'].value_counts()

In [None]:
# Log-normalise counts for celltypist
adata.X = adata.X.astype(float)
#adata.layers['counts'] = adata.X.copy().astype(int) -> only save predictions
sc.pp.normalize_per_cell(adata, counts_per_cell_after=10000)
sc.pp.log1p(adata)
sc.pp.highly_variable_genes(adata, n_top_genes=3000, n_bins=20, subset = True)

In [None]:
# Train celltypist model
model_taa_l1 = celltypist.train(adata, labels='taa_l1', n_jobs = -1, check_expression = False, feature_selection  = False)

In [None]:
model_taa_l1.write(f'{repo_path}/models/celltypist_thyAgeing_all_scvi_v2_2024-06-20_taa_l1.pkl')

## Integrate whole dataset

In [None]:
# Load data
object_version = 'v3_2024-11-04'
adata = ad.read_h5ad(f'{general_data_path}/objects/rna/thyAgeing_all_filtered_{object_version}.zarr')

# Add cell type labels and remove cells with no labels
cell_labels = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v2_2024-08-22_curatedAnno_v4.csv', index_col=0)
adata.obs = adata.obs.join(cell_labels[[c for c in cell_labels.columns if 'taa' in c]], how = 'left')

In [None]:
# Specify version
object_version = f'v3_{today}'

# Run scvi
from scvi_wrapper import run_scvi
scvi_run = run_scvi(adata, 
                    layer_raw = 'X', 
                    # Excluded genes
                    include_genes=[], exclude_cc_genes=True, exclude_mt_genes=True, 
                    exclude_vdjgenes = True, remove_cite = False,
                    # Highly variable gene selection
                    batch_hv="study", hvg = 10000, span = 0.5,
                    # scVI 
                    batch_scvi="sample",
                    cat_cov_scvi=["donor", "chemistry_simple", "sex"], 
                    #cont_cov_scvi=["percent_mito", 'percent_ribo', 'n_genes'], # ["percent_mito", 'percent_ribo', 'n_genes']
                    max_epochs=50, batch_size=2000, #early_stopping = True, early_stopping_patience = 15, early_stopping_min_delta = 10.0,
                    plan_kwargs = {'lr': 0.001, 'reduce_lr_on_plateau' : True, 'lr_patience' : 10, 'lr_threshold' : 20}, 
                    n_layers = 3, n_latent = 30, dispersion = 'gene-batch',
                    # Leiden clustering
                    leiden_clustering = None, col_cell_type = ['taa_l5', 'taa_l1', 'taa_l0'],
                    fig_dir = f'{plots_path}/scvi', fig_prefix = f'thyAgeing_all_scvi_{object_version}')

In [None]:
# Save adata and scvi model
model_path = f'{repo_path}/models'
overwrite = True

# Convert object dtypes to string
for c in scvi_run['data'].obs.columns:
    if scvi_run['data'].obs[c].dtypes == 'object':
        scvi_run['data'].obs[c] = scvi_run['data'].obs[c].astype('|S')

anno_cols = [c for c in scvi_run['data'].obs.columns if '_pred_' in c or '_prob_' in c or 'taa' in c or 'leiden' in c]
if not os.path.exists(f'{data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr') or overwrite:
    print('Saving new embeddings: {}'.format(object_version))
    scvi_run['data'].obs = scvi_run['data'].obs.drop(columns=anno_cols)
    scvi_run['data'].write_h5ad(
        f'{data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr',
        compression=hdf5plugin.FILTERS["zstd"],
        compression_opts=hdf5plugin.Zstd(clevel=5).filter_options,
    )
    scvi_run['vae'].save(f'{model_path}/thyAgeing_all_scvi_{object_version}', save_anndata=False, overwrite=overwrite)
else:
    print('File already exists')

## Predict compartment-level cell type labels on whole dataset

In [None]:
# Load data
object_version = 'v3_2024-11-05'
adata = ad.read_h5ad(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr')

In [None]:
# Load model
#model_taa_l1 = celltypist.Model.load(f'{repo_path}/models/celltypist_thyAgeing_all_scvi_v2_2024-06-20_taa_l1.pkl')
model_taa_l1 = celltypist.Model.load(f'{repo_path}/models/Celltypist_mod_thyAgeing_all_scvi_v2_mrk_sel_2024-11-03.pkl')

# Subset anndata to features present in model 
adata = adata[:, model_taa_l1.features]

adata.shape

In [None]:
# Log-normalise counts for celltypist
adata.X = adata.X.astype(float)
sc.pp.normalize_per_cell(adata, counts_per_cell_after=10000)
sc.pp.log1p(adata)

In [None]:
# Celltypist predictions
predictions = celltypist.annotate(adata, model=model_taa_l1, majority_voting=True, min_prop = 0.15)
predictions.to_adata(prefix='celltypist_taa_l1_')

In [None]:
adata

In [None]:
celltypist_predictions = adata.obs[['celltypist_taa_l1_predicted_labels', 'celltypist_taa_l1_over_clustering', 'celltypist_taa_l1_majority_voting', 'celltypist_taa_l1_conf_score']]
celltypist_predictions.columns = ['celltypist_pred_taa_l1', 'celltypist_over_clustering_taa_l1', 'celltypist_mv_pred_taa_l1', 'celltypist_prob_taa_l1']
celltypist_predictions

In [None]:
celltypist_predictions.to_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_{object_version}_celltypist_taa_l1.csv')

In [None]:
# Inspect prediction probabilities by cell type

# Generate a histogram of celltypist_prob_taa_l1, facetted by celltypist_pred_taa_l1
g = sns.FacetGrid(celltypist_predictions, col="celltypist_pred_taa_l1", col_wrap=4, sharex=False, sharey=False)
g.map(plt.hist, "celltypist_prob_taa_l1", bins=30, color="blue", edgecolor="black")
plt.show()

## Sanity-check celltypist annotations

In [None]:
# Load data
object_version = 'v3_2024-11-05'
adata = ad.read_h5ad(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr')

celltypist_predictions = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_{object_version}_celltypist_taa_l1.csv', index_col=0)
adata.obs = adata.obs.join(celltypist_predictions, how = 'left')

# Add cell type labels and remove cells with no labels
cell_labels = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v2_2024-08-22_curatedAnno_v4.csv', index_col=0)
adata.obs = adata.obs.join(cell_labels[[c for c in cell_labels.columns if 'taa' in c]], how = 'left')

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

confusion_df = adata.obs[['celltypist_mv_pred_taa_l1', 'taa_l1']].dropna()

# Compute confusion matrix
conf_matrix = confusion_matrix(confusion_df['taa_l1'], confusion_df['celltypist_mv_pred_taa_l1'], labels=confusion_df['taa_l1'].unique(), normalize='true')

# Create a DataFrame for the confusion matrix
conf_matrix_df = pd.DataFrame(conf_matrix, index=confusion_df['taa_l1'].unique(), columns=confusion_df['taa_l1'].unique())

# Plot the confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix_df, annot=True, fmt='.2f', cmap='Blues')
plt.xlabel('Predicted Labels MV')
plt.ylabel('Labels before')
plt.title('Confusion Matrix')
plt.show()
plt.savefig(f'{plots_path}/ctAnnotation/celltypist_confusion_matrix_mv_taa_l1.png')

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

confusion_df = adata.obs[['celltypist_pred_taa_l1', 'taa_l1']].dropna()

# Compute confusion matrix
conf_matrix = confusion_matrix(confusion_df['taa_l1'], confusion_df['celltypist_pred_taa_l1'], labels=confusion_df['taa_l1'].unique(), normalize='true')

# Create a DataFrame for the confusion matrix
conf_matrix_df = pd.DataFrame(conf_matrix, index=confusion_df['taa_l1'].unique(), columns=confusion_df['taa_l1'].unique())

# Plot the confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix_df, annot=True, fmt='.2f', cmap='Blues')
plt.xlabel('Predicted Labels MV')
plt.ylabel('Labels before')
plt.title('Confusion Matrix')
plt.show()
plt.savefig(f'{plots_path}/ctAnnotation/celltypist_confusion_matrix_taa_l1.png')

In [None]:
sc.pl.umap(adata, color=['celltypist_pred_taa_l1', 'celltypist_mv_pred_taa_l1', 'celltypist_prob_taa_l1'], ncols=2, wspace=0.5)
plt.savefig(f'{plots_path}/ctAnnotation/thyAgeing_all_scvi_{object_version}_celltypist_taa_l1.png', bbox_inches='tight')

## Save compartment splits

In [None]:
# Load data
object_version = 'v3_2024-11-05'
adata = ad.read_h5ad(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr')

celltypist_predictions = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_{object_version}_celltypist_taa_l1.csv', index_col=0)
adata.obs = adata.obs.join(celltypist_predictions, how = 'left')

# Add cell type labels and remove cells with no labels
cell_labels = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v2_2024-08-22_curatedAnno_v4.csv', index_col=0)
adata.obs = adata.obs.join(cell_labels[[c for c in cell_labels.columns if 'taa' in c]], how = 'left')

In [None]:
np.array(adata.obs['celltypist_mv_pred_taa_l1'].unique())

In [None]:
compartment_dict = {'T_NK' : ['T', 'NK', 'ILC'],
                    'B' : ['B'],
                    'Myeloid' : ['Myeloid_dev', 'Neutrophil', 'DC', 'Mac', 'Mono', 'Mast'],
                    'FB_Vasc' : ['Fb', 'RBC', 'Pericyte', 'SMC', 'Mesothelium', 'EC'],
                    'TEC' : ['TEC']}

# Check if all predictions are present in some compartment
np.setdiff1d(np.array(adata.obs['celltypist_mv_pred_taa_l1'].unique()), np.array([c for v in compartment_dict.values() for c in v]))

In [None]:
for k,v in compartment_dict.items():
    
    adata_sub = adata[adata.obs['celltypist_mv_pred_taa_l1'].isin(v)].copy()
    adata_sub.obs.drop(columns = [c for c in adata.obs.columns if 'celltypist' in c], inplace = True)
    
    adata_sub.write_h5ad(
                f'{general_data_path}/compartmentSplits/thyAgeing_{k.lower()}Split_scvi_{object_version}.zarr',
                compression=hdf5plugin.FILTERS["zstd"],
                compression_opts=hdf5plugin.Zstd(clevel=5).filter_options
                )