# Thymus ageing atlas: Integration and plotting of the final object

In [None]:
import os
import sys
import session_info
from datetime import datetime
today = datetime.today().strftime('%Y-%m-%d')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import anndata as ad
import hdf5plugin

import warnings
warnings.filterwarnings('ignore', category=ad.ImplicitModificationWarning)

# Add repo path to sys path (allows to access scripts and metadata from repo)
repo_path = '/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis'
sys.path.insert(1, repo_path) 
sys.path.insert(2, '/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts')

# Autoreload custom scripts
%load_ext autoreload
%autoreload 2

# Define paths
plots_path = f'{repo_path}/plots/'
data_path = f'{repo_path}/data/'
model_path = os.path.join(repo_path, 'models')
general_data_path = '/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/data'

print('Dir for plots: {}'.format(plots_path))
print('Dir for data: {}'.format(data_path))

# Formatting
from matplotlib import font_manager
font_manager.fontManager.addfont("/nfs/team205/ny1/ThymusSpatialAtlas/software/Arial.ttf")
plt.style.use('/nfs/team205/lm25/thymus_projects/thymus_ageing_atlas/General_analysis/scripts/plotting/thyAgeing.mplstyle')

# Import custom scripts
from utils import get_latest_version,update_obs,freq_by_donor,cellxgene_prep
from anno_levels import get_ct_levels, get_ct_palette, age_group_levels, age_group_palette
from plotting.utils import plot_grouped_boxplot, calc_figsize
from scvi_wrapper import run_scvi

In [None]:
# Generate palette for compartment proportion plot
[midpoint_color(c1,c2) for c1,c2 in [('#9784ed', '#dc5999'), ('#f4833d', '#f5bc3d'), ('#90a195', '#4b9aa1')]]

In [None]:
# Load latest meta
latest_meta_path = get_latest_version(dir = f'{general_data_path}/metadata', file_prefix='Thymus_ageing_metadata')
latest_meta = pd.read_excel(latest_meta_path)

In [None]:
# Import data
adata = ad.read_h5ad(f'{data_path}/objects/rna/thyAgeing_all_scvi_v3_2024-11-05_filt_prelim_anno.zarr')

adata_c9 = ad.read_h5ad(f'/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/T_NK_compartment/data/objects/rna/thyAgeing_tSplit_scvi_v7_2024-11-06.zarr', backed= 'r')
leiden_c9 = pd.read_csv(f'/lustre/scratch126/cellgen/team205/lm25/thymus_projects/thymus_ageing_atlas/T_NK_compartment/data/objects/rna/thyAgeing_tSplit_scvi_v7_2024-11-06_leidenClusters.csv', index_col = 0)
adata_c9.obs = adata_c9.obs.join(leiden_c9)
adata_c9 = adata_c9[adata_c9.obs['leiden_r2.5'] == 9].to_memory()

In [None]:
all(adata.var_names == adata_c9.var_names)

In [None]:
adata_c9.X[:100, :100].sum() % 1 == 0

In [None]:
adata_concat = ad.concat([adata, adata_c9], index_unique=None)
adata_concat.var = adata.var[['gene_ids', 'gene_name']].copy().rename(columns={'gene_ids': 'gene_id'})
adata_concat

In [None]:
import scvi
import torch
torch.cuda.is_available()

In [None]:
object_version = f'v5_{today}'
# Run scvi
scvi_run = run_scvi(adata_concat, 
                    layer_raw = 'X', 
                    # Excluded genes
                    include_genes=[], exclude_cc_genes=True, exclude_mt_genes=True, 
                    exclude_vdjgenes = True, remove_cite = False,
                    # Highly variable gene selection
                    batch_hv="age_group", hvg = 10000, span = 0.5,
                    # scVI 
                    batch_scvi="sample",
                    cat_cov_scvi=["chemistry_simple", "sex", "donor"], 
                    cont_cov_scvi=None, 
                    max_epochs=200, batch_size=2000, early_stopping = True, early_stopping_patience = 15, early_stopping_min_delta = 10.0,
                    plan_kwargs = {'lr': 0.001, 'reduce_lr_on_plateau' : True, 'lr_patience' : 10, 'lr_threshold' : 20}, 
                    n_layers = 3, n_latent = 30, dispersion = 'gene-batch',
                    # Leiden clustering
                    leiden_clustering = None, col_cell_type = ['taa_l3', 'taa_l1'], 
                    fig_dir = f'{plots_path}/preprocessing/scvi', fig_prefix = f'thyAgeing_all_scvi_{object_version}')

