In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os

import scanpy as sc
import scirpy as ir
import skbio
import anndata as ann
import numpy as np
import pandas as pd
import seaborn as sb
from tqdm import tqdm
import math
from scipy import stats, sparse

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as clrs

from matplotlib import rcParams
import matplotlib.gridspec as gridspec

In [None]:
from skbio.stats.composition import clr

In [None]:
import sys
sys.path.append('..')
import utils.visualisation as utils_vis

In [None]:
path_figs = '../../figures/mvp/'

sc.settings.set_figure_params(dpi=600)
sc._settings.ScanpyConfig(figdir=path_figs)
sc.settings.verbosity = 3
sc.set_figure_params(vector_friendly=True, color_map='viridis', transparent=True)
sb.set_style('whitegrid')
sc._settings.settings._vector_friendly=False

colormap = 'flare'
binding_mode = 'binding_ct'
dpi = 600

In [None]:
adata = sc.read('../../data/mvp/02_mvp_annotated_cd8.h5ad')

In [None]:
path_figs = f'{path_figs}/{adata.uns["celltype"]}'

In [None]:
epitopes = adata.uns['epitopes']
cite_ids = adata.uns['cite_ids']
custom_cite_ids = adata.uns['custom_cite_ids']
cite_ids_full = cite_ids.tolist() + custom_cite_ids.tolist()

## Dotplots

In [None]:
markers_6b = ['CD4', 'CD8A', 'CD8B', 'TBX21', 'EOMES', 'IL7R', 
              'CCR6', 'ITGA1', 'JAML', 'PDCD1', 'LAG3', 'TIGIT', 
              'KLRB1', 'KLRD1', 'IFNG', 'GZMA', 'GZMB', 'GZMH', 
              'GZMK', 'PRF1', 'NKG7', 'IL32', 'CCL3', 'CCL4', 'CCL5']

plot = sc.pl.dotplot(adata, markers_6b, show=True, groupby='leiden', title='By Leiden')
plot = sc.pl.dotplot(adata, markers_6b, show=True, groupby='binding_10x', title='By Binding 10x')
plot = sc.pl.dotplot(adata, markers_6b, show=True, groupby='donor', title='By Donor')

In [None]:
selected_markers = ['SELL', 'TCF7', 'LTB', 'IL7R', 'GZMK',
                    'KLRG1', 'CCL5', 'NKG7', 'GZMH', 'GZMB', 'PRF1', 'IFNG', 'IL2', 
                    'TNF', 'MKI67', 'CX3CR1', 
                    'KLRB1', 'TRAV1-2', 'KIR3DL1', 'PDCD1', 
                    'CTLA4', 'LAG3', 'TIGIT', 'TOX']
plot = sc.pl.dotplot(adata, selected_markers, show=True, groupby='leiden', use_raw=False)

## Leiden Distribution

In [None]:
ir.pl.group_abundance(adata, groupby='leiden', target_col='binding_10x')
ir.pl.group_abundance(adata, groupby='leiden', target_col='binding_10x', normalize=True)

## Clone purity
To assess purity and consistency of Dex labeling per clonotype, display Dex assignment per clone as in Minervina et al. Fig. 2b; perhaps first focus on clones with > 5 cells.

In [None]:
colors = dict(zip([el for el in adata.uns['epitopes'] if el in adata.obs['binding_10x'].values], 
                  adata.uns['binding_10x_colors']))
colors_epitope = dict(zip(adata.obs['donor'].unique(), adata.uns['donor_colors']))
colors.update(colors_epitope)
colors['No binding'] = 'tab:gray'

In [None]:
rcParams['figure.figsize'] = (5, 30)
adata_tmp = adata[adata.obs['clone_id']!='nan']
adata_tmp = adata_tmp[adata_tmp.obs['clone_id_size']>5].copy()
order = adata_tmp.obs['clone_id'].value_counts().index.tolist()

for col in ['binding_10x', 'donor']:
    adata_tmp.obs[col] = adata_tmp.obs[col].replace(['None'], 'No binding')
    df_tmp = adata_tmp.obs[['clone_id', col]]
    df_tmp = df_tmp.groupby('clone_id')[col].value_counts(normalize=True).unstack()
    df_tmp = df_tmp.reindex(reversed(order))
    
    ax = df_tmp.plot(kind='barh', stacked=True, color=colors)
    ax.legend(bbox_to_anchor=(1., 0.75))
    ax.grid(False)
    ax.set_xticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(axis='y', labelsize='x-small')
    ax.set_title(col)
    plt.show()

In [None]:
rcParams['figure.figsize'] = (30, 10)

for donor in adata.obs['donor'].unique():
    adata_tmp = adata[adata.obs['clone_id']!='nan']
    adata_tmp = adata_tmp[adata_tmp.obs['donor']==donor]
    adata_tmp = adata_tmp[adata_tmp.obs['clone_size_donor_ct']>1]
    if len(adata_tmp) == 0:
        continue
    order = adata_tmp.obs['clone_id'].value_counts().index.tolist()

    df_tmp = adata_tmp.obs[['clone_id', 'binding_10x']]
    df_tmp = df_tmp.groupby('clone_id')['binding_10x'].value_counts(normalize=False).unstack()
    df_tmp = df_tmp.reindex(order)

    specific_cts = adata_tmp[adata_tmp.obs['binding_10x']!='No binding'].obs['clone_id'].unique()
    df_tmp = df_tmp[df_tmp.index.isin(specific_cts)].copy()
    
    if len(df_tmp) == 0:
        continue

    ax = df_tmp.plot(kind='bar', stacked=True, color=colors)
    ax.legend(bbox_to_anchor=(1., 0.75))
    ax.grid(False)
    #ax.set_xticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    #ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(axis='y', labelsize='x-small')
    ax.set_title(donor)
    plt.savefig(f'{path_figs}/spec/ct_expansion_10x_{donor}.pdf', bbox_inches='tight', dpi=300,)
    plt.show()

In [None]:
rcParams['figure.figsize'] = (30, 10)

for donor in adata.obs['donor'].unique():
    adata_tmp = adata[adata.obs['clone_id']!='nan']
    adata_tmp = adata_tmp[adata_tmp.obs['donor']==donor]
    adata_tmp = adata_tmp[adata_tmp.obs['clone_size_donor_ct']>1]
    if len(adata_tmp) == 0:
        continue
    order = adata_tmp.obs['clone_id'].value_counts().index.tolist()

    df_tmp = adata_tmp.obs[['clone_id', 'binding_10x']]
    df_tmp = df_tmp.groupby('clone_id')['binding_10x'].value_counts(normalize=True).unstack()
    df_tmp = df_tmp.reindex(order)

    specific_cts = adata_tmp[adata_tmp.obs['binding_10x']!='No binding'].obs['clone_id'].unique()
    df_tmp = df_tmp[df_tmp.index.isin(specific_cts)].copy()
    
    if len(df_tmp) == 0:
        continue

    ax = df_tmp.plot(kind='bar', stacked=True, color=colors)
    ax.legend(bbox_to_anchor=(1., 0.75))
    ax.grid(False)
    #ax.set_xticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    #ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(axis='y', labelsize='x-small')
    ax.set_title(donor)
    plt.savefig(f'{path_figs}/spec/ct_expansion_10x_{donor}_norm.pdf', bbox_inches='tight', dpi=300,)
    plt.show()

