# Build neighbourhood VDJ feature space

In [None]:
import palantir

In [None]:
import dandelion as ddl

In [None]:
# import gspread as gs
import numpy as np
import seaborn as sns
import os
import pandas as pd

In [None]:
import scanpy as sc

In [None]:
from collections import Counter
import matplotlib.pyplot as plt

In [None]:
os.chdir('/nfs/team205/ny1/milopy/milopy/')
# import milopy
import core as milo
ddl.logging.print_header()
sc.logging.print_header()

In [None]:
# set working directory


In [None]:
sc.settings.set_figure_params(dpi = 160, color_map = 'RdYlBu_r', dpi_save = 300, format = 'pdf')
plt.rcParams["figure.figsize"] = [6,6]
sns.set_palette('colorblind')

In [None]:
#### need to add this line to restore plotting function of scanpy in the presence of palantir
%matplotlib inline

# Load data

In [None]:
# object loaded with abTCR, gdTCR, BCR data in 02_panfetal_load_VDJ


In [None]:
adata.obs_names

In [None]:
adata.obs['anno_CITE_4v3'].value_counts().keys()

In [None]:
adata

In [None]:
# remove old VDJ annotations
adata.obs = adata.obs[adata.obs.columns.drop(
    list(adata.obs.filter(regex='VDJ')) + list(adata.obs.filter(regex='VJ'))+ list(adata.obs.filter(regex='vj')) 
    + list(adata.obs.filter(regex='vdj')) + list(adata.obs.filter(regex='clone')) + list(adata.obs.filter(regex='contig'))
    + list(adata.obs.filter(regex='isotype')) + list(adata.obs.filter(regex='chain')) + list(adata.obs.filter(regex='locus'))
)]
adata

In [None]:
# set up subsets and colors

ct_all_order = ['uncommitted', 'committed_CD4neg','committed_CD4neg(P)','committed_CD4pos','committed_CD4pos(P)',
                'DP(P)_early', 'DP(P)_late','DP(Q)_early','DP(Q)_rearr','DP(Q)_CD99_CD31lo','DP_early_CD31',
                'DP_4hi8lo','DP(Q)_CD199','DP(Q)_HSPH1','DP_pos_sel','DP(Q)_CD99_CD31hi','DP(Q)_Th2',
                'SP_CD4_immature', 'SP_CD8_immature','SP_CD4_semimature', 'SP_CD8_semimature','CD8aaI_immature',
                'CD8aaII_immature','gdT_immature','gdT_semimature',
                'CD8aaI_mature','CD8aaII_mature','SP_CD4_mature', 
                'SP_Treg_immature','SP_Treg_mature','SP_CD8_mature','SP_Treg_PD1', 'SP_Treg_CD8',
                'NK_tr_itg_hi','SP_CD8_NKlike', 'gdT_mature','NK_circ_56hi16lo', 'gdT_Vd2', 'NK_tr_itg_lo',
                'NK_circ_56lo16hi','iNKT','SP_Treg_recirc'
                  ]
                   
ct_color_map = dict(zip(ct_all_order, np.array(sns.color_palette("husl", len(ct_all_order)))[range(len(ct_all_order))]))

# Load abTCR

In [None]:
url_1 = sheet_url.replace('/edit#gid=', '/export?format=csv&gid=')
meta = pd.read_csv(url_1)       
meta

In [None]:
# clean out non relevant libraries 
meta = meta[~(meta['path_TCRab'].isna())]
meta = meta[meta['cite']]
meta

In [None]:
# creacte dandelion object  
import os.path
import warnings
warnings.filterwarnings("ignore")
from os import path
from tqdm import tqdm
tcrab = {}
for x, y in tqdm(zip(meta['path_TCRab'], meta['sample'])):
    file1 = '/'+x+'/all_contig_annotations.json'
    file2 = '/'+x+'/outs/all_contig_annotations.json'
    
    if path.exists(file1):
        tmp = ddl.read_10x_vdj(file1)
    else: 
        tmp = ddl.read_10x_vdj(file2)
    
    # update cell_id to librarry-barcode
    tmp.data['cell_id']= [y + '-' + z.split('-1')[0] for z in tmp.data['cell_id']]
    ddl.utl.update_metadata(tmp) # update the metadata_names
    # only leave contigs with cell_id in adata
    tmp = tmp[tmp.data['cell_id'].isin(adata.obs_names)].copy()
    
    tcrab[x] = tmp
