# Thymus ageing atlas: abTCR VDJ pseudotime

In [None]:
import os
import sys
import session_info
from datetime import datetime
today = datetime.today().strftime('%Y-%m-%d')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import anndata as ad
import hdf5plugin
import dandelion as ddl

import warnings
warnings.filterwarnings("ignore", category=ad.ImplicitModificationWarning)

# Add repo path to sys path (allows to access scripts and metadata from repo)
#repo_path,_ = os.path.split(os.path.split(os.getcwd())[0])
repo_path = '/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/T_NK_compartment'
sys.path.insert(1, repo_path) 
sys.path.insert(2, '/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts')

# Autoreload custom scripts
%load_ext autoreload
%autoreload 2

# Define paths
plots_path = f'{repo_path}/plots/'
data_path = f'{repo_path}/data/'
model_path = os.path.join(repo_path, 'models')
general_data_path = '/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/data'

print('Dir for plots: {}'.format(plots_path))
print('Dir for data: {}'.format(data_path))

# Formatting
from matplotlib import font_manager
font_manager.fontManager.addfont("/nfs/team205/ny1/ThymusSpatialAtlas/software/Arial.ttf")
plt.style.use('/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts/plotting/thyAgeing.mplstyle')

# Import custom scripts
from utils import get_latest_version,update_obs,freq_by_donor
from anno_levels import get_ct_levels, get_ct_palette, age_group_levels, age_group_palette, t_nk_groupings
from plotting.utils import plot_grouped_boxplot, calc_figsize, thyAgeing_colors, get_tint_palette, get_chroma_palette, create_blend_palette

In [None]:
from pertpy.tools import Milo
import palantir

milo = Milo()

#required because of Palantir
%matplotlib inline

sc.settings.set_figure_params(dpi=80)

import matplotlib
matplotlib.rcdefaults()

In [None]:
# Define paths
plots_path = f'{repo_path}/plots'
data_path = f'{repo_path}/data'
model_path = os.path.join(repo_path, 'models')
general_data_path = '/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/data'

print('Dir for plots: {}'.format(plots_path))
print('Dir for data: {}'.format(data_path))

## Load and prepare data