In [None]:
rcParams['figure.figsize'] = (30, 10)

for donor in adata.obs['donor'].unique():
    adata_tmp = adata[adata.obs['clone_id']!='nan']
    adata_tmp = adata_tmp[adata_tmp.obs['donor']==donor]
    adata_tmp = adata_tmp[adata_tmp.obs['clone_size_donor_ct']>1]
    if len(adata_tmp) == 0:
        continue
    order = adata_tmp.obs['clone_id'].value_counts().index.tolist()

    df_tmp = adata_tmp.obs[['clone_id', 'binding_10x']]
    df_tmp = df_tmp.groupby('clone_id')['binding_10x'].value_counts().unstack()
    df_tmp = df_tmp / (adata.obs['donor']==donor).sum()
    df_tmp = df_tmp.reindex(order)

    specific_cts = adata_tmp[adata_tmp.obs['binding_10x']!='No binding'].obs['clone_id'].unique()
    df_tmp = df_tmp[df_tmp.index.isin(specific_cts)].copy()
    
    if len(df_tmp) == 0:
        continue

    ax = df_tmp.plot(kind='bar', stacked=True, color=colors)
    ax.legend(bbox_to_anchor=(1., 0.75))
    ax.grid(False)
    #ax.set_xticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    #ax.spines['bottom'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.tick_params(axis='y', labelsize='x-small')
    ax.set_title(donor)
    plt.savefig(f'{path_figs}/spec/ct_expansion_10x_{donor}_over_all_cts.pdf', bbox_inches='tight', dpi=300,)
    plt.show()

In [None]:
for thresh in [0, 1, 5]:
    for donor in adata.obs['donor'].unique():
        adata_tmp = adata[adata.obs['clone_id']!='nan']
        adata_tmp = adata_tmp[adata_tmp.obs['donor']==donor]
        adata_tmp = adata_tmp[adata_tmp.obs['clone_size_donor_ct']>thresh]
        if len(adata_tmp) == 0:
            continue
        order = adata_tmp.obs['clone_id'].value_counts().index.tolist()

        df_tmp = adata_tmp.obs[['clone_id', 'n_max_dextramer', 'binding_10x']]
        specific_cts = adata_tmp[adata_tmp.obs['binding_10x']!='No binding'].obs['clone_id'].unique()

        rcParams['figure.figsize'] = (1.5*len(specific_cts), 8)

        df_tmp = df_tmp[df_tmp['clone_id'].isin(specific_cts)].copy()

        order = [el for el in order if el in specific_cts]
        if len(df_tmp) == 0:
            continue
        ax = sb.swarmplot(data=df_tmp, x='clone_id', y='n_max_dextramer', order=order, 
                          hue='binding_10x', palette=colors, s=5)
        ax.set_title(donor)
        plt.savefig(f'{path_figs}/spec/ct_expansion_10x_{donor}_umis_{thresh}.pdf', bbox_inches='tight', dpi=300,)
        plt.show()

In [None]:
ncols = 3
nrows = 2
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 3, nrows * 3))
axes = axes.reshape(-1)


vmax = adata.obs[epitopes].max().max()
for i, donor in enumerate(adata.obs['donor'].value_counts().index):
    adata_tmp = adata[adata.obs['donor']==donor]
    adata_tmp = adata_tmp[adata_tmp.obs['binding_10x']!='No binding']
    df_counts = pd.DataFrame(adata_tmp.obs.groupby('binding_10x')['clone_id'].nunique())
    df_counts = df_counts.sort_values('clone_id')
    df_counts = df_counts.reset_index()
    sb.barplot(data=df_counts, y='clone_id', x='binding_10x', order=reversed(df_counts['binding_10x']), 
               ax=axes[i], palette=colors)
    axes[i].set_xlabel(donor)
    axes[i].set_xticklabels([el.get_text()[:3] for el in axes[i].get_xticklabels()], rotation=90)
    axes[i].set_ylabel('Unique CTs')
    axes[i].set_yscale('log')
#axes[1].legend()
plt.tight_layout()
plt.show()


In [None]:
for donor in adata.obs['donor'].value_counts().index:
    print(donor)
    adata_tmp = adata[adata.obs['binding_10x']!='No binding']
    adata_tmp = adata_tmp[adata_tmp.obs['donor']==donor]
    print(adata_tmp.obs.groupby('binding_10x')['clone_id'].nunique())

In [None]:
adata_tmp = adata[adata.obs['binding_10x']!='No binding']
print(adata_tmp.obs.groupby('binding_10x')['clone_id'].nunique())

In [None]:
df_tmp = adata[adata.obs['binding_10x']!='No binding'].obs
df_tmp = df_tmp[['donor', 'clone_id', 'clone_size_donor_ct']].drop_duplicates()
df_tmp.groupby('donor')['clone_size_donor_ct'].mean()

In [None]:
ir.pl.group_abundance(adata, groupby='donor', target_col='binding_ct', normalize=False)
for d in adata.obs['donor'].value_counts().index:
    adata_tmp = adata[adata.obs['donor']==d]
    print(d)
    print(adata_tmp.obs['binding_ct'].value_counts())

In [None]:
adata.obs['has_ltd'] = adata.obs['binding_10x'].apply(lambda x: 'LTD' if x=='LTDEMIAQY' else np.nan)
adata.obs['is_mvp'] = adata.obs['donor'].apply(lambda x: 'MVP' if x=='MVP' else 'Other')

In [None]:
colors_tmp = {
    'No LTDs': 'gray',
    'Other': 'tab:blue',
    'MVP': 'tab:orange',
}
df_ltds = adata.obs.groupby(['leiden', 'has_ltd'])['is_mvp'].value_counts().unstack()
df_ltds.index = [el[0] for el in df_ltds.index]
df_ltds.index.name = 'leiden'
df_ltds['No LTDs'] = adata.obs['leiden'].value_counts()
df_ltds['No LTDs'] = df_ltds['No LTDs'] - df_ltds['MVP'] - df_ltds['Other']
df_ltds.plot(kind='bar', stacked=True, color=colors_tmp)