len(tcrab)

In [None]:
# concatenate object 
tcrab = ddl.concat([tcrab[x] for x in tcrab], prefixes = list(tcrab.keys()))
tcrab

In [None]:
ddl.tl.transfer(adata, tcrab)

In [None]:
tcrab.data

In [None]:
# library_type is set to filter out genes that are not TRA/TRB (in 'locus' column) as this library is abTCR
tcrab.data['sequence_alignment'] = tcrab.data['sequence']
tcrab_checked, trab_adata = ddl.pp.check_contigs(tcrab, adata, productive_only = False, library_type = 'tr-ab')
tcrab_checked

In [None]:
trab_adata

In [None]:
adata_abtcr = trab_adata.copy()

# Filter cells

In [None]:
plt.rcParams["figure.figsize"] = [20,20]
plt.rcParams['font.family'] = 'sans-serif'

sc.set_figure_params(fontsize=4,dpi=200)
adata_abtcr

## Subset cells to dp onwards, and cells with paired TCRab

In [None]:
# set up subsets and colors

ct_all_order = ['uncommitted', 'committed_CD4neg','committed_CD4neg(P)','committed_CD4pos','committed_CD4pos(P)',
                'DP(P)_early', 'DP(P)_late','DP(Q)_early','DP(Q)_rearr','DP(Q)_CD99_CD31lo','DP_early_CD31',
                'DP_4hi8lo','DP(Q)_CD199','DP(Q)_HSPH1','DP_pos_sel','DP(Q)_CD99_CD31hi','DP(Q)_Th2',
                'SP_CD4_immature', 'SP_CD8_immature','SP_CD4_semimature', 'SP_CD8_semimature','CD8aaI_immature',
                'CD8aaII_immature','gdT_immature','gdT_semimature',
                'CD8aaI_mature','CD8aaII_mature','SP_CD4_mature', 
                'SP_Treg_immature','SP_Treg_mature','SP_CD8_mature','SP_Treg_PD1', 'SP_Treg_CD8',
                'NK_tr_itg_hi','SP_CD8_NKlike', 'gdT_mature','NK_circ_56hi16lo', 'gdT_Vd2', 'NK_tr_itg_lo',
                'NK_circ_56lo16hi','iNKT','SP_Treg_recirc'
                  ]

ct_color_map_all = dict(zip(ct_all_order, np.array(sns.color_palette("husl", len(ct_all_order)))[range(len(ct_all_order))]))

ct_order = ['DP(P)_early','DP(Q)_early','DP(Q)_rearr',
                'DP_4hi8lo','DP(Q)_CD199','DP(Q)_HSPH1','DP(Q)_CD99_CD31lo','DP(Q)_CD99_CD31hi','DP_pos_sel',
                'SP_CD4_immature', 'SP_CD8_immature','SP_CD4_semimature', 'SP_CD8_semimature','SP_CD4_mature','SP_CD8_mature',
                  ]

ct_color_map = dict(zip(ct_all_order, np.array(sns.color_palette("husl", len(ct_order)))[range(len(ct_order))]))

# subset cells to celltypes within ct_order
bdata = adata_abtcr[adata_abtcr.obs['anno_CITE_4v3'].isin(ct_order)]
# filter out cells without any contig
for chain in ['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main']:
    bdata.obs[chain] = bdata.obs[chain].astype('str')
# change all entries with ',' (inconfident mappings) to 'None'
for cell in bdata.obs_names:
    for chain in ['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main']:
        gene = bdata.obs.loc[cell, chain]
        if ',' in gene or gene =='None' or gene =='' or gene=='No_contig':
            bdata.obs.loc[cell, chain] = chain+'_None'