In [None]:
# Load adata
object_version = 'v5_2025-04-03'
adata = ad.read_h5ad(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr')

# Add new annotations to adata
ct_anno = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v4_2025-02-04_curatedAnno_v10.csv', index_col = 0)
for c in ct_anno.columns:
    if c in adata.obs.columns:
        adata.obs.drop(c, axis = 1, inplace = True)
adata.obs = adata.obs.join(ct_anno)

# Filter data (only include annotated cells)
adata = adata[~adata.obs['taa_l5'].str.contains('locnt|-sp|explore', na = True)]

# Update metadata
latest_meta_path = get_latest_version(dir = f'{general_data_path}/metadata', file_prefix='Thymus_ageing_metadata')
latest_meta = pd.read_excel(latest_meta_path)
update_obs(adata, latest_meta, on = 'index', ignore_warning = True, add_cols = ['age_group3'])

# Add vdj data
meta_tcr = pd.read_csv(f'{data_path}/objects/rna/thyAgeing_tSplit_scvi_v9_2025-03-28_tcrab_v6.csv', index_col = 0)
adata.obs = adata.obs.join(meta_tcr)

In [None]:
# Specify cell type columns
col_cell_type_broad = 'taa_l3'
col_cell_type_fine = 'taa_l4'
col_cell_type_broad_levels = get_ct_levels(col_cell_type_broad, taa_l1 = ['T', 'NK'])
col_cell_type_fine_levels = get_ct_levels(col_cell_type_fine, taa_l1 = ['T', 'NK'])
col_age_group = 'age_group3'
col_age_group_levels = ['infant', 'paed(early)', 'paed(mid)', 'paed(late)', 'adult(early)', 'adult(mid)', 'adult(late)']

Filter adata to only contain libraries which were also TCR-sequenced

In [None]:
# Remove any samples which were not TCR sequenced
sample_freq = adata.obs.groupby('sample')['chain_status'].apply(lambda x: x.value_counts(normalize = True)).reset_index(name = 'prop').rename(columns = {'level_1' : 'chain_status'})
exclude_samples = sample_freq.loc[(sample_freq['chain_status'] == 'No_contig') & (sample_freq['prop'] == 1)]['sample'].unique()
print(f'Excluding {len(exclude_samples)} samples with no TCR data')
adata = adata[~adata.obs['sample'].isin(exclude_samples)]

sample_freq

In [None]:
adata.obs['chain_status'].value_counts()

In [None]:
adata.obs[col_cell_type_fine].value_counts().to_frame()

In [None]:
sc.pl.umap(adata, color = col_cell_type_fine)

In [None]:
# Set up subsets and colors
col_anno = col_cell_type_fine
ct_order = ['T_DN(early)','T_DN(P)', 'T_DN(Q)','T_DN(late)', 'T_DP(P)', 'T_DP(Q)', 'T_αβT(entry)', 'T_CD4_naive', 'T_CD8_naive', 'T_CD4_naive_recirc', 'T_CD8_naive_recirc', #'T_CD8αα', 'T_γδT',
            ]

ct_color_df = pd.DataFrame.from_dict(dict(zip(np.unique(adata.obs[col_anno].astype(str)), adata.uns[col_anno + '_colors'])), orient='index', columns=['color'])
ct_color_map = ct_color_df.loc[ct_order]['color'].to_dict()

adata = adata[adata.obs[col_anno].isin(ct_order)].copy()

In [None]:
# this step is needed as inserting new category of e.g. v_call_abT_VDJ_main_missing
for chain in ['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main']:
    adata.obs[chain] = adata.obs[chain].astype('str')

In [None]:
(adata.obs['v_call_abT_VDJ_main'] == 'nan').sum()

In [None]:
bdata = ddl.tl.setup_vdj_pseudobulk(adata, 
                                    mode='abT',
                                    subsetby=col_anno, 
                                    groups = ct_order, 
                                    productive_vdj=False, 
                                    productive_vj=False,
                                    remove_missing=False,
                                    filter_pattern = ",|None|No_contig|nan",
                                    allowed_chain_status=['Single pair','Extra pair','Orphan VDJ','Orphan VJ'] #so remove 'No_contig' and 'ambiguous' and 'Extra pair-exception'
                                   )

In [None]:
bdata.obs['chain_status'].value_counts()

In [None]:
# Remove cells which have chains present in less than 5 cells
for chain in ['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main']:
    chain_count_df = bdata.obs[chain].value_counts().to_frame(name='count').reset_index(names = chain)
    genes_to_remove = chain_count_df.loc[chain_count_df['count'] < 5][chain].tolist()
    if any(genes_to_remove):
        n_cells_removed = chain_count_df.loc[chain_count_df[chain].isin(genes_to_remove)]['count'].sum()
        print(f'Removing {len(genes_to_remove)} genes from {chain}: {genes_to_remove}. (cells removed: n = {n_cells_removed})')
        bdata = bdata[~bdata.obs[chain].isin(genes_to_remove)]

In [None]:
# check only the right contig in each chain i.e. TRAJ in j_call_abT_VJ_main
# should be 1 given there is e.g. 'v_call_abT_VDJ_main_missing' 
for x, y in zip(['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main'], ['TRBV','TRBJ','TRAV','TRAJ']):
    
    wrong_assignments = bdata.obs[x].unique()[~pd.Series(bdata.obs[x].unique()).str.startswith(y)].tolist()
    
    if any(wrong_assignments):
        if f'{x}_missing' in wrong_assignments:
            wrong_assignments.remove(f'{x}_missing')
        print(f'Incorrect assignments in {x}: {wrong_assignments}')
        bdata = bdata[~bdata.obs[x].isin(wrong_assignments)]
    

In [None]:
bdata.obs['age_group'].value_counts()

## Create nhood adata

In [None]:
age_group_dict = {'infant' : ['infant'],
                  'paed(early)' : ['paed(early)'],
                  'paed(mid-late)' : ['paed(mid)', 'paed(late)'],
                  'adult' : ['adult(early)', 'adult(mid)', 'adult(late)'],
                  }
age_group_key = 'infant'
bdata_sub = bdata[bdata.obs[col_age_group].isin(age_group_dict[age_group_key])]

bdata_sub

In [None]:
# Remove mature T cells without paired TCR
remove_unpaired = np.array(bdata_sub.obs['chain_status'].isin(['Orphan VDJ','Orphan VJ'])) & np.array(bdata_sub.obs[col_anno].isin(['T_CD4_naive','T_CD8_naive', 'T_CD4_naive_recirc','T_CD8_naive_recirc']))
print(f'Removing {sum(remove_unpaired)} cells w/o paired TCR')
bdata_sub = bdata_sub[~remove_unpaired]

In [None]:
# Construct nhood graph
sc.pp.neighbors(bdata_sub, use_rep="X_scVI", n_neighbors=50)
sc.tl.umap(bdata_sub)
sc.pl.umap(bdata_sub, color = col_anno,ncols=1, legend_loc='on data', legend_fontsize=6)

In [None]:
sc.pl.umap(bdata_sub, color = [col_anno, 'donor', 'age_group', 'sex'],ncols=2, legend_fontsize=6, wspace = 0.5)

In [None]:
# Create nhoods
milo.make_nhoods(bdata_sub)

In [None]:
for c in bdata_sub.obs.columns:
    if bdata_sub.obs[c].dtypes == 'object':
        bdata_sub.obs[c] = bdata_sub.obs[c].astype('|S')

bdata_sub.write_h5ad(
        f'{data_path}/analyses/vdj_pseudotime/thyAgeing_subAdata_{age_group_key}.zarr',
        compression=hdf5plugin.FILTERS["zstd"],
        compression_opts=hdf5plugin.Zstd(clevel=5).filter_options,
    )

### VDJ pseudobulk

In [None]:
nhood_adata = ddl.tl.vdj_pseudobulk(
    bdata_sub, pbs=bdata_sub.obsm["nhoods"], obs_to_take=[col_anno, 'donor', 'sex'], 
    renormalise=True, min_count=10, 
    extract_cols=['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main'],
)

In [None]:
# histogram to look at neighbourhood sizes 
plt.rcParams["figure.figsize"] = [4,4]
plt.hist(np.array(bdata_sub.obsm["nhoods"].sum(0)).flatten(), bins=50);
plt.title('neighborhood sizes')
plt.xlabel('number of cells in the neighborhood')
plt.ylabel('proportion (%) of neighborhoods')

In [None]:
# replace NAs with 0
if np.isnan(nhood_adata.X).any():
    print('Replacing NAs')
    nhood_adata.X = np.nan_to_num(nhood_adata.X, copy=True, nan=0.0)
else:
    print('No NAs')

In [None]:
sc.pp.pca(nhood_adata, random_state = 1712)
sc.pl.pca(nhood_adata, color=col_anno, palette=ct_color_map)

In [None]:
# Find nhoods in vdj nhood space
sc.pp.neighbors(nhood_adata, random_state = 1712)
sc.tl.umap(nhood_adata, random_state = 1712)

In [None]:
sc.pl.umap(nhood_adata, color=[col_anno,'donor', 'cell_count', 'donor_fraction', f'{col_anno}_fraction'], ncols = 2, return_fig = True, wspace = 0.5)
plt.savefig(f'{plots_path}/vdj/pseudotime/thyAgeing_vdjNhoods_covariates_{age_group_key}_umap.png', dpi=300, bbox_inches='tight')

In [None]:
# Inspect donor metadata
bdata_sub.obs.groupby(['donor', 'age_group', 'age_months', 'sex', 'study'], observed = True).size()

In [None]:
cols=['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main']
sc.pl.umap(nhood_adata, color=[col + '_missing' for col in cols if col+'_missing' in nhood_adata.var_names],color_map = 'RdYlBu_r', return_fig = True)
plt.savefig(f'{plots_path}/vdj/pseudotime/thyAgeing_vdjNhoods_missingFreq_{age_group_key}_umap.png', dpi=300, bbox_inches='tight')

In [None]:
# Save vdj nhood adata
nhood_adata.write_h5ad(f'{data_path}/analyses/vdj_pseudotime/thyAgeing_vdjNhoods_{age_group_key}.h5ad')

### GEX pseudobulk

In [None]:
from scipy.sparse import csr_matrix 

# Make X raw counts
# bdata_sub.X = bdata_sub.X.copy()

# Make GEX pseudobulk
nhood_adata_gex = ddl.tl.pseudobulk_gex(bdata_sub, pbs=bdata_sub.obsm["nhoods"], obs_to_take=[col_anno, 'donor', 'sex'])
nhood_adata_gex.X = csr_matrix(nhood_adata_gex.X)
nhood_adata_gex

In [None]:
# Save gex nhood adata
nhood_adata_gex.write_h5ad(f'{data_path}/analyses/vdj_pseudotime/thyAgeing_gexNhoods_{age_group_key}.h5ad')

### Run pseudotime analysis

In [None]:
# Define dictionary for age groups
age_group_dict = {'infant' : ['infant'],
                  'paed(early)' : ['paed(early)'],
                  'paed(mid-late)' : ['paed(mid)', 'paed(late)'],
                  'adult' : ['adult(early)', 'adult(mid)', 'adult(late)'],
                  }

# Concatenate vdj nhood adatas
nhood_adata = {k:ad.read_h5ad(f'{data_path}/analyses/vdj_pseudotime/thyAgeing_vdjNhoods_{k}.h5ad') for k in age_group_dict.keys()}
nhood_adata = ad.concat(nhood_adata, join='outer', label=col_age_group)

In [None]:
nhood_adata.obs.age_group3.value_counts()    

In [None]:
# Make obs names unique (only index anyway)
nhood_adata.obs['orig_index'] = nhood_adata.obs_names
nhood_adata.obs_names_make_unique()

In [None]:
# Order annotations
nhood_adata.obs[col_anno] = nhood_adata.obs[col_anno].astype('category')
nhood_adata.obs[col_anno] = nhood_adata.obs[col_anno].cat.reorder_categories(ct_order)

In [None]:
# Replace NAs with 0
nhood_adata.X = np.nan_to_num(nhood_adata.X, copy=True, nan=0.0)

In [None]:
sc.pp.pca(nhood_adata, random_state = 1712)
sc.pl.pca(nhood_adata, color=col_anno, palette=ct_color_map)

In [None]:
sc.pp.neighbors(nhood_adata, random_state = 42)
sc.tl.umap(nhood_adata, random_state = 42)
sc.pl.umap(nhood_adata, color=col_anno, palette=ct_color_map)

In [None]:
for c in ['donor', col_age_group]:
    nhood_adata.obs[c] = nhood_adata.obs[c].astype('category')
sc.pl.umap(nhood_adata, color=['donor','cell_count', 'donor_fraction', f'{col_anno}_fraction', col_age_group], ncols = 2, return_fig = True, wspace = 0.5)
plt.savefig(f'{plots_path}/vdj/pseudotime/thyAgeing_vdjNhoods_covariates_allAges_umap.png', dpi=300, bbox_inches='tight')

In [None]:
# Run diffusion maps
pca_projections = pd.DataFrame(nhood_adata.obsm['X_pca'], index=nhood_adata.obs_names)
dm_res = palantir.utils.run_diffusion_maps(pca_projections, n_components=20)

In [None]:
plt.scatter(np.arange(20), dm_res['EigenValues'])

In [None]:
ms_data = palantir.utils.determine_multiscale_space(dm_res, n_eigs=17)

In [None]:
cols=['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main']
sc.pl.umap(nhood_adata, color=[col + '_missing' for col in cols if col+'_missing' in nhood_adata.var_names],color_map = 'RdYlBu_r', return_fig = True)
plt.savefig(f'{plots_path}/vdj/pseudotime/thyAgeing_vdjNhoods_missingFreq_allAges_umap.png', dpi=300, bbox_inches='tight')

In [None]:
nhood_adata.obs[col_cell_type_fine].value_counts()

In [None]:
def calculate_centroid(adata, cell_population, embedding='X_umap'):
    """
    Calculate the centroid of a specific cell population in a given embedding and return the barcode of the cell closest to the centroid.

    Parameters:
    adata (anndata.AnnData): The annotated data matrix.
    cell_population (str): The cell population to calculate the centroid for.
    embedding (str): The embedding to use for calculating the centroid. Default is 'X_umap'.

    Returns:
    str: The barcode of the cell closest to the centroid.
    """
    # Filter the data for the specific cell population
    subset = adata[adata.obs[col_cell_type_fine] == cell_population]
    
    # Calculate the centroid
    centroid = subset.obsm[embedding].mean(axis=0)
    
    # Find the cell closest to the centroid
    distances = np.linalg.norm(subset.obsm[embedding] - centroid, axis=1)
    closest_cell_idx = np.argmin(distances)
    
    return subset.obs_names[closest_cell_idx]

In [None]:
# select the start and end points
# start
tmp = nhood_adata[nhood_adata.obs[col_anno] == 'T_DN(early)']
rootcell = np.argmax(tmp.obsm['X_umap'][:,1])
rootcell = tmp.obs_names[rootcell]
nhood_adata.obs['rootcell'] = 0
nhood_adata.obs.loc[rootcell,'rootcell'] = 1

In [None]:
# ends
tmp1 = nhood_adata[nhood_adata.obs[col_anno] == 'T_CD8_naive']
tmp2 = nhood_adata[nhood_adata.obs[col_anno] == 'T_CD4_naive']

endcell1 = np.argmax(tmp1.obsm['X_umap'][:,1])
endcell1 = tmp1.obs_names[endcell1]
endcell2 = np.argmax(tmp2.obsm['X_umap'][:,0])
endcell2 = tmp2.obs_names[endcell2]

endcell1 = calculate_centroid(nhood_adata, 'T_CD8_naive')
endcell2 = calculate_centroid(nhood_adata, 'T_CD4_naive')

terminal_states = pd.Series(['T_CD8_naive','T_CD4_naive'],index=[endcell1,endcell2])

In [None]:
# plot rootcell and terminal states
nhood_adata.obs['terminal_states'] = 0
nhood_adata.obs.loc[terminal_states.index, 'terminal_states'] = 1
plt.rcParams["figure.figsize"] = [5,5]
sc.pl.umap(nhood_adata,color=['rootcell','terminal_states',col_anno],
           title=['root cell','terminal states','nhood annotation'],color_map='OrRd',size=50)

In [None]:
pr_res = palantir.core.run_palantir(ms_data,  rootcell, num_waypoints=500, 
                                    terminal_states = terminal_states.index)

In [None]:
pr_res.branch_probs.columns = terminal_states[pr_res.branch_probs.columns]

In [None]:
ddl.tl.pseudotime_transfer(adata = nhood_adata, pr_res = pr_res, suffix = '_nhood_vdj')

In [None]:
plt.rcParams["figure.figsize"] = [4,4]
plot = ['pseudotime', 'prob_T_CD8_naive', 'prob_T_CD4_naive']
#plot = ['pseudotime', 'prob_CD4+T']
sc.pl.umap(nhood_adata,color=[term + '_nhood_vdj' for term in plot],
           title=['pseudotime','branch probability to CD8+T','branch probability to CD4+T'],
           frameon=False,wspace=0.1,
           color_map = 'RdYlBu_r',
           return_fig=True,
           show = False,
          )
plt.savefig(f'{plots_path}/vdj/pseudotime/thyAgeing_vdjNhoods_branchProb_allAges_umap.png', dpi=300, bbox_inches='tight')

In [None]:
plt.rcParams["figure.figsize"] = [10,4]
sc.pl.violin(nhood_adata, keys = ['pseudotime_nhood_vdj'],groupby=col_anno)

In [None]:
sns.set_theme(style='white')
fig, ax = plt.subplots(figsize=(10,2))

ax= sns.scatterplot(data=nhood_adata.obs, 
                    x ='pseudotime_nhood_vdj',
                    y='prob_T_CD4_naive_nhood_vdj',
                    s=8,
                    hue=col_anno,
                    palette = ct_color_map)
ax.set_ylabel('probability to CD4+T')    
ax.set_xlabel('pseudotime')
ax.set_title('')
h,l = ax.get_legend_handles_labels()
l1 = ax.legend(h[:8],l[:8], loc='upper right',bbox_to_anchor=(1.2, 1),frameon=False, fontsize='small')
#plt.savefig(fig_path+'pseudotime_scatterplot_nhood_vdj_transplant.png',bbox_inches='tight')
#plt.savefig(fig_path+'pseudotime_scatterplot_nhood_vdj_healthy_paed.png',bbox_inches='tight')

In [None]:
# Save new nhood vdj
nhood_adata.write_h5ad(f'{data_path}/analyses/vdj_pseudotime/thyAgeing_vdjNhoods_allAges_complete.h5ad')

Transfer pseudotime to cells

In [None]:
# Load nhood adata
nhood_adata = ad.read_h5ad(f'{data_path}/analyses/vdj_pseudotime/thyAgeing_vdjNhoods_allAges_complete.h5ad')

In [None]:
# Define age groups
age_group_dict = {'infant' : ['infant'],
                  'paed(early)' : ['paed(early)'],
                  'paed(mid-late)' : ['paed(mid)', 'paed(late)'],
                  'adult' : ['adult(early)', 'adult(mid)'],
                  }

# Initialise pseudotime df
pseudotime_df = {}
for age_group_key in age_group_dict.keys():
    
    # Load subsets of data
    bdata_sub = ad.read_h5ad(f'{data_path}/analyses/vdj_pseudotime/thyAgeing_subAdata_{age_group_key}.zarr')
    pbs_slice_index = bdata_sub.shape[0]

    nhood_adata_sub = nhood_adata[nhood_adata.obs[col_age_group] == age_group_key]
    nhood_adata_sub.obsm['pbs'] = bdata_sub.obsm['nhoods'].T
    nhood_adata_sub.obs_names = nhood_adata_sub.obs['orig_index']
    
    # project the nhood level pseudotime to cell level pseudotime.
    cdata_sub = ddl.tl.project_pseudotime_to_cell(adata = bdata_sub, 
                                pb_adata = nhood_adata_sub, 
                                term_states=['T_CD8_naive','T_CD4_naive'], 
                                suffix = '_nhood_vdj')
    
    # Add to pseudotime df
    pseudotime_df[age_group_key] = cdata_sub.obs[['pseudotime_nhood_vdj', 'prob_T_CD8_naive_nhood_vdj', 'prob_T_CD4_naive_nhood_vdj']]
    
# Make df from dict
pseudotime_df = pd.concat(pseudotime_df, axis=0).reset_index().drop(columns = 'level_0').rename(columns = {'level_1' : 'names'}).set_index('names')

pseudotime_df

In [None]:
# Save pseudotime assignments
pseudotime_df.to_csv(f'{data_path}/analyses/vdj_pseudotime/thyAgeing_vdjPseudotime_allAges.csv')
pseudotime_df.to_csv(f'{data_path}/objects/thyAgeing_tSplit_scvi_{object_version}_vdjPseudotime.csv')

In [None]:
session_info.show()