In [None]:
import sys 
import os
from datetime import datetime
today = datetime.today().strftime('%Y-%m-%d')

import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
import hdf5plugin
import matplotlib.pyplot as plt
import seaborn as sns

# Add repo path to sys path (allows to access scripts and metadata from repo)
repo_path = '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/B_compartment'
sys.path.insert(1, repo_path) 
sys.path.insert(2, '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts')

%reload_ext autoreload
%autoreload 2

# Define paths
plots_path = f'{repo_path}/plots/'
data_path = f'{repo_path}/data/'
model_path = os.path.join(repo_path, 'models')
general_data_path = '/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/data'

print('Dir for plots: {}'.format(plots_path))
print('Dir for data: {}'.format(data_path))

# Formatting
from matplotlib import font_manager
font_manager.fontManager.addfont("/nfs/team205/ny1/ThymusSpatialAtlas/software/Arial.ttf")
plt.style.use('/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts/plotting/thyAgeing.mplstyle')

# Import custom scripts
from utils import get_latest_version,update_obs,freq_by_donor
from anno_levels import get_ct_levels, get_ct_palette, age_group_levels, age_group_palette
from plotting.utils import plot_grouped_boxplot, calc_figsize
from scvi_wrapper import run_scvi

## V1: Spatial integration only

In [None]:
adata = ad.read_h5ad('/nfs/team205/vk8/projects/thymus_ageing_atlas/Spatial_analyses/data/xenium/adata_xenium_B_2025-03-20.zarr')
adata.X = adata.layers['counts']

adata

In [None]:
sc.pp.filter_cells(adata, min_genes=20)
adata.shape

In [None]:
adata.obs[['DonorID', 'Donor_type', 'Age_group', 'Age(misc)',
       'Age(numeric)', 'age_months', 'Source', 'Study', 'Study name ',
       'Study ID', 'published', 'Sex']].head()

In [None]:
import pandas as pd

In [None]:
# Marker expression
b_markers = {'Other': ['IL6R', 'TNFRSF18','TNFSF13B'],
             'DZ' : ['AICDA', 'CXCR4', 'MYC', 'MKI67', 'TOP2A', 'PCNA', 'BACH2', 'TCF3', 'PAX5', 'IRF4', 'MEF2B', 'FOXO1'],
               'LZ' : ['CXCR5', 'CD83', 'CD86', 'MYBL1', 'SOCS3', 'CD40'],
               'T cell contact': ['CXCL10', 'CCL5', 'CCL3'],
               'GC_misc' : ['ID2', 'ID3',],
               'EBV' : ['CCL22', 'CCL17', 'EBI3', 'CCL3', 'ICAM1'],
               'Recruitment' : ['CCR7'],
               'B_naive' : ['CD22', 'SELL', 'IL4R', 'TCL1A', 'CR2', 'FOXO1', 'IGHM', 'IGHD',],
               'B_mem' : ['CD27', 'CD38', 'FCRL4', 'FCRL5', 'CD44', 'PRDM1', 'IGHA1', 'IGHG1', 'IGHE'],
               'B_plasma' : ['PRDM1', 'XBP1', 'MZB1'], 
               'B_age-assoc' : ['TBX21', 'ITGAX', 'IRF4'],
               'B_pan' : ['CD19', 'MS4A1'],
               'B_med' : ['HLA-DRA', 'HLA-DRB1','AIRE', 'IL15', 'LTA', 'LTB', 'PTPRC', 'CD5', 'SPN', 'CD80' ,'LY6G6C'],
               'B_dev' : ['IGLL1', 'MME', 'RAG1'],
               'B_dev_thy' : ['CD34', 'VPREB1', 'TYROBP']}

b_markers_df = pd.DataFrame([(k, gene) for k, genes in b_markers.items() for gene in genes], columns=['cell_label', 'gene_name'])
b_markers_df.to_csv('/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/B_compartment/data/curated/thyAgeing_bMarkers_detailled.csv', index=False)