In [None]:
df_ltds_norm = df_ltds.div(df_ltds.sum(axis=1), axis=0)
df_ltds_norm.plot(kind='bar', stacked=True, color=colors_tmp)

## Binding Modes

Generate histograms for all epitope specificities with log scales of UMI counts on x-axis and number of cells on y-axis label cells that are defined as positive in a different color than the negative cells. Include positive and negative cells in the same plot (not on top of each other). Do this once for all time points pooled and once leaving out the S2 time point since the latter one has higher UMI counts than the other time points. The range of the x-axis should always be the same to better compare expression levels between different epitope specificities.

In [None]:
def plot_binding_distributions(adata, binding_mode, title):
    ncols = 4
    nrows = 4
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 3, nrows * 3))
    axes = axes.reshape(-1)


    vmax = adata.obs[epitopes].max().max()
    for i, c in enumerate(adata.uns['epitopes']):
        if np.sum(adata.obs[binding_mode] == c) > 0:
            ax2 = axes[i].twinx()
            sb.histplot(adata[adata.obs[binding_mode] == c].obs[c], ax=ax2, color='tab:green', stat='percent', kde=True)
            ax2.set_yticks([])
            ax2.set_yticklabels('')
            ax2.set_ylabel(None)

        sb.histplot(adata[adata.obs[binding_mode] != c].obs[c], ax=axes[i], color='tab:red', stat='percent', kde=True)
        axes[i].set_yticks([])
        axes[i].set_yticklabels([])
        axes[i].set_ylabel(None)
        axes[i].set_title(c)
        axes[i].set_xscale('symlog')
        axes[i].set_xlim((0, vmax))
        axes[i].set_xlabel(None)
        
        axes[i].grid(False, axis='y', which='both')

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()
    
for mode in ['binding_ct', 'binding_10x']:
    plot_binding_distributions(adata, mode, title=f'{mode} - Full Data')
    #adata_tmp = adata[adata.obs['sample']!='S2']
    #plot_binding_distributions(adata_tmp, mode, title=f'{mode} - wo S2')

### Purity vs UMI

In [None]:
bindings = adata.obs['binding_10x'].astype(str).value_counts().index.tolist()
bindings.remove('No binding')
nrows = 4
ncols = 4


fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 3, nrows * 3))
axes = axes.reshape(-1)

for i, ep in enumerate(bindings):
    cts_bind = adata[adata.obs['binding_10x']==ep].obs['clone_id'].unique().tolist()
    if 'nan' in cts_bind:
        cts_bind.remove('nan')
    df_tmp = pd.DataFrame(index=cts_bind, columns=['Purity', 'UMI'])
    for ct in cts_bind:
        df_ct = adata[adata.obs['clone_id']==ct].obs
        umi = df_ct[ep].mean()
        df_tmp.loc[ct, 'UMI'] = umi
        purity = np.sum(df_ct['binding_10x']==ep) / len(df_ct)
        df_tmp.loc[ct, 'Purity'] = purity
    
    sb.scatterplot(data=df_tmp, x='Purity', y='UMI', ax=axes[i])

    if len(df_tmp) > 1:
        corr, pval = stats.pearsonr(df_tmp['UMI'].values, df_tmp['Purity'].values)
    else:
        corr, pval = np.nan, np.nan
    axes[i].set_title(f'{ep}\ncorr: {corr:.2f}, pval: {pval:.2f}')

plt.suptitle('UMI vs Purity')
plt.tight_layout()
plt.show()

In [None]:
bindings = adata.obs['binding_10x'].astype(str).value_counts().index.tolist()
bindings.remove('No binding')
nrows = 3
ncols = 4


fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 3, nrows * 3))
axes = axes.reshape(-1)

for i, ep in enumerate(bindings):
    cts_bind = adata[(adata.obs['binding_10x']==ep)
                     & (adata.obs['clone_size_ct']>1)
                    ].obs['clone_id'].unique().tolist()
    if 'nan' in cts_bind:
        cts_bind.remove('nan')
    df_tmp = pd.DataFrame(index=cts_bind, columns=['Purity', 'UMI', 'Expansion'])
    for ct in cts_bind:
        df_ct = adata[adata.obs['clone_id']==ct].obs
        umi = df_ct[ep].mean()
        df_tmp.loc[ct, 'UMI'] = umi
        purity = np.sum(df_ct['binding_10x']==ep) / len(df_ct)
        df_tmp.loc[ct, 'Purity'] = purity
        exp = df_ct['clone_size_ct']
        df_tmp.loc[ct, 'Expansion'] = df_ct.iloc[0]['clone_size_ct']
    
    sb.scatterplot(data=df_tmp, x='Purity', y='UMI', ax=axes[i])

    if len(df_tmp) > 1:
        corr, pval = stats.pearsonr(df_tmp['UMI'].values, df_tmp['Purity'].values)
    else:
        corr, pval = np.nan, np.nan
    axes[i].set_title(f'{ep}\ncorr: {corr:.2f}, pval: {pval:.2f}')

plt.suptitle('UMI vs Purity')
plt.tight_layout()
plt.show()

In [None]:
for donor in adata.obs['donor'].unique():
    adata_donor = adata[adata.obs['donor']==donor]
    #axes = axes.reshape(-1)
    
    eps = adata_donor.obs[binding_mode].astype(str).value_counts().index.tolist()
    eps = [el.split('+') for el in eps if (el!='nan') and (el!='No binding')]
    eps = list(set([el for sublist in eps for el in sublist]))
    
    ncols = len(eps)
    fig, axes = plt.subplots(nrows=1, ncols=ncols, figsize=(ncols * 5, 1 * 5))
    if type(axes) != np.ndarray:
        axes = [axes]
    for i, ep in enumerate(eps):

        df_ep = adata_donor[adata_donor.obs[binding_mode].astype(str).str.contains(ep)].obs.copy()
        df_ep['Purity_umi'] = df_ep[ep] / df_ep['n_count_dextramer']
        df_ep = df_ep[['Purity_umi', ep, 'binding_ct', 'clone_id']]
        df_ep = df_ep[df_ep[ep].notna()]

        if len(df_ep) == 0:
            continue
        
        sb.scatterplot(data=df_ep, x=ep, y='Purity_umi', ax=axes[i], hue='binding_ct')

        if len(df_ep) > 1:
            corr, pval = stats.pearsonr(df_ep[ep].values, df_ep['Purity_umi'].values)
        else:
            corr, pval = np.nan, np.nan
        axes[i].set_title(f'{ep}\ncorr: {corr:.2f}, pval: {pval:.2f}')
        axes[i].set_xlabel('UMI')

    plt.suptitle(f'{donor} - UMI conts vs UMI Purity')
    plt.tight_layout()
    plt.show()