In [None]:
# option for DP onwards - only leave cells with all 4 chains
bdata = bdata[~(np.array(bdata.obs['v_call_abT_VDJ_main'].str.endswith('None')) | np.array(bdata.obs['j_call_abT_VDJ_main'].str.endswith('None')) |
np.array(bdata.obs['v_call_abT_VJ_main'].str.endswith('None')) | np.array(bdata.obs['j_call_abT_VJ_main'].str.endswith('None')))]

In [None]:
bdata

# Select neighbourhoods 

In [None]:
## need to redo neighborhood graph after subsetting cells before milo
# n_neighbors decides the minimum neighbourhood size 
# here use_rep = 'X_scvi' as data integration was done using scVI
sc.pp.neighbors(bdata, use_rep = "X_scVI", n_neighbors = 50)
sc.tl.umap(bdata, random_state = 1712)

In [None]:
# take a look at the UMAP to make sure it looks reasonable i.e. different cell types are clustered separately
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = [5.5,5]
sc.pl.umap(bdata, color=['anno_CITE_4v3'], legend_fontsize=5,palette='tab20',legend_loc='on data')

In [None]:
import milopy
import milopy.core as milo

# use milo to sample neighbourhood
milo.make_nhoods(bdata)
# build neighbourhood adata in bdata.uns['nhood_adata']
milo.count_nhoods(bdata, sample_col='sample') # this step is needed to build bdata.uns['nhood_adata'] and sample_col can be anything
# this step is needed for plotting below
milopy.utils.build_nhood_graph(bdata)
# assign neighbourhood celltype by majority voting
# results are in bdata.uns['nhood_adata'].obs['nhood_annotation'] & bdata.uns['nhood_adata'].obs['nhood_annotation_frac'] 
milopy.utils.annotate_nhoods(bdata, anno_col='anno_CITE_4v3')
bdata

In [None]:
bdata.uns['nhood_adata'].uns['annotation_labels'] = np.nan

bdata.write_h5ad('bdata.h5ad')

Now neighbourhood adata is stored in bdata.uns['nhood_adata']

# Create neighbourhood VDJ feature space

In [None]:
#### this option for DP
# function for making neighbourhood vdj feature space
nhood_adata = ddl.tl.vdj_pseudobulk(bdata, pbs = bdata.obsm['nhoods'], obs_to_take = 'anno_CITE_4v3', extract_cols=['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main'])
nhood_adata

     nhood_adata is the new neighbourhood VDJ feature space, whereby each observation is a cell neighbourhood
     VDJ usage frequency stored in nhood_adata.X
     VDJ genes stored in nhood_adata.var
     neighbourhood metadata stored in nhood_adata.obs
     can visualise the data using PCA or UMAP (see below)

In [None]:
# # sort out the annotation colour order
nhood_adata.obs['anno_CITE_4v3'] = nhood_adata.obs['anno_CITE_4v3'].astype('category')
nhood_adata.obs['anno_CITE_4v3'] = nhood_adata.obs['anno_CITE_4v3'].cat.reorder_categories(ct_order)

In [None]:
sc.pp.pca(nhood_adata, random_state = 1712)
sc.pl.pca(nhood_adata, color=['anno_CITE_4v3'])

In [None]:
sc.pp.neighbors(nhood_adata, random_state = 1712)
sc.tl.umap(nhood_adata, random_state = 1712)

In [None]:
sc.pl.umap(nhood_adata, color=['anno_CITE_4v3'],groups=['DP(Q)_CD99_CD31lo','DP(Q)_CD99_CD31hi','DP_pos_sel'],s=20)
sc.pl.umap(nhood_adata, color=['anno_CITE_4v3'],s=20)

## Run Pseudotime on VDJ feature space

In [None]:
# make sure you install palantir if you don't already have it