In [None]:
# Marker expression
b_markers = {'Other': ['IL6R', 'TNFRSF18','TNFSF13B'],
             'DZ' : ['AICDA', 'CXCR4', 'MYC', 'MKI67', 'TOP2A', 'PCNA', 'BACH2', 'TCF3', 'PAX5', 'IRF4', 'MEF2B', 'FOXO1'],
               'LZ' : ['CXCR5', 'CD83', 'CD86', 'MYBL1', 'SOCS3', 'CD40'],
               'T cell contact': ['CXCL10', 'CCL5', 'CCL3'],
               'GC_misc' : ['ID2', 'ID3',],
               'EBV' : ['CCL22', 'CCL17', 'EBI3', 'CCL3', 'ICAM1'],
               'Recruitment' : ['CCR7'],
               'B_naive' : ['CD22', 'SELL', 'IL4R', 'TCL1A', 'CR2', 'FOXO1', 'IGHM', 'IGHD',],
               'B_mem' : ['CD27', 'CD38', 'FCRL4', 'FCRL5', 'CD44', 'PRDM1', 'IGHA1', 'IGHG1', 'IGHE'],
               'B_plasma' : ['PRDM1', 'XBP1', 'MZB1'], 
               'B_age-assoc' : ['TBX21', 'ITGAX', 'IRF4'],
               'B_pan' : ['CD19', 'MS4A1'],
               'B_med' : ['HLA-DRA', 'HLA-DRB1','AIRE', 'IL15', 'LTA', 'LTB', 'PTPRC', 'CD5', 'SPN', 'CD80' ,'LY6G6C'],
               'B_dev' : ['IGLL1', 'MME', 'RAG1'],
               'B_dev_thy' : ['CD34', 'VPREB1', 'TYROBP']}

b_genes = list(set([g for g in b_markers.values() for g in g if g in adata.var_names]))
b_genes = [g for g in b_genes if g in adata.var_names and g not in ['PCNA', 'MKI67', 'TOP2A']]

In [None]:
object_version = f'v1_{today}'

# Run scvi
scvi_run = run_scvi(adata, 
                    layer_raw = 'X', 
                    # Excluded genes
                    include_genes=b_genes, exclude_cc_genes=True, exclude_mt_genes=True, 
                    exclude_vdjgenes = True, remove_cite = False,
                    # Highly variable gene selection
                    batch_hv="Age_group", span = 0.5,
                    hvg = 300,
                    # scVI 
                    batch_scvi="SampleID",
                    cat_cov_scvi=["DonorID", "Sex"], 
                    #cont_cov_scvi=["percent_mito", 'percent_ribo', 'n_genes'], 
                    max_epochs=400, batch_size=2000, early_stopping = True, early_stopping_patience = 20, early_stopping_min_delta = 5.0,
                    plan_kwargs = {'lr': 0.001, 'reduce_lr_on_plateau' : True, 'lr_patience' : 10, 'lr_threshold' : 20}, 
                    n_layers = 3, n_latent = 30, dispersion = 'gene-batch',
                    # Leiden clustering
                    leiden_clustering = None, col_cell_type = ['ctypist_taa_l4_predicted_labels', 'taa_l5'], 
                    fig_dir = f'{plots_path}/preprocessing/xenium', fig_prefix = f'thyAgeing_bSplitXenium_scvi_{object_version}')

In [None]:
overwrite = False
anno_cols = [c for c in scvi_run['data'].obs.columns if '_pred_' in c or '_prob_' in c or 'taa' in c]
if not os.path.exists(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}.zarr') or overwrite:
    scvi_run['data'].obs = scvi_run['data'].obs.drop(columns=anno_cols)
    scvi_run['data'].write_h5ad(
        f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}.zarr',
        compression=hdf5plugin.FILTERS["zstd"],
        compression_opts=hdf5plugin.Zstd(clevel=5).filter_options,
    )
    scvi_run['vae'].save(f'{model_path}/thyAgeing_bSplitXenium_scvi_{object_version}', save_anndata=False, overwrite=overwrite)
    print(f'Saving version {object_version} of scVI model and adata object')
else:
    print('File already exists')

### Leiden clustering

In [None]:
adata = ad.read_h5ad(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}.zarr')

adata

In [None]:
# Louvain clustering
res_list = [1.5, 2.5]
for res in res_list:
    sc.tl.leiden(adata, resolution = res, key_added = f"leiden_r{res}")