In [None]:
adata[adata.obs['binding_ct'].notna()].obs.groupby('donor')['clone_id'].nunique()

## UMAPs

### Selected Genes

In [None]:
rcParams['figure.figsize'] = (5, 5)
sc.pl.violin(adata, keys=['JUNB', 'CD69'], groupby='leiden')
sc.pl.violin(adata, keys=['CD8 Cytotoxic_score', 'CD8 Cytokine_score'], groupby='leiden')

In [None]:
sc.pl.umap(adata, color=['JUNB', 'CD69'], cmap=colormap)
sc.pl.umap(adata, color=['CD8 Cytotoxic_score', 'CD8 Cytokine_score'], cmap=colormap)
sc.pl.umap(adata, color='MKI67', cmap=colormap)
sc.pl.umap(adata, color='SP140', cmap=colormap)
sc.pl.umap(adata, color='LAG3', cmap=colormap)
sc.pl.umap(adata, color='TIGIT', cmap=colormap)
sc.pl.umap(adata, color='ifn_seumois', cmap=colormap)

In [None]:
for c in cite_ids_full:
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(1 * 8, 1 * 4))
    sc.pl.umap(adata, color=c, ax=axes[0], show=False)
    sc.pl.umap(adata, color=f'clr_{c}', ax=axes[1], show=False)
    
    plt.tight_layout()
    plt.savefig(f'{path_figs}/citeSeq/{c}.png', bbox_inches='tight', dpi=300,)
    plt.show()

## SARS-CoV vs Others

In [None]:
epitopes_covid = ['LTDEMIAQY' , 'QPYRVVVL', 'TFEYVSQPFLMDLE', 'YLQPRTFLL',
                  'RLQSLQTYV', 'VLNDILSRL', 'KIADYNYKL', 'YTNSFTRGVY', 'NYNYLYRLF',
                  'SPRRARSVA', 'FPQSAPHGV', 'IYKTPPIKDF']
epitopes_others = ['ATDSLNNEY', 'CTELKLSDY', 'FLRGRAYGL', 'RAKFKQLL', ]
adata.obs['specificity_group'] = np.nan
adata.obs.loc[adata.obs['binding_10x'].isin(epitopes_covid), 'specificity_group'] = 'Sars-CoV'
adata.obs.loc[adata.obs['binding_10x'].isin(epitopes_others), 'specificity_group'] = 'Other'
adata.obs['specificity_group'].value_counts(dropna=False)

In [None]:
sc.pl.umap(adata, color='specificity_group', s=20)
sc.pl.umap(adata, color='binding_10x', s=20)

In [None]:
ncols = 4
nrows = 4
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 3, nrows * 3))
axes = axes.reshape(-1)


vmax = adata.obs[epitopes].max().max()
for i, c in enumerate(adata.uns['epitopes']):
    sc.pl.umap(adata, ax=axes[i], show=False)
    adata_tmp = adata[adata.obs['binding_10x']==c] 
    if len(adata_tmp) > 0:
        sc.pl.umap(adata_tmp, color='binding_10x', s=40,
                   ax=axes[i], show=False)
    axes[i].legend().remove()
    axes[i].set_ylabel(None)
    axes[i].set_title(c)
    axes[i].set_xlabel(None)

plt.tight_layout()
plt.savefig(f'{path_figs}/specificity_10x.pdf', 
            bbox_inches='tight', dpi=300,)
plt.show()

In [None]:
ncols = 2
nrows = 3
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 3.5, nrows * 3))
axes = axes.reshape(-1)


vmax = adata.obs[epitopes].max().max()
for i, c in enumerate(adata.obs['donor'].value_counts().index):
    sc.pl.umap(adata, ax=axes[i], show=False)
    adata_tmp = adata[adata.obs['donor']==c] 
    adata_tmp = adata_tmp[adata_tmp.obs['binding_10x']!='No binding']
    if len(adata_tmp) > 0:
        sc.pl.umap(adata_tmp, color='binding_10x', s=40,
                   ax=axes[i], show=False)
    if i != 1:
        axes[i].legend().remove()
    axes[i].set_ylabel(None)
    axes[i].set_title(c)
    axes[i].set_xlabel(None)

plt.tight_layout()
plt.savefig(f'{path_figs}/specificity_donor_10x.pdf', 
            bbox_inches='tight', dpi=300,)
plt.show()

In [None]:
ncols = 2
nrows = 3
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 3.5, nrows * 3))
axes = axes.reshape(-1)


vmax = adata.obs[epitopes].max().max()
for i, c in enumerate(adata.obs['donor'].value_counts().index):
    sc.pl.umap(adata, ax=axes[i], show=False)
    adata_tmp = adata[adata.obs['donor']==c] 
    adata_tmp = adata_tmp[adata_tmp.obs['binding_10x']=='LTDEMIAQY']
    if len(adata_tmp) > 0:
        sc.pl.umap(adata_tmp, color='binding_10x', s=40,
                   ax=axes[i], show=False)
    if i != 1:
        axes[i].legend().remove()
    axes[i].set_ylabel(None)
    axes[i].set_title(c)
    axes[i].set_xlabel(None)

plt.tight_layout()
plt.savefig(f'{path_figs}/ltd_donor_10x.pdf', 
            bbox_inches='tight', dpi=300,)
plt.show()

In [None]:
utils_vis.separate_umaps_by_condition(adata, 'donor', 3, 2, do_int_sort=False)

In [None]:
rcParams['figure.figsize'] = (4, 4)
adata.obs['LTD_donor'] = 'LTD-MVP'
adata.obs.loc[adata.obs['donor']!='MVP', 'LTD_donor'] = 'LTD-Cova'
adata.obs.loc[adata.obs['binding_10x']!='LTDEMIAQY', 'LTD_donor'] = np.nan
sc.pl.umap(adata, color='LTD_donor')

## LTD by donor

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sc.pl.umap(adata, show=False, ax=ax)
sc.pl.umap(adata[adata.obs[binding_mode]!='No binding'], color='binding_ct', show=False, ax=ax,
          s=50)
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sc.pl.umap(adata, show=False, ax=ax)
sc.pl.umap(adata[adata.obs[binding_mode]=='LTDEMIAQY'], color='donor', show=False, ax=ax,
          s=50)
ax.set_title(f'LTD by donor {binding_mode}')
plt.tight_layout()
plt.savefig(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_umap_ltd.pdf', 
            bbox_inches='tight', dpi=dpi)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
df_tmp = adata[(adata.obs[binding_mode]=='LTDEMIAQY')
              & (adata.obs['clone_id']!='nan')].obs[['donor', 'clone_id']]
df_tmp = df_tmp.groupby('donor')['clone_id'].value_counts().unstack().fillna(0)