# Run diffusion maps
pca_projections = pd.DataFrame(nhood_adata.obsm['X_pca'], index=nhood_adata.obs_names)
dm_res = palantir.utils.run_diffusion_maps(pca_projections, n_components=10)
dm_res

In [None]:
plt.scatter(np.arange(10), dm_res['EigenValues'])

In [None]:
# based on plot above, choose n_eigs
ms_data = palantir.utils.determine_multiscale_space(dm_res, n_eigs=5)

In [None]:
# select the start and end points
# start
tmp = nhood_adata[nhood_adata.obs['anno_CITE_4v3'] == 'DP(P)_early']
rootcell = np.argmax(tmp.obsm['X_umap'][:,0])
rootcell = tmp.obs_names[rootcell]
nhood_adata.obs['rootcell'] = 0
nhood_adata.obs.loc[rootcell,'rootcell'] = 1

In [None]:
# ends
tmp1 = nhood_adata[nhood_adata.obs['anno_CITE_4v3'] == 'SP_CD8_mature']
tmp2 = nhood_adata[nhood_adata.obs['anno_CITE_4v3'] == 'SP_CD4_mature']
endcell1 = np.argmax(tmp1.obsm['X_umap'][:,1])
endcell1 = tmp1.obs_names[endcell1]
endcell2 = np.argmax(tmp2.obsm['X_umap'][:,0])
endcell2 = tmp2.obs_names[endcell2]

terminal_states = pd.Series(['SP_CD8_mature', 'SP_CD4_mature'], 
                           index=[endcell1,endcell2])

In [None]:
# plot rootcell and terminal states
nhood_adata.obs['terminal_states'] = 0
nhood_adata.obs.loc[terminal_states.index, 'terminal_states'] = 1
plt.rcParams["figure.figsize"] = [4,4]
sc.pl.umap(nhood_adata,color=['rootcell','terminal_states','anno_CITE_4v3'],
           title=['root cell','terminal states','nhood annotation'],color_map='OrRd',s=10)

In [None]:
pr_res = palantir.core.run_palantir(ms_data,  rootcell, num_waypoints=500, 
                                    terminal_states = terminal_states.index)

In [None]:
pr_res.branch_probs.columns = terminal_states[pr_res.branch_probs.columns]

## Visualise the data

In [None]:
ddl.tl.pseudotime_transfer(adata = nhood_adata, pr_res = pr_res, suffix = '_nhood_vdj')

In [None]:
plt.rcParams["figure.figsize"] = [4,4]
plot = ['pseudotime', 'prob_SP_CD8_mature', 'prob_SP_CD4_mature']
sc.pl.umap(nhood_adata,color=[term + '_nhood_vdj' for term in plot],
           title=['pseudotime','branch probability to CD8',
                  'branch probability to CD4'],
           frameon=False,wspace=0.1,
           color_map = 'RdYlBu_r'
          )

In [None]:
# save nhood object
# nhood_adata.write_h5ad('nhood_adata.h5ad')
nhood_adata = sc.read_h5ad('nhood_adata.h5ad')

# end of testing 


## Project pseudotime and branch probabilities back to cells

In [None]:

# adata.obs['norm_'+anno] = s_normalized
# # project the nhood level pseudotime to cell level pseudotime.
# cdata = ddl.tl.project_pseudotime_to_cell(adata = bdata, 
#                                pb_adata = nhood_adata, 
#                                term_states=['SP_CD8_mature','SP_CD4_mature'], 
#                                suffix = '_nhood_vdj')

In [None]:
sns.set_theme(style='white')
fig, ax = plt.subplots(figsize=(15,5))
cdata.obs['anno_CITE_4v3'] = cdata.obs['anno_CITE_4v3'].cat.reorder_categories(ct_order)
df = cdata.obs.copy()
sigma = 0.05
df['prob_SP_CD8_mature_nhood_vdj'] = df['prob_SP_CD8_mature_nhood_vdj'] + (np.random.rand(len(df['prob_SP_CD8_mature_nhood_vdj'])))*sigma
ax= sns.scatterplot(data=df, 
                    x ='pseudotime_nhood_vdj',
                    y='prob_SP_CD8_mature_nhood_vdj',
                    s=4,
                    hue='anno_CITE_4v3',
                   )