adata.obs[[f'leiden_r{str(r)}' for r in res_list]] = adata.obs[[f'leiden_r{str(r)}' for r in res_list]].astype('category')

adata.obs[[f'leiden_r{str(r)}' for r in res_list]].to_csv(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}_leidenClusters.csv')

In [None]:
sc.pl.umap(adata, color=[f'leiden_r{str(r)}' for r in res_list], ncols=2, return_fig = True)
plt.savefig(f'{plots_path}/ctAnnotation/xenium_v1/thyAgeing_bSplitXenium_scvi_{object_version}_leidenClusters.png', dpi=300, bbox_inches='tight')

### Marker expression

In [None]:
adata = ad.read_h5ad(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}.zarr')

leiden_clusters = pd.read_csv(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}_leidenClusters.csv', index_col = 0)
adata.obs = adata.obs.join(leiden_clusters)
adata.obs[leiden_clusters.columns] = adata.obs[leiden_clusters.columns].astype(int).astype('category')

sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

adata

In [None]:
# Marker expression
b_markers = {'Other': ['IL6R', 'TNFRSF18','TNFSF13B'],
             'DZ' : ['AICDA', 'CXCR4', 'MYC', 'MKI67', 'TOP2A', 'PCNA', 'BACH2', 'TCF3', 'PAX5', 'IRF4', 'MEF2B', 'FOXO1'],
               'LZ' : ['CXCR5', 'CD83', 'CD86', 'MYBL1', 'SOCS3', 'CD40'],
               'T cell contact': ['CXCL10', 'CCL5', 'CCL3'],
               'GC_misc' : ['ID2', 'ID3',],
               'EBV' : ['CCL22', 'CCL17', 'EBI3', 'CCL3', 'ICAM1'],
               'Recruitment' : ['CCR7'],
               'B_naive' : ['CD22', 'SELL', 'IL4R', 'TCL1A', 'CR2', 'FOXO1', 'IGHM', 'IGHD',],
               'B_mem' : ['CD27', 'CD38', 'FCRL4', 'FCRL5', 'CD44', 'PRDM1', 'IGHA1', 'IGHG1', 'IGHE'],
               'B_plasma' : ['PRDM1', 'XBP1', 'MZB1'], 
               'B_age-assoc' : ['TBX21', 'ITGAX', 'IRF4'],
               'B_pan' : ['CD19', 'MS4A1'],
               'B_med' : ['HLA-DRA', 'HLA-DRB1','AIRE', 'IL15', 'LTA', 'LTB', 'PTPRC', 'CD5', 'SPN', 'CD80' ,'LY6G6C'],
               'B_dev' : ['IGLL1', 'MME', 'RAG1'],
               'B_dev_thy' : ['CD34', 'VPREB1', 'TYROBP'],
               'T': ['CD8A', 'CD3E']}

In [None]:
b_markers_filtered = {k: [v for v in b_markers[k] if v in adata.var_names] for k in b_markers.keys()}
sc.pl.DotPlot(adata, 
                b_markers_filtered,
                groupby = 'leiden_r1.5',
                mean_only_expressed=True,
                cmap = 'magma').add_totals().savefig(f'{plots_path}/ctAnnotation/xenium_v1/thyAgeing_bSplitXenium_scvi_{object_version}_leiden_r1.5_dotplot.pdf', dpi=300, bbox_inches='tight')

## V2: Integration with suspension data

In [None]:
adata = ad.read_h5ad('/nfs/team205/vk8/projects/thymus_ageing_atlas/Spatial_analyses/data/xenium/adata_xenium_B_2025-03-20.zarr')
adata.X = adata.layers['counts']

adata

In [None]:
# Load adata_susp
object_version = 'v4_2024-11-06'
adata_susp = ad.read_h5ad(f'{data_path}/objects/rna/thyAgeing_bSplit_scvi_{object_version}.zarr')

# Add cell type annotation
ct_anno = pd.read_csv(f'{data_path}/preprocessing/ctAnnotation/thyAgeing_bSplitxTissue_scvi_v2_2025-02-20_v5.csv', index_col=0)
cols_overlapping = [col for col in ct_anno.columns if col in adata_susp.obs.columns]
if any(cols_overlapping):
    adata_susp.obs.drop(columns=cols_overlapping, inplace=True)