# Ordering the df
stack_order = df_tmp.sum(axis=0).sort_values(ascending=False).index
df_tmp = df_tmp[stack_order]
donor_order = df_tmp.sum(axis=1).sort_values().index
df_tmp = df_tmp.loc[donor_order]


df_tmp.plot(kind='bar', stacked=True,# color=adata.uns['leiden_colors'], 
            ax=ax)

n_cts_by_donor = (df_tmp>0).sum(axis=1)
for i, (donor, n) in enumerate(n_cts_by_donor.items()):
    ax.text((i+0.5)/len(n_cts_by_donor), 1, f'n={n}', ha='center', 
            transform=ax.transAxes)


ax.legend().remove()
sb.despine(ax=ax)
ax.grid(False)
ax.set_ylabel('# Cells')
plt.tight_layout()
plt.savefig(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_ltd_clone_expansion.pdf', 
            bbox_inches='tight', dpi=dpi)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
df_tmp = adata[(adata.obs[binding_mode]=='LTDEMIAQY')
              & (adata.obs['clone_id']!='nan')].obs[['donor', 'clone_id']]
df_tmp = df_tmp.groupby('donor')['clone_id'].value_counts().unstack().fillna(0)

# Ordering the df
stack_order = df_tmp.max(axis=0).sort_values(ascending=False).index
df_tmp = df_tmp[stack_order]
donor_order = df_tmp.sum(axis=1).sort_values().index
df_tmp = df_tmp.loc[donor_order]


df_tmp.plot(kind='bar', stacked=True,# color=adata.uns['leiden_colors'], 
            ax=ax)

n_cts_by_donor = (df_tmp>0).sum(axis=1)
for i, (donor, n) in enumerate(n_cts_by_donor.items()):
    ax.text((i+0.5)/len(n_cts_by_donor), 1, f'n={n}', ha='center', 
            transform=ax.transAxes)


ax.legend().remove()
sb.despine(ax=ax)
ax.grid(False)
ax.set_ylabel('# Cells')
plt.tight_layout()
plt.savefig(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_ltd_clone_expansion_v2.pdf', 
            bbox_inches='tight', dpi=dpi)
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
df_tmp = adata[(adata.obs[binding_mode]=='LTDEMIAQY')
              & (adata.obs['clone_id']!='nan')].obs[['donor', 'leiden']]
df_tmp = df_tmp.groupby('donor')['leiden'].value_counts(normalize=True).unstack().fillna(0)

df_tmp = df_tmp.loc[donor_order]
df_tmp.plot(kind='bar', stacked=True, ax=ax)

ax.legend(bbox_to_anchor=(1., 0.75), title='Leiden')
sb.despine(ax=ax)
ax.grid(False)
ax.set_ylabel('Fraction Leiden')
ax.set_title('LTD-specific cells')
plt.tight_layout()
plt.savefig(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_ltd_over_leiden.pdf', 
            bbox_inches='tight', dpi=dpi)
plt.show()
df_tmp.to_csv(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_ltd_over_leiden_nrom.csv')

## Pseudotime

In [None]:
fig, ax = plt.subplots(figsize=(5, 5))
sc.pl.umap(adata, color='dpt_pseudotime', ax=ax, show=False)

plt.tight_layout()
plt.savefig(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_umap_pseudotime.pdf', 
            bbox_inches='tight', dpi=dpi)
plt.show()

In [None]:
sc.pl.violin(adata, groupby='leiden', keys='dpt_pseudotime')

## Genes over Pseudo-Time

In [None]:
adata_tmp = adata.copy()
n_bins = 10
bins = [1/n_bins * i for i in range(n_bins)]
adata_tmp.obs['pseudotime_bins'] = np.digitize(adata_tmp.obs['dpt_pseudotime'].values, bins)

count_groups = adata_tmp.obs.groupby(['binding_10x', 'pseudotime_bins', 'leiden']).size()

spec_order = adata.obs['binding_10x'].value_counts().index

n_rows = 3
n_cols = 4
fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * 4, n_rows * 4))
axes = axes.reshape(-1)

for j, spec in enumerate(spec_order):
    ax = axes[j]
    counts_ax = count_groups.loc[spec].unstack()
    counts_ax.plot(kind='bar', stacked=True, color=adata_tmp.uns['leiden_colors'], ax=ax)

    if j!=0:
        ax.legend().remove()
    ax.set_title(spec)
    #ax.set_xticklabels([f'{b:.1f}' for b in bins], rotation=0)
        
fig.tight_layout()
plt.show()

In [None]:
adata_tmp = adata.copy()
n_bins = 10
bins = [1/n_bins * i for i in range(n_bins)]
adata_tmp.obs['pseudotime_bins'] = np.digitize(adata_tmp.obs['dpt_pseudotime'].values, bins)

for group in ['leiden', 'pseudotime_bins']:
    count_groups = adata_tmp.obs.groupby([group, 'clone_size']).size().unstack()

    log_exp_max = np.log(count_groups.columns[-1])
    colors = {el: sb.color_palette('flare', as_cmap=True)(np.log(el)/log_exp_max) for el in count_groups.columns}


    fig, ax = plt.subplots()
    count_groups.plot(kind='bar', stacked=True, color=colors, ax=ax)
    ax.legend().remove()

    norm = mpl.colors.Normalize(vmin=0, vmax=log_exp_max)
    cbar = mpl.cm.ScalarMappable(norm=norm, cmap='flare')
    fig.colorbar(cbar, label='Log expansion')
    ax.set_title(f'Over {group}')
    plt.show()

In [None]:
adata_tmp.obs['pseudotime_bins'].value_counts()

## Cite vs GEX

In [None]:
mapping_cite_gex = pd.read_excel('../../data/external/' +
                                 'TotalSeq_C_Human_Universal_Cocktail_v1_137_Antibodies_399905_Barcodes.xlsx', skiprows=1,
                                index_col=1)
mapping_cite_gex = mapping_cite_gex[['Description', 'Clone', 'Barcode', 'Ensemble ID', 'Gene name']]
mapping_cite_gex = mapping_cite_gex[~mapping_cite_gex['Gene name'].isna()]
mapping_cite_gex['Description'] = mapping_cite_gex['Description'].str.replace('anti-human ', 'Hu.')
dict_replace = {
    'anti-mouse/human ': 'HuMs.',
    'anti-human/mouse ': 'HuMs.',
    'anti-human/mouse/rat ': 'HuMsRt.',
    '-': '.',
    ',': '',
    'integrin': 'integrin.b7',
    'FcεRIα': 'FceRIa',
    'CD105': 'CD105_43A3',
    'CD226': 'CD226_11A8',
    'CD38': 'CD38_HIT2',
    'CD20': 'CD20_2H7'
}
for k, v in dict_replace.items():
    mapping_cite_gex['Description'] = mapping_cite_gex['Description'].str.replace(k, v)
    