ax.set_ylabel('probability to CD8')    
ax.set_xlabel('pseudotime')
ax.set_title('')
h,l = ax.get_legend_handles_labels()
l1 = ax.legend(h[:20],l[:20], loc='upper right',bbox_to_anchor=(1.2, 1),frameon=False, fontsize='small')
# plt.savefig(fig_path+'/pseudotime_scatterplot_nhood_vdj.pdf',bbox_inches='tight')

# look at the TRAV/J expression in DP(Q) that are beyond the bifurcation point

In [None]:
# # look at TRAV/TRAJ expression 
bulk_adata = ddl.tl.vdj_pseudobulk(adata = cdata, obs_to_bulk = ['anno_CITE_4v3_pseudotime_bin'], obs_to_take = ['anno_CITE_4v3_pseudotime_bin'],
                                   extract_cols= ['v_call_abT_VDJ_main', 'j_call_abT_VDJ_main','v_call_abT_VJ_main', 'j_call_abT_VJ_main'])

In [None]:
# load TCR list by location - lists in github folder metadata/TCR_genes
TCR_list_by_loc_dict = {}
for chain in ['TRAV','TRAJ']:
    path = '/lustre/scratch126/cellgen/team205/cs42/VDJ_collab_manuscript/gene_list/'+chain+'_list_by_location.csv'
    TCR_list_by_loc_dict[chain] = list(pd.read_csv(path, header=None)[0])

In [None]:
gene_intersection = [gene for gene in TCR_list_by_loc_dict['TRAV'] if gene in bulk_adata.var_names] +[gene for gene in TCR_list_by_loc_dict['TRAJ'] if gene in bulk_adata.var_names] 
trav = pd.DataFrame(index = bulk_adata.obs['anno_CITE_4v3_pseudotime_bin'], columns = gene_intersection,
                    data = bulk_adata[:,gene_intersection].X)
trav = trav

In [None]:

ct_order = ['DP(P)_early','DP(Q)_rearr_bin_1','DP(Q)_rearr_bin_2','DP(Q)_rearr_bin_3','DP(Q)_rearr_bin_4','DP(Q)_rearr_bin_5','DP_pos_sel',
                'SP_CD4_immature','SP_CD4_semimature','SP_CD4_mature', 'SP_CD8_immature', 'SP_CD8_semimature','SP_CD8_mature']
trav = trav.reindex(ct_order)


In [None]:
plt.rcParams["figure.figsize"] = [30,10]
svm = sns.heatmap(trav)
figure = svm.get_figure()    
figure.savefig(fig_path+'/vdj_usage.pdf', dpi=200)
# pseudotime_scatterplot_nhood_vdj.pdf',bbox_inches='tight'
# DP_late is using the end of TRAV/TRAJv


In [None]:
# return annotations to original object 
adata_abtcr

In [None]:
cdata

In [None]:

adata_abtcr.obs['pseudotime_nhood_vdj'] = cdata.obs['pseudotime_nhood_vdj']
adata_abtcr.obs['prob_SP_CD8_mature_nhood_vdj'] = cdata.obs['prob_SP_CD8_mature_nhood_vdj']
adata_abtcr.obs['prob_SP_CD4_mature_nhood_vdj'] = cdata.obs['prob_SP_CD4_mature_nhood_vdj']

In [None]:
plt.rcParams["figure.figsize"] = [10,10]

sc.pl.umap(adata_abtcr, color=['anno_CITE_4v3','pseudotime_nhood_vdj','prob_SP_CD8_mature_nhood_vdj','prob_SP_CD4_mature_nhood_vdj'],cmap='jet', legend_fontsize=1,palette='tab20',legend_loc='on data')

In [None]:
adata_abtcr.write()

In [None]:
adata_abtcr = sc.read()

In [None]:
adata_abtcr.obs.to_csv()