adata_susp.obs = adata_susp.obs.join(ct_anno)

# Update metadata_susp
latest_meta_path = get_latest_version(dir = f'{general_data_path}/metadata', file_prefix='Thymus_ageing_metadata')
latest_meta = pd.read_excel(latest_meta_path)
update_obs(adata_susp, latest_meta, on = 'index', ignore_warning = True)

In [None]:
adata.obs.rename(columns = {'DonorID' : 'donor', 'SampleID':'sample', 'Sex':'sex'}, inplace=True)
overlap_genes = np.intersect1d(adata.var_names, adata_susp.var_names).tolist()
adata_concat = adata_susp[:,overlap_genes].concatenate(adata[:,overlap_genes], batch_key = 'Batch', batch_categories = ['RNA', 'Xenium'], index_unique = None)

adata_concat

In [None]:
adata_concat.obs['Batch'].value_counts()

In [None]:
object_version = f'v2_{today}'

# Run scvi
scvi_run = run_scvi(adata_concat, 
                    layer_raw = 'X', 
                    # Excluded genes
                    include_genes=b_genes, exclude_cc_genes=True, exclude_mt_genes=True, 
                    exclude_vdjgenes = True, remove_cite = False,
                    # Highly variable gene selection
                    batch_hv="Batch", span = 0.5,
                    hvg = 300,
                    # scVI 
                    batch_scvi="sample",
                    cat_cov_scvi=["donor", "sex", 'Batch'], 
                    #cont_cov_scvi=["percent_mito", 'percent_ribo', 'n_genes'], 
                    max_epochs=400, batch_size=2000, early_stopping = True, early_stopping_patience = 20, early_stopping_min_delta = 5.0,
                    plan_kwargs = {'lr': 0.001, 'reduce_lr_on_plateau' : True, 'lr_patience' : 10, 'lr_threshold' : 20}, 
                    n_layers = 3, n_latent = 30, dispersion = 'gene-batch',
                    # Leiden clustering
                    leiden_clustering = None, col_cell_type = ['ctypist_taa_l4_predicted_labels', 'taa_l5'], 
                    fig_dir = f'{plots_path}/preprocessing/xenium', fig_prefix = f'thyAgeing_bSplitXenium_scvi_{object_version}')

In [None]:
overwrite = True

for c in scvi_run['data'].obs.columns:
    if scvi_run['data'].obs[c].dtype == 'object':
        scvi_run['data'].obs[c] = scvi_run['data'].obs[c].astype('str')
        
anno_cols = [c for c in scvi_run['data'].obs.columns if '_pred_' in c or '_prob_' in c or 'taa' in c]
if not os.path.exists(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}.zarr') or overwrite:
    scvi_run['data'].obs = scvi_run['data'].obs.drop(columns=anno_cols)
    scvi_run['data'].write_h5ad(
        f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}.zarr',
        compression=hdf5plugin.FILTERS["zstd"],
        compression_opts=hdf5plugin.Zstd(clevel=5).filter_options,
    )
    scvi_run['vae'].save(f'{model_path}/thyAgeing_bSplitXenium_scvi_{object_version}', save_anndata=False, overwrite=overwrite)
    print(f'Saving version {object_version} of scVI model and adata object')
else:
    print('File already exists')

### Leiden clustering

In [None]:
object_version = 'v2_2025-03-26'
adata = ad.read_h5ad(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}.zarr')

adata

In [None]:
# Louvain clustering
res_list = [1.5, 2.5]
for res in res_list:
    sc.tl.leiden(adata, resolution = res, key_added = f"leiden_r{res}")
adata.obs[[f'leiden_r{str(r)}' for r in res_list]] = adata.obs[[f'leiden_r{str(r)}' for r in res_list]].astype('category')

adata.obs[[f'leiden_r{str(r)}' for r in res_list]].to_csv(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}_leidenClusters.csv')

### Marker expression