mapping_cite_gex['Description'] = mapping_cite_gex['Description'].str.split().str[0]
dict_replace = {
    'Hu.CD4': 'Hu.CD4_RPA.T4',
    'Hu.CD14': 'Hu.CD14_M5E2',
    'Hu.CD45': 'Hu.CD45_HI30',
    'Hu.Ig': 'Hu.Ig.LightChain.k',
    'Hu.CD3': 'Hu.CD3_UCHT1',
}
for k, v in dict_replace.items():
    mapping_cite_gex['Description'] = mapping_cite_gex['Description'].replace(k, v)
print([el for el in mapping_cite_gex['Description'] if el not in adata.uns['cite_ids']])
mapping_cite_gex = dict(mapping_cite_gex[['Description', 'Gene name']].values)
mapping_cite_gex = {k: v.split(', ') for k, v in mapping_cite_gex.items()}
#mapping_cite_gex

In [None]:
genes = [gene for gene_list in mapping_cite_gex.values() for gene in gene_list]
[el for el in genes if el not in adata.var_names]

In [None]:
def jitter(vals):
    return np.random.normal(vals, vals.max()*0.01, vals.shape)

In [None]:
rcParams['figure.figsize'] = (5, 5)
for cite, genes in mapping_cite_gex.items():
    for gene in genes:
        if gene in adata.var_names:
            adata_tmp = adata[~adata.obs[f'clr_{cite}'].isna()]
            x = jitter(adata_tmp.obs[f'clr_{cite}'].values)
            y = jitter(adata_tmp[:, gene].X.A.reshape(-1))
            #values = np.vstack([x, y])
            #kernel = stats.gaussian_kde(values)(values)
            plot = sb.scatterplot(x=x, y=y, s=4)
            plot = sb.kdeplot(x=x, y=y, #c=kernel, 
                              levels=5, fill=True, alpha=0.6, cut=2,)
            plot.set_xlabel(cite)
            plot.set_ylabel(gene)
            plt.savefig(f'../../figures/mvp/{adata.uns["celltype"]}/citeVsGex/{cite}_vs_{gene}.png', dpi=dpi,
                       bbox_inches='tight')
            plt.show()

## Story Structure Plots

### Fig 2A
Leiden clustering of Dextramer UMAP (to show phenotypic assignment)

In [None]:
sc.settings.vector_friendly = False
fig, ax = plt.subplots(figsize=(5, 5))
sc.pl.umap(adata, color='leiden', show=False, ax=ax)
plt.savefig(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_umap_leiden.pdf', 
            bbox_inches='tight', dpi=dpi)
plt.show()

### Fig 2B
Show defining genes for Leiden clusters; try dot plot (average expression, percent expression) with each gene only appearing once but expression shown across all clusters for each gene.

In [None]:
def double_histogram_dotplot(adata, n_genes, groupby='leiden', save_name=None):
    # This Code is highjacking the parts of functions from scanpy, scipy, and matplotlib. 
    # If you re-use it, that's between you and your deity.     
    import scipy.cluster.hierarchy as sch
    from scipy.spatial import distance
    
    # Collect unique list from top 5 DEGs
    top_genes = adata.uns[f'rank_genes_groups_{groupby}']['names'][:n_genes]
    top_genes = top_genes.flatten().tolist()
    top_genes = list(set([g for t in top_genes for g in t]))
    
    # Mean expression per cluster for gene dendrogram
    df_genes = pd.DataFrame(data=adata[:, top_genes].X.A, columns=top_genes, index=adata.obs.index)
    df_genes[groupby] = adata.obs[groupby]
    df_counts = df_genes.groupby(groupby)[top_genes].mean()
    
    # Dendrogram without plotting
    correlation_matrix = df_counts.corr(method='pearson')
    correlation_condensed = distance.squareform(1 - correlation_matrix)
    z_var = sch.linkage(correlation_condensed, method='complete')
    dendro_info = sch.dendrogram(z_var, labels=list(top_genes), no_plot=True, 
                                 color_threshold=0, above_threshold_color='k')
    
    # Normal Dotplot with dummy brackets on top for extra axis
    plot = sc.pl.rank_genes_groups_dotplot(adata, var_names=dendro_info['ivl'], dendrogram=True, show=False, 
                                       var_group_labels=[''], var_group_positions=[(4,10)], 
                                          key=f'rank_genes_groups_{groupby}')
    
    # Delete dummy Bracket
    plot['gene_group_ax'].cla()
    plot['gene_group_ax'].grid(False)
    sb.despine(ax=plot['gene_group_ax'], bottom=True, left=True)
    
    # Add custom dendrogram: Scale the x-Coordinates, put it within the Group Genes axis, and rescale this axis
    scale_factor = (len(top_genes)-0.5)/(np.array(dendro_info['icoord']).max())
    sch._plot_dendrogram(icoords=np.array(dendro_info['icoord'])*scale_factor, dcoords=np.array(dendro_info['dcoord'])+0.03, 
                        ivl=dendro_info['ivl'],
                        p=30, n=len(top_genes), mh=max(z_var[:, 2]),
                       orientation='top', no_labels=True, color_list=dendro_info['color_list'], ax=plot['gene_group_ax'])
    plot['gene_group_ax'].set_xlim([1, len(top_genes)+1])
    pos = plot['gene_group_ax'].get_position()
    plot['gene_group_ax'].set_position([pos.x0, pos.y0, pos.x1-pos.x0, 0.5])

    # Re-adjust the labels since they were lost while plotting
    plot['mainplot_ax'].set_xticks([el+0.5 for el in range(0, len(top_genes))])
    plot['mainplot_ax'].set_xlim([0, len(top_genes)])
    _ = plot['mainplot_ax'].set_xticklabels(dendro_info['ivl'])
    
    if save_name is not None:
        plt.savefig(f'../../figures/mvp/manuscript/{save_name}.pdf', 
                bbox_inches='tight', dpi=dpi)
    plt.show()
    
double_histogram_dotplot(adata, 5, groupby='leiden', 
                         save_name=f'{adata.uns["celltype"]}_gex_degs_2dendro')

fig, ax = plt.subplots(figsize=(15, 4))
sc.pl.rank_genes_groups_dotplot(adata, n_genes=5, dendrogram=True, 
                                key=f'rank_genes_groups_leiden',
                               ax=ax, show=False)
plt.savefig(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_gex_degs.pdf', 
                bbox_inches='tight', dpi=dpi)
plt.show()