In [None]:
# Save adata and scvi model
overwrite = True
ad.settings.allow_write_nullable_strings = True

for c in scvi_run['data'].obs.columns:
    if scvi_run['data'].obs[c].dtype == 'object':
        scvi_run['data'].obs[c] = scvi_run['data'].obs[c].astype('string')
    if scvi_run['data'].obs[c].dtype == 'category':
        scvi_run['data'].obs[c] = scvi_run['data'].obs[c].astype(str)
        
anno_cols = [c for c in scvi_run['data'].obs.columns if '_pred_' in c or '_prob_' in c or 'taa' in c]
if not os.path.exists(f'{data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr') or overwrite:
    scvi_run['data'].obs = scvi_run['data'].obs.drop(columns=anno_cols)
    scvi_run['data'].write_h5ad(
        f'{data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr',
        compression=hdf5plugin.FILTERS["zstd"],
        compression_opts=hdf5plugin.Zstd(clevel=5).filter_options,
    )
    scvi_run['vae'].save(f'{model_path}/thyAgeing_all_scvi_{object_version}', save_anndata=False, overwrite=overwrite)
else:
    print('File already exists')

### Check cell type annotations

In [None]:
object_version = f'v5_2025-04-03'
adata = ad.read_h5ad(f'{data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr', backed='r')
ct_anno = pd.read_csv(f'{data_path}/objects/rna/thyAgeing_all_scvi_v4_2025-02-04_curatedAnno_v10.csv', index_col = 0)
adata.obs = adata.obs.join(ct_anno, how = 'left')

In [None]:
sc.pl.umap(adata, color = ['taa_l4', 'taa_l3', 'taa_l2', 'taa_l1'], ncols = 1, return_fig = True)
plt.savefig(f'{plots_path}/preprocessing/scvi/thyAgeing_all_scvi_{object_version}_ctAnnotation_umap.png', dpi = 300, bbox_inches = 'tight')

In [None]:
adata.obs['taa_l3'].unique().to_numpy()

In [None]:
adata_sub = adata[adata.obs['taa_l3'].isin(['T_CD4', 'T_CD8', 'T_Treg', 'T_αβT(entry)'])].to_memory()

sc.pp.normalize_total(adata_sub, target_sum=1e4)
sc.pp.log1p(adata_sub)

In [None]:
import itertools
[f'{c1}_{c2}' for c1,c2 in list(itertools.product(['T_αβT(entry)', 'T_CD8', 'T_CD4', 'T_Treg'], ['infant', 'paed', 'adult']))]

In [None]:
import itertools
cat_order = [f'{c1}_{c2}' for c1,c2 in list(itertools.product(['T_αβT(entry)', 'T_CD8', 'T_CD4', 'T_Treg'], ['infant', 'paed', 'adult']))]
sc.pl.DotPlot(adata_sub,
              var_names = ['CXCR4', 'CD38', 'CCR7', 'S1PR1'],
              groupby = ['taa_l3', 'age_group'],
              categories_order=cat_order,
              mean_only_expressed=True,
              figsize = calc_figsize(width = 100, height = 70),
              cmap = 'magma').add_totals().style(smallest_dot=1, largest_dot = 40).savefig(f'{plots_path}/egressMarkers.pdf') 

## Save cellxgene object

In [None]:
# Load adata
object_version = 'v5_2025-04-03'
adata = ad.read_h5ad(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_{object_version}.zarr')

# Add new annotations to adata
ct_anno = pd.read_csv(f'{general_data_path}/objects/rna/thyAgeing_all_scvi_v4_2025-02-04_curatedAnno_v10.csv', index_col = 0)
for c in ct_anno.columns:
    if c in adata.obs.columns:
        adata.obs.drop(c, axis = 1, inplace = True)
adata.obs = adata.obs.join(ct_anno)

# Filter data (only include annotated cells)
adata = adata[(adata.obs['anno_status'] == 'include'),:]

# Update metadata
latest_meta_path = get_latest_version(dir = f'{general_data_path}/metadata', file_prefix='Thymus_ageing_metadata')
latest_meta = pd.read_excel(latest_meta_path)
update_obs(adata, latest_meta, on = 'index', ignore_warning = True)

adata

In [None]:
adata.var['gene_ids'] = adata.var['gene_id'].copy()

In [None]:
cellxgene_prep(adata, object_name = 'ThyAge_all')