In [None]:
sc.pl.umap(adata, color=[f'leiden_r{str(r)}' for r in res_list], ncols=2, return_fig = True)
plt.savefig(f'{plots_path}/ctAnnotation/xenium_v2/thyAgeing_bSplitXenium_scvi_{object_version}_leidenClusters.png', dpi=300, bbox_inches='tight')

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
b_markers_filtered = {k: [v for v in b_markers[k] if v in adata.var_names] for k in b_markers.keys()}
sc.pl.DotPlot(adata, 
                b_markers_filtered,
                groupby = 'leiden_r1.5',
                mean_only_expressed=True,
                cmap = 'magma').add_totals().savefig(f'{plots_path}/ctAnnotation/xenium_v2/thyAgeing_bSplitXenium_scvi_{object_version}_leiden_r1.5_dotplot.pdf', dpi=300, bbox_inches='tight')

### Cell annotations

In [None]:
cluster_assignment = {'B_mem': [2,11,8,9,10,4,7],
                      'B_naive': [1,17],
                      'B_plasma': [0,5,13,16,],
                      'B_plasma_IgE' : [19],
                      'B_plasmablast' : [18],
                      'B_GC-like' : [3,12],
                      'B_dev': [],
                      'B_dev_thy': [],
                      'B_med': [],
                      'B_age-assoc': [],
                      'Remove': [6,14,15,20]}

np.array([c for c in adata.obs['leiden_r1.5'].unique().tolist() if c not in [str(c) for c in cluster_assignment.values() for c in c]])

In [None]:
[c for c in adata.var_names if 'IGH' in c]

In [None]:
adata.obs['temp_anno'] = pd.NA
cluster_assignment = {k: [str(c) for c in cluster_assignment[k]] for k in cluster_assignment.keys()}
for k, v in cluster_assignment.items():
    adata.obs.loc[adata.obs['leiden_r1.5'].isin(v), 'temp_anno'] = k
    
sc.pl.umap(adata, color = 'temp_anno', return_fig = True)
plt.savefig(f'{plots_path}/ctAnnotation/xenium_v2/thyAgeing_bSplitXenium_scvi_{object_version}_tempAnno_umap.pdf', dpi=300, bbox_inches='tight')

In [None]:
anno_levels = pd.read_excel(f'{general_data_path}/curated/thyAgeing_full_curatedAnno_v9_2025-03-03_levels.xlsx')

anno_levels

In [None]:
ct_anno = adata.obs[['temp_anno']].copy()
ct_anno = ct_anno.reset_index(names='names').rename(columns = {'temp_anno' : 'taa_l5'}).merge(anno_levels, how = 'left', on = 'taa_l5')
ct_anno = ct_anno.set_index('names')[anno_levels.columns]
ct_anno.loc[ct_anno['taa_l5'] == 'Remove', 'taa_l5'] = pd.NA

ct_anno.to_csv(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}_v2.csv')

ct_anno.head()

In [None]:
ct_anno

In [None]:
ct_anno = pd.read_csv(f'{data_path}/objects/xenium/thyAgeing_bSplitXenium_scvi_{object_version}_v2.csv', index_col = 0)
for c in ct_anno.columns:
    if c in adata.obs.columns:
        adata.obs.drop(columns = c, inplace = True)
adata.obs = adata.obs.join(ct_anno)

In [None]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

In [None]:
b_markers_filtered = {k: [v for v in b_markers[k] if v in adata.var_names] for k in b_markers.keys()}
sc.pl.DotPlot(adata[adata.obs['Batch'] == 'Xenium', :], 
                b_markers_filtered,
                groupby = 'taa_l5',
                figsize = calc_figsize(width = 200, height = 50),
                mean_only_expressed=True,
                cmap = 'magma').add_totals().style(smallest_dot=1, largest_dot = 40).savefig(f'{plots_path}/ctAnnotation/xenium_v2/thyAgeing_bSplitXenium_scvi_{object_version}_finalAnno_dotplot.pdf', dpi=300, bbox_inches='tight')

In [None]:
with plt.rc_context({'figure.figsize' : calc_figsize(width = 70, height = 40)}):
    sc.pl.umap(adata[adata.obs['Batch'] == 'Xenium', :], color = 'taa_l5', return_fig = True, show = False)
    plt.savefig(f'{plots_path}/ctAnnotation/xenium_v2/thyAgeing_bSplitXenium_scvi_{object_version}_finalAnno_umap.png', dpi=300, bbox_inches='tight')