In [None]:
adata_cite = ann.AnnData(obs=adata.obs[['leiden']], 
                         uns=adata.uns,
                         X=adata.obs[[f'clr_{el}' for el in adata.uns['cite_ids']]])
adata_cite.var_names = ['_'.join(el.split('_')[1:]) for el in adata_cite.var_names]
adata_cite.X = sparse.csr_matrix(adata_cite.X)
adata_cite.uns.pop('rank_genes_groups_leiden')
adata_cite.uns['rank_genes_groups_leiden'] = adata_cite.uns.pop('rank_genes_groups_leiden_cite')

double_histogram_dotplot(adata_cite, 5, groupby='leiden', 
                        save_name=f'{adata.uns["celltype"]}_cite_degs_2dendro')

fig, ax = plt.subplots(figsize=(15, 4))
sc.pl.rank_genes_groups_dotplot(adata_cite, n_genes=5, dendrogram=True, 
                                key=f'rank_genes_groups_leiden',
                                ax=ax, show=False)
plt.savefig(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_cite_degs.pdf', 
                bbox_inches='tight', dpi=dpi)
plt.show()

### Fig 2C
Showcases of epitope specificities: UMAP with evolution of population over time (as above, take UMAP from all epitope specificities and non-binders from all donors and time points as basis and then highlight only specific data: here highlight one epitope specificity and create a different UMAP for each time point, but merge data for all donors).


In [None]:
n_rows = 4
n_cols = 4
fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * 3, n_rows * 3))
axes = axes.reshape(-1)

colors = dict(zip(adata.obs['leiden'].value_counts().index.astype(str), adata.uns['leiden_colors']))

for i, ep in enumerate(adata.uns['epitopes']):
    ax = axes[i]

    sc.pl.umap(adata, ax=ax, show=False)

    cts_ep = adata[adata.obs[binding_mode]==ep]
    cts_ep = cts_ep.obs['clone_id'].unique().tolist()
    if 'nan' in cts_ep:
        cts_ep.remove('nan')

    adata_tmp = adata[adata.obs['clone_id'].isin(cts_ep)]
    if len(adata_tmp) > 0:
        sc.pl.umap(adata_tmp,
                   color='leiden',
                   palette=colors, 
                   size=30,
                   ax=ax, show=False)
        axes[i].get_legend().remove()

    ax.set_ylabel('')
    ax.set_xlabel('')
    ax.set_title(ep)

fig.tight_layout()
#plt.savefig('../../figures/dextramer/paper/2c_umap_binders_by_epitope+time.pdf', dpi=dpi)
plt.show()

### Fig 2D
Quantifications: Phenotype (Leiden cluster fraction) of different epitope specificities over time

In [None]:
dfs_binding_by_leiden = {}

for ep in adata.uns['epitopes']:
    cts_ep = adata[adata.obs[binding_mode]==ep]
    cts_ep = cts_ep.obs['clone_id'].unique().tolist()
    if 'nan' in cts_ep:
        cts_ep.remove('nan')
    
    df_tmp = adata[adata.obs['clone_id'].isin(cts_ep)].obs[['leiden', 'donor']] 
    df_tmp = df_tmp.value_counts().unstack()
    df_tmp.columns = df_tmp.columns.astype(str)
    
    df_tmp = df_tmp.fillna(0)
    
    
    for s in adata.obs['donor'].unique():
        if s not in df_tmp.columns:
            df_tmp[s] = 0.
    #df_tmp = df_tmp[time_order]
    for l in adata.obs['leiden'].unique():
        if l not in df_tmp.index:
            df_tmp.loc[len(df_tmp)] = 0.
    df_tmp.index = df_tmp.index.astype(int)
    df_tmp = df_tmp.sort_index()
               
    df_tmp = df_tmp.transpose()
    df_tmp = df_tmp.div(df_tmp.sum(axis=1), axis=0)
    df_tmp = df_tmp.fillna(0) 
    dfs_binding_by_leiden[ep] = df_tmp

In [None]:
n_rows = 4
n_cols = 4
fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * 3, n_rows * 3))
axes = axes.reshape(-1)

colors = dict(zip(adata.obs['leiden'].value_counts().index.astype(int), adata.uns['leiden_colors']))

for (i, ax), (ep, df_fracs) in zip(enumerate(axes), dfs_binding_by_leiden.items()):
    plot = df_fracs.plot(kind='bar', stacked=True,  ax=ax, ylim=[0, 1], color=colors,
                        title=ep, xlabel='')
    
    if i==3:
        handles, labels = ax.get_legend_handles_labels()
        label_dict = dict(zip(labels, handles))
        ax.legend(label_dict,
                  markerscale=0.5, #fontsize='x-small', 
                  #ncol=len(df_fracs.columns)//2, 
                  loc='right', 
                  bbox_to_anchor=(1.3, .5),
                  frameon=False)
    else:
        ax.get_legend().remove()
    ax.tick_params(axis='x', labelrotation=90)
    ax.grid(False)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
for j in range(i+1, n_cols*n_rows):
    axes[j].axis('off')
    
fig.tight_layout()
#plt.savefig('../../figures/dextramer/paper/2d_leiden_over_time_per_specificity.pdf', dpi=dpi)
plt.show()


### Fig 2E alternative
Show UMAP with clonal expansion.

In [None]:
sc.pl.umap(adata, color='clone_size')

In [None]:
sc.pl.umap(adata, color='clone_size_clipped')

### Fig 2E
Show UMAP with pseudotime trajectory (starting point to be specified). Then show for each time point and epitope specificity a separate graph (like a histogram) with pseudotime on the x-axis and number of cells on the y-axis, with a color barcode for phenotypic clusters.

In [None]:
sc.pl.umap(adata, color='dpt_pseudotime')

In [None]:
epitopes_present = adata.obs[binding_mode].value_counts().index
n_cols = len(epitopes_present)
n_rows = len(adata.obs['donor'].unique())
fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * 3, n_rows * 3))

colors = dict(zip(adata.obs['leiden'].value_counts().index, adata.uns['leiden_colors']))


step_width = 0.1
adata.obs['pseudotime_bin'] = adata.obs['dpt_pseudotime'].apply(lambda x: x//step_width)
bin_order = sorted(adata.obs['pseudotime_bin'].unique().tolist())

for j, d in enumerate(adata.obs['donor'].unique()):
    for i, ep in enumerate(epitopes_present):
        ax = axes[j][i]
        
        df_counts = adata[(adata.obs[binding_mode]==ep) & 
                          (adata.obs['donor']==d) &
                          (~adata.obs['pseudotime_bin'].isna())]
        df_counts = df_counts.obs[['pseudotime_bin', 'leiden']].copy()
        if len(df_counts) > 0:
            df_counts = df_counts.groupby('pseudotime_bin')['leiden'].value_counts().unstack(fill_value=0)
            
            df_counts.columns = list(df_counts.columns)
            for l in adata.obs['leiden'].unique():
                if l not in df_counts.columns:
                    df_counts[l] = 0
            for bins in adata.obs['pseudotime_bin'].unique():
                if bins not in df_counts.index:
                    df_counts.loc[bins] = 0                    
            df_counts = df_counts.sort_index()
            
            df_counts.plot(kind='bar', stacked='True', ax=ax, color=colors)
            sb.despine(ax=ax)
            ax.grid(axis='x')
            ax.get_legend().remove()
            
        else:
            ax.grid(False)
            sb.despine(ax=ax, left=True)
            ax.set_yticklabels([])
            ax.set_xticks(range(0, 10))
        
        ax.set_ylabel(None)
        ax.set_xlabel(None)
        ax.set_title(None)    
        #ax.set_xticklabels([])
    
xticks = [f'- {el:.1f}' for el in (step_width + np.arange(0, 1, step_width))]
for j, e in enumerate(epitopes_present):   
    axes[0][j].set_title(e)
    #axes[-1][j].set_xticklabels(xticks, rotation=90)
for i, d in enumerate(adata.obs['donor'].unique()):
    axes[i][0].set_ylabel(d)    
 
fig.tight_layout()
#plt.savefig('../../figures/dextramer/paper/2e_pseudotime_by_time_specificity.pdf', dpi=dpi)
plt.show()

### Fig 2F
Diversity of repertoires over time (Gini and inverse Simpson index as in Bacher et al. for evenness and richness, respectively); once for each time point and all phenotypic clusters pooled; once for each phenotypic cluster and all time points pooled. Once for each time point and phenotypic cluster. Everything once for all donors pooled and once for each donor (only for YLQ and LTD and only for donors with cognate HLA for these epitopes).

In [None]:
print(len(adata))
adata_diversity = adata[adata.obs['clone_id']!='nan']
print(len(adata_diversity))

In [None]:
def gini(*args, **kwargs):
    # Needed to write a wrapper passable to scirpy, since skbio standardly uses rectangular estimation
    # This leads to negative Ginis for cases where all cts are distributed equally.
    return skbio.diversity.alpha.gini_index(*args, **kwargs, method='trapezoids')


def inverse_simpson(*args, **kwargs):
    # Bacher et all report inverse simpson
    val = 1 - skbio.diversity.alpha.simpson(*args, **kwargs)
    return 1 / val

In [None]:
def plot_diversity(adata_tmp, condition, metric='gini_index', title_suffix='', plot_type=None, ax=None):
    if condition == binding_mode:
        adata_tmp = adata_tmp[~adata_tmp.obs[binding_mode].isna()]
    diversity = ir.tl.alpha_diversity(adata_tmp, groupby=condition, target_col='clone_id', metric=metric, 
                                      inplace=False)
    metric_name = metric if type(metric)==str else metric.__name__
    diversity = diversity.rename(columns={0: metric_name})
    diversity[condition] = diversity.index
    if plot_type:
        plot = plot_type(data=diversity, x=condition, y=metric_name, ax=ax)
        plot.set_title(f'{title_suffix} Diversity over {condition}')
        if not ax:
            plt.show()
        if condition == binding_mode:
            ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    return diversity


adata_diversity = adata[~adata.obs['clone_id'].isna()]
adata_diversity = adata_diversity[adata_diversity.obs['clone_id']!='nan']
adata_div_spec = adata_diversity[~adata_diversity.obs[binding_mode].isna()]
adatas =  [('All cells -', adata_diversity), ('Specific cells -', adata_div_spec)]
conditions = [('donor', sb.barplot), ('leiden', sb.barplot), (binding_mode, sb.barplot)]

for metric in [gini, 'simpson', inverse_simpson]:
    
    n_rows = len(adatas)
    n_cols = len(conditions)
    fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * 5, n_rows * 5))
    
    for i, (title, adata_tmp) in enumerate(adatas):
        for j, (condition, plot_type) in enumerate(conditions):
            plot_diversity(adata_tmp, condition, metric, title_suffix=title, plot_type=plot_type, ax=axes[i][j])
    
    plt.tight_layout()
    plt.show()

In [None]:
adata.obs['binding_10x'].value_counts()

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
adata_tmp = adata[~adata.obs['clone_id'].isna()]
adata_tmp = adata_tmp[adata_tmp.obs['clone_id']!='nan']
adata_tmp = adata_tmp[adata_tmp.obs[binding_mode]=='LTDEMIAQY']
div_ltd = plot_diversity(adata_tmp, 'donor', gini, title_suffix='Diversity', plot_type=sb.barplot, ax=ax)
div_ltd.to_csv(f'../../figures/mvp/manuscript/{adata.uns["celltype"]}_diversity_gini_ltd.csv')
plt.tight_layout()
plt.show()

In [None]:
for metric in ['gini_index', 'simpson']: 
    n_rows = len(adatas)
    n_cols = len(conditions)
    fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols * 5, n_rows * 5))
    
    for i, (title, adata_tmp) in enumerate(adatas):
        for j, (condition, plot_type) in enumerate(conditions):
            counts = adata_tmp.obs.groupby(condition)['clone_id'].nunique().reset_index()
            sb.barplot(data=counts, y='clone_id', x=condition, ax=axes[i][j])
            if condition==binding_mode:
                axes[i][j].set_xticklabels(axes[i][j].get_xticklabels(), rotation=90)
    
    plt.tight_layout()
    plt.show()

## Supplementary 2

### Supp 2A
UMAPs with epitope specificity assignment

In [None]:
ax = sc.pl.umap(adata, show=False)
ax = sc.pl.umap(adata[~adata.obs[binding_mode].isna()], color=binding_mode, show=False, ax=ax, size=15)
#plt.savefig(f'{path_figs}/paper/S2a_umap_binding.pdf', bbox_inches='tight', dpi=dpi)

### Supp 2C 
Show UMAP with IFN score (IFN_seumois) and MKI67; potentially also show other genes (depending on DEG results) such as TCF7, etc.

In [None]:
sc.pl.umap(adata, color=['ifn_seumois', 'MKI67'], show=False, use_raw=False, cmap=colormap)
#plt.savefig(f'{path_figs}/paper/S2c_umaps_scores_genes.pdf', bbox_inches='tight', dpi=dpi)

### Supp 2E 
Show pseudotime graphs with pseudotime on the x-axis and specific genes (to be specified, see point above) on the y-axis, with a color barcode for phenotypic clusters.

In [None]:
genes = ['ifn_seumois', 'MKI67', 'TCF7']
for gene in genes:
    sc.pl.scatter(adata, x='dpt_pseudotime', y=gene, color='leiden', size=15)