# Notebook to add Information

In [None]:
%load_ext autoreload

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os

import scanpy as sc
import scirpy as ir
import anndata as ann
import numpy as np
import pandas as pd
import seaborn as sb
from tqdm import tqdm
import math

import matplotlib.pyplot as plt
import matplotlib.colors as clrs

from matplotlib import rcParams

In [None]:
%autoreload 2
import sys
sys.path.append('..')

import utils.annotation as utils_annotation
import utils.representation as utils_representation
import utils.visualisation as utils_vis

In [None]:
sc.settings.set_figure_params(dpi=150)
sc.settings.verbosity = 3
sc.set_figure_params(vector_friendly=True, color_map='viridis', transparent=True)
sb.set_style('whitegrid')

colormap = 'flare'

## Collect input data

In [None]:
path_base = '../../data/20231017'
path_out = '../../data/mvp'
path_annotated_cd4 = f'{path_out}/02_mvp_annotated_cd4.h5ad'
path_annotated_cd8 = f'{path_out}/02_mvp_annotated_cd8.h5ad'

In [None]:
adatas = []

for i in range(1, 4):
    adata_tmp = sc.read(f'{path_base}/01_mixed_merged_{i}.h5ad')
    adata_tmp.uns['log1p']['base'] = None
    adatas.append(adata_tmp)
adata = adatas[0].concatenate(adatas[1:])

In [None]:
# Filter only relevant samples
mixed_samples = ['run_1_HA1', 'run_1_HA2', 'run_1_HA3', 
                 'run_1_HA5', 'run_1_HA6', 'run_1_HA7',
                 'run_2_HA3', 'run_2_HA8', 'run_3_HA6']
adata = adata[adata.obs['pool'].isin(mixed_samples)].copy()
adata.obs['pool'].value_counts()

In [None]:
plot = sb.barplot(data=pd.DataFrame(adata.obs['pool'].value_counts()).reset_index(), x='index', y='pool')
_ = plot.set_xticklabels(plot.get_xticklabels(), rotation=90)

## Pool level annotation

In [None]:
# Create list of epitopes
epitopes = ['LTDEMIAQY', 'QPYRVVVL', 'YLQPRTFLL', 'RLQSLQTYV', 
            'VLNDILSRL', 'KIADYNYKL', 'YTNSFTRGVY', 'NYNYLYRLF', 
            'TFEYVSQPFLMDLE', 'ATDSLNNEY', 'CTELKLSDY', 
            'FLRGRAYGL', 'RAKFKQLL', 'SPRRARSVA', 'FPQSAPHGV', 'IYKTPPIKDF',]
adata.uns['epitopes'] = epitopes
adata.obsm['epitopes'] = adata.obs[epitopes]

In [None]:
pool_2_epitope = {
    'run_1_HA1': ['LTDEMIAQY', 'QPYRVVVL', 'TFEYVSQPFLMDLE'], 
    'run_1_HA2': ['LTDEMIAQY', 'YLQPRTFLL', 'RLQSLQTYV', 'VLNDILSRL', 'KIADYNYKL', 'QPYRVVVL', 'TFEYVSQPFLMDLE'], 
    'run_1_HA3': ['LTDEMIAQY', 'YTNSFTRGVY', 'NYNYLYRLF', 'QPYRVVVL', 'TFEYVSQPFLMDLE'], 
    'run_1_HA5': ['ATDSLNNEY', 'CTELKLSDY', 'FLRGRAYGL', 'RAKFKQLL'],
    'run_1_HA6': ['ATDSLNNEY', 'CTELKLSDY', 'FLRGRAYGL', 'RAKFKQLL'], 
    'run_1_HA7': ['ATDSLNNEY', 'CTELKLSDY', 'FLRGRAYGL', 'RAKFKQLL'],
    'run_2_HA3': ['LTDEMIAQY', 'SPRRARSVA', 'FPQSAPHGV', 'IYKTPPIKDF', 'TFEYVSQPFLMDLE', 'YTNSFTRGVY'], 
    'run_2_HA8': ['LTDEMIAQY'], 
    'run_3_HA6': ['LTDEMIAQY', 'QPYRVVVL', 'TFEYVSQPFLMDLE', 'YTNSFTRGVY'],
}

In [None]:
pool_annotation = {
    'run_1_HA1': ['MVP', 'd189', '215°'], 
    'run_1_HA2': ['A07', 'd189', '3°'], 
    'run_1_HA3': ['A15', 'd189', '3°'], 
    'run_1_HA5': ['MVP', 'd157', '215°'], 
    'run_1_HA6': ['A07', np.nan, np.nan], 
    'run_1_HA7': ['A15', np.nan, np.nan],
    'run_2_HA3': ['A04', 'd189', '3°'], 
    'run_2_HA8': ['A16', 'd189', '3°'], 
    'run_3_HA6': ['A08', 'd189', '3°']
}
pool_annotation = pd.DataFrame(pool_annotation, index=['donor', 'time', 'vaccination']).transpose()
pool_annotation

In [None]:
for col in pool_annotation.columns:
    adata.obs[col] = adata.obs['pool'].map(pool_annotation[col])

In [None]:
plot = sb.barplot(data=pd.DataFrame(adata.obs['donor'].value_counts()).reset_index(), x='index', y='donor')
_ = plot.set_xticklabels(plot.get_xticklabels(), rotation=90)

In [None]:
plot = sb.barplot(data=pd.DataFrame(adata.obs['vaccination'].value_counts()).reset_index(), x='index', y='vaccination')
_ = plot.set_xticklabels(plot.get_xticklabels(), rotation=90)

## Cell Filtering based on Phenotypes

### Initial UMAP and Leiden

In [None]:
utils_representation.calculate_umap(adata, n_high_var=5000, remove_tcr_genes=True)

In [None]:
utils_representation.calculate_leiden(adata, resolution=3.0, n_high_var=5000, remove_tcr_genes=True)

In [None]:
sc.pl.umap(adata, color='leiden')
sc.pl.umap(adata, color=['donor', 'time', 'vaccination'])
sc.pl.umap(adata, color=['sample', 'pool'])

In [None]:
utils_vis.separate_umaps_by_condition(adata, 'leiden', 6, 6, do_int_sort=True)

### Filter clusters based on scores

In [None]:
utils_annotation.add_seumois_score(adata)

In [None]:
utils_annotation.add_all_scores(adata)

In [None]:
utils_vis.plot_marker_genes(adata)

In [None]:
sc.pl.umap(adata, color=['has_ir', 'chain_pairing'])

In [None]:
ir.pl.group_abundance(adata, groupby='leiden', target_col='chain_pairing', normalize=True, fig_kws={'figsize': (12, 5)})

In [None]:
print('Before Filtering Cluster: ', len(adata))
clusters_remove = ['23', '26', '31']
adata = adata[~adata.obs['leiden'].isin(clusters_remove)]
print('After Filtering Cluster: ', len(adata))

## Clonotypes - Definition über gesamte Daten??? seperieren

In [None]:
ir.tl.chain_pairing(adata)
ir.pp.ir_dist(adata, metric='identity', sequence='aa')
ir.tl.define_clonotype_clusters(adata, metric='identity', receptor_arms='all', dual_ir='any', sequence='aa', 
                                key_added='clone_id')

In [None]:
adata.obs.loc[adata.obs['IR_VJ_1_junction_aa'].isna(), 'clone_id'] = np.nan
adata.obs.loc[adata.obs['IR_VDJ_1_junction_aa'].isna(), 'clone_id'] = np.nan
adata.obs['clone_id'] = adata.obs['clone_id'].astype(float)

ir.tl.clonal_expansion(adata, target_col='clone_id', key_added='clone_size_clipped', clip_at=3)
ir.tl.clonal_expansion(adata, target_col='clone_id', key_added='clone_size', clip_at=len(adata))
adata.obs['clone_size'] = adata.obs['clone_size'].astype(float)
sc.pl.umap(adata, color=['clone_size_clipped', 'clone_size'])

### Filter Cells without IR

In [None]:
print(f'Amount of cells: {len(adata)}')
adata = adata[adata.obs['has_ir']=='True'].copy()
print(f'Amount of cells with IR: {len(adata)}')

### Extract clonotype information

In [None]:
utils_annotation.extract_clonotype_information(adata, 'junction_aa', 'clonotype_sequence')
utils_annotation.extract_clonotype_information(adata, 'v_call', 'v_genes')
utils_annotation.extract_clonotype_information(adata, 'j_call', 'j_genes')

### Assign MAITs

Mark Cells with a TCR typical for a MAIT T cell by a reported MAIT sequence and common gene combinations.

In [None]:
def assign_mait(row, gene_combination=True, cdr3=True):
    if gene_combination:
        if 'TRAJ33' in str(row['j_genes']) and 'TRAV1-2' in str(row['v_genes']):
            if 'TRBV20-1' in str(row['v_genes']) or 'TRBV6' in str(row['v_genes']):
                return 'True'
    if cdr3:
        if 'CAVMDSSYKLIF' in str(row['clonotype_sequence']):
            return 'True'
    return 'False'

In [None]:
adata.obs['has_mait'] = adata.obs.apply(assign_mait, axis=1)
sc.pl.umap(adata, color='has_mait', groups='True')
adata.obs['has_mait'].value_counts()

## Specificity Annotation

In [None]:
for ep in adata.uns['epitopes']:
    adata.obs[f'log_{ep}'] = np.log(adata.obs[ep]+1)

In [None]:
rcParams['figure.figsize'] = (6, 6)
sc.pl.umap(adata, color=[f'log_{ep}' for ep in adata.uns['epitopes']], ncols=4, size=10)

In [None]:
sc.pl.umap(adata, color=adata.uns['epitopes'], ncols=4, size=10)

In [None]:
rcParams['figure.figsize'] = (12, 3)
plot = sc.pl.violin(adata, [f'log_{el}' for el in adata.uns['epitopes']], rotation=90, title='All pools', show=False)
plt.title('All Samples')
plt.show()

In [None]:
for p in adata.obs['pool'].unique():
    sc.pl.violin(adata[adata.obs['pool']==p], [f'log_{el}' for el in adata.uns['epitopes']], rotation=90, show=False)
    plt.title(p)
    plt.show()

In [None]:
utils_vis.distributions_over_columns(adata, adata.uns['epitopes'], 4, 4, x_lim=500, y_lim=0.005)

In [None]:
for p in adata.obs['pool'].unique():
    adata_tmp = adata[adata.obs['pool']==p]
    cols = pool_2_epitope[p]
    utils_vis.distributions_over_columns(adata_tmp, cols, 1, len(cols), x_lim=500, y_lim=0.005, title=p)

In [None]:
adata.obs['n_count_dextramer'] = np.sum(adata.obsm['epitopes'], axis=1)
adata.obs['n_max_dextramer'] = np.max(adata.obsm['epitopes'], axis=1)
adata.obs['max_dextramer'] = np.nanargmax(adata.obsm['epitopes'], axis=1)
adata.obs['max_dextramer'] = adata.obs['max_dextramer'].apply(lambda x: adata.uns['epitopes'][x])
adata.obs['%_max_dextramer'] = adata.obs['n_max_dextramer'] / adata.obs['n_count_dextramer']

In [None]:
thresholds_umi = {
    'LTDEMIAQY': 30,
    'TFEYVSQPFLMDLE': 4,
    'CTELKLSDY': 10,
    'FLRGRAYGL': 10,
    'RAKFKQLL': 10,
}
threshold_ct_purity = 0.9
threshold_umi_purity = 0.0

In [None]:
adata.obs['binding_ct'] = None
cts_by_epitope = {}
for ep, thresh in thresholds_umi.items():
    adata.obs[f'has_{ep}'] = (adata.obs[ep]>=thresh).astype(str)
    cts_spec = adata[adata.obs[f'has_{ep}']=='True'].obs['clone_id'].unique()
    adata_tmp = adata[adata.obs['clone_id'].isin(cts_spec)
                      & adata.obs[ep].notna()]
    frac_binds = adata_tmp.obs.groupby('clone_id')[f'has_{ep}'].value_counts(normalize=True)
    frac_binds = frac_binds.unstack()['True']
    cts_spec = frac_binds[frac_binds>=threshold_ct_purity].index.values
    adata.obs.loc[adata.obs['clone_id'].isin(cts_spec), 'binding_ct'] = ep
    cts_by_epitope[ep] = cts_spec
    
for ep1, cts1 in cts_by_epitope.items():
    for ep2, cts2 in cts_by_epitope.items():
        if ep1 == ep2:
            continue
        overlap = [el for el in cts1 if el in cts2]
        adata.obs.loc[adata.obs['clone_id'].isin(overlap), 'binding_ct'] = 'Ambiguous'
adata.obs['binding_ct'].value_counts()

In [None]:
adata.obs['binding_minerva'] = adata.obs[['max_dextramer', '%_max_dextramer', 'n_max_dextramer']
                                ].apply(lambda x: x[0] if x[1]>=0.3 and x[2]>=4 else 'No binding', axis=1)

adata.obs['binding_10x'] = adata.obs[['max_dextramer', 'n_max_dextramer']
                                ].apply(lambda x: x[0] if x[1]>=10 else 'No binding', axis=1)

adata.obs['binding_10x_minervina'] = adata.obs[['max_dextramer', '%_max_dextramer', 'n_max_dextramer']
                                ].apply(lambda x: x[0] if x[1]>=0.3 and x[2]>=10 else 'No binding', axis=1)

def assign_10x(thresh):
    adata.obs[f'binding_10x_{thresh}'] = adata.obs[['max_dextramer', 'n_max_dextramer']
                                                 ].apply(lambda x: x[0] if x[1]>=thresh else 'None', axis=1)

for i in range(1, 31):
    assign_10x(i)

In [None]:
modes = ['minerva', '10x', '10x_minervina']
modes = [f'binding_{el}' for el in modes]

for mode in modes:
    n_cols = 4
    n_rows = 4

    fig, axes = plt.subplots(ncols=n_cols, nrows=n_rows, figsize=(n_cols*3, n_rows*3))
    axes = axes.reshape(-1)

    for i, ep in enumerate(adata.uns['epitopes']):
        sc.pl.umap(adata, ax=axes[i], show=False, size=30)
        sc.pl.umap(adata[adata.obs[mode]==ep], color=f'log_{ep}', ax=axes[i], show=False, size=30)
        axes[i].set_title(ep)
    fig.tight_layout()
    plt.show()

In [None]:
rcParams['figure.figsize'] = (6, 6)
colors = sb.color_palette('hls', len(adata.uns['epitopes']))+['lightgrey']
colors = {epitope: color for epitope, color in zip(adata.uns['epitopes']+['No binding'], colors)}
sc.pl.umap(adata, color=modes, palette=colors, ncols=1)

In [None]:
for mode in modes:
    ir.pl.group_abundance(adata, groupby='pool', target_col=mode, normalize=True)


In [None]:
amount_bindings_10x = []
amount_bindings_min = []
amount_bindings_10x_min = []

large_clones = adata[adata.obs['clone_size'].astype(float)>=2].obs['clone_id'].unique()
for ct in large_clones:
    adata_tmp = adata[adata.obs['clone_id']==ct]
    amount_10x = len(adata_tmp[adata_tmp.obs['binding_10x']!='No binding'].obs['binding_10x'].unique())
    if amount_10x==0:
        amount_10x = None
    amount_bindings_10x.append(amount_10x)
    
    amount_min = len(adata_tmp[adata_tmp.obs['binding_minerva']!='No binding'].obs['binding_minerva'].unique())
    if amount_min==0:
        amount_min = None
    amount_bindings_min.append(amount_min)
        
    amount_10x_min = len(adata_tmp[adata_tmp.obs['binding_10x_minervina']!='No binding'].obs['binding_10x_minervina'].unique())
    if amount_10x_min==0:
        amount_10x_min = None
    amount_bindings_10x_min.append(amount_10x_min)

comparision_binding = pd.DataFrame({'clone_id': large_clones,
                                    'bindings_minerva': amount_bindings_min,
                                    'bindings_10x_minervina': amount_bindings_10x_min,
                                    'bindings_10x': amount_bindings_10x})
comparision_binding.head()

In [None]:
print('Pure clones - 10x: ', (np.sum(comparision_binding['bindings_10x'].values==1) / len(comparision_binding)))
print('Pure clones - Minerva: ', (np.sum(comparision_binding['bindings_minerva'].values==1) / len(comparision_binding)))
print('Pure clones - 10x Minerva: ', (np.sum(comparision_binding['bindings_10x_minervina'].values==1) / len(comparision_binding)))

In [None]:
import matplotlib.gridspec as gridspec


for mode in modes:
    n_cols = 4
    n_rows = 4

    fig = plt.figure(figsize=(n_cols*3, n_rows*3))

    gs = fig.add_gridspec(figure=fig, nrows=n_rows, ncols=n_cols, hspace=0.5,)


    for i, ep in enumerate(adata.uns['epitopes']):
        gs_internal = gridspec.GridSpecFromSubplotSpec(subplot_spec=gs[i], nrows=2, ncols=1)
        axes_0 = fig.add_subplot(gs_internal[0])
        axes_1 = fig.add_subplot(gs_internal[1])


        vmax = max(adata.obs[ep].max()+20, 10)

        if np.sum(adata.obs[mode]==ep)>0:
            sb.distplot(adata[adata.obs[mode]==ep].obs[ep], ax=axes_0, color='tab:pink', 
                        hist=True, kde_kws={'fill': True, 'bw_adjust': 1}, axlabel=False, kde=True)
            axes_0.set_title(None)
            axes_0.set_xscale('symlog')
            axes_0.set_xlim((0, vmax))
            axes_0.set_xticks([])
            axes_0.set_yticks([])
            axes_0.set_ylabel(None)
        else:
            axes_0.set_visible(False)#axis('off')
            pass

        sb.distplot(adata[adata.obs[mode]!=ep].obs[ep], ax=axes_1, color='tab:olive', 
                    hist=True, kde_kws={'fill': True}, kde=True)
        axes_1.set_yticks([])
        axes_1.set_ylabel(None)
        #axes[1].set_yticklabels('')
        axes_1.set_xscale('symlog')
        axes_1.set_xlim((0, vmax))

    plt.suptitle(mode)
    plt.tight_layout()
    plt.show()



## CD4-CD8 Assignment

In [None]:
def clr(x):
    x = x/np.exp(np.log1p(x).sum() / x.shape[0])
    x = np.log1p(x)
    return x

In [None]:
for c in ['Hu.CD8', 'Hu.CD4_RPA.T4']:
    adata.obs.loc[~adata.obs[c].isna(), f'clr_{c}'] = clr(adata[~adata.obs[c].isna()].obs[c].values)

In [None]:
sc.pl.umap(adata, color=['CD8A', 'CD8B', 'CD4'])

In [None]:
sc.pl.violin(adata, keys=['CD8A', 'CD8B', 'CD4'], show=False)
plt.axhline(0.65, c='black')
plt.show()

In [None]:
sc.pl.umap(adata, color=['clr_Hu.CD8', 'clr_Hu.CD4_RPA.T4'])

In [None]:
ax = sc.pl.violin(adata, keys=['clr_Hu.CD8', 'clr_Hu.CD4_RPA.T4'], show=False)
plt.axhline(0.75, c='blue')
plt.axhline(.95, c='orange')
plt.show()
sc.pl.violin(adata, keys=['Hu.CD8', 'Hu.CD4_RPA.T4'])

In [None]:
thresholds = {
    'CD8A': 0.65, 
    'CD8B': 0.65, 
    'clr_Hu.CD8': 0.75,
    'clr_Hu.CD4_RPA.T4': 0.95,
    'CD4': 0.65,
}

for n, t in thresholds.items():
    if n.startswith('clr'):
        vals = adata.obs[n]
    else:
        vals = adata[:, n].X.A
    adata.obs[f'has_{n}'] = vals > t
    print(f'{n}: {np.sum(adata.obs[f"has_{n}"])}')
adata.obs[f'has_citeNaN'] = adata.obs['Hu.CD8'].isna()
print(f'citeNaN: {np.sum(adata.obs[f"has_citeNaN"])}')

In [None]:
cols = list(thresholds.keys()) + ['citeNaN']
df_overlap = pd.DataFrame(index=cols, columns=cols, dtype=float)
for i in cols:
    for j in cols:
        counts = np.sum(adata.obs[f'has_{i}'] & adata.obs[f'has_{j}'])
        df_overlap.loc[i, j] = float(counts)
sb.heatmap(df_overlap, annot=True, fmt='.4g')

In [None]:
adata.obs['has_CD8_joined'] = adata.obs['has_CD8A'] | adata.obs['has_CD8B'] | adata.obs['has_clr_Hu.CD8']
adata.obs['has_CD4_joined'] = adata.obs['has_CD4'] | adata.obs['has_clr_Hu.CD4_RPA.T4']
adata.obs['has_NaN_joined'] = ~adata.obs['has_CD8_joined'] & ~adata.obs['has_CD4_joined']
adata.obs['has_CD4+8_joined'] = adata.obs['has_CD8_joined'] & adata.obs['has_CD4_joined']

In [None]:
rcParams['figure.figsize'] = (6, 6)
cols = ['CD8_joined', 'CD4_joined', 'NaN_joined']
df_overlap = pd.DataFrame(index=cols, columns=cols, dtype=float)
for i in cols:
    for j in cols:
        counts = np.sum(adata.obs[f'has_{i}'] & adata.obs[f'has_{j}'])
        df_overlap.loc[i, j] = float(counts)
sb.heatmap(df_overlap, annot=True, fmt='.4g')

In [None]:
rcParams['figure.figsize'] = (6, 3)
sc.pl.violin(adata[adata.obs['has_NaN_joined']], keys=['clr_Hu.CD8', 'clr_Hu.CD4_RPA.T4'])
sc.pl.violin(adata[adata.obs['has_NaN_joined']], keys=['CD8A', 'CD8B', 'CD4'])

In [None]:
rcParams['figure.figsize'] = (6, 3)
sc.pl.violin(adata[adata.obs['has_CD4+8_joined']], keys=['clr_Hu.CD8', 'clr_Hu.CD4_RPA.T4'])
sc.pl.violin(adata[adata.obs['has_CD4+8_joined']], keys=['CD8A', 'CD8B', 'CD4'])

In [None]:
rcParams['figure.figsize'] = (6, 6)
adata.obs['has_NaN_joined'] = adata.obs['has_NaN_joined'].astype(str)
adata.obs['has_CD4+8_joined'] = adata.obs['has_CD4+8_joined'].astype(str)
sc.pl.umap(adata, color='has_NaN_joined', groups='True', s=30)
sc.pl.umap(adata, color='has_CD4+8_joined', groups='True', s=30)

In [None]:
adata.obs['celltype'] = 'Ambiguous'
adata.obs.loc[adata.obs['has_CD8_joined'] & ~adata.obs['has_CD4_joined'], 'celltype'] = 'CD8'
adata.obs.loc[adata.obs['has_CD4_joined'] & ~adata.obs['has_CD8_joined'], 'celltype'] = 'CD4'
sc.pl.umap(adata, color='celltype')
adata.obs['celltype'].value_counts()

In [None]:
adata_tmp = adata[adata.obs['has_CD4+8_joined']=='True']
print('CD4+CD8\n', adata_tmp.obs['binding_10x'].value_counts())

print(len(adata_tmp[adata_tmp.obs['clone_id'].isna()]))

cts_tmp = adata_tmp.obs['clone_id'].unique()
cts_tmp = cts_tmp[~np.isnan(cts_tmp)]
print(len(cts_tmp))
adata_overlap = adata[adata.obs['clone_id'].isin(cts_tmp)]
print('\n', adata_overlap.obs['celltype'].value_counts())

adata_overlap.obs[['clone_id', 'celltype']].value_counts().head(10)

In [None]:
adata_tmp = adata[adata.obs['has_NaN_joined']=='True']
print('CD4+CD8\n', adata_tmp.obs['binding_10x'].value_counts())

print(len(adata_tmp[adata_tmp.obs['clone_id'].isna()]))

cts_tmp = adata_tmp.obs['clone_id'].unique()
cts_tmp = cts_tmp[~np.isnan(cts_tmp)]
print(len(cts_tmp))
adata_overlap = adata[adata.obs['clone_id'].isin(cts_tmp)]
print('\n', adata_overlap.obs['celltype'].value_counts())

adata_overlap.obs[['clone_id', 'celltype']].value_counts().head(10)

In [None]:
ir.pl.group_abundance(adata, groupby='leiden', target_col='celltype', normalize=True)
ir.pl.group_abundance(adata, groupby='donor', target_col='celltype', normalize=False)
ir.pl.group_abundance(adata, groupby='pool', target_col='celltype', normalize=True)#, fig_kws={'figsize': (12, 5)})

## Separate CD4-CD8 datasets

In [None]:
adata_cd4 = adata[adata.obs['celltype']=='CD4'].copy()
adata_cd8 = adata[adata.obs['celltype']=='CD8'].copy()

adata_cd4.uns['celltype'] = 'CD4'
adata_cd8.uns['celltype'] = 'CD8'

adata_full = adata.copy()
del adata

In [None]:
cts_cd4 = adata_cd4.obs['clone_id'].unique()
cts_cd4 = cts_cd4[~np.isnan(cts_cd4)]
cts_cd8 = adata_cd8.obs['clone_id'].unique()
cts_cd8 = cts_cd8[~np.isnan(cts_cd8)]
cts_overlap = [el for el in cts_cd8 if el in cts_cd4]
print(cts_overlap)
adata_full[adata_full.obs['clone_id'].isin(cts_overlap)].obs['celltype'].value_counts()

## UMAPs II

### Recalculate UMAPs

In [None]:
for adata in [adata_cd4, adata_cd8]:
    utils_representation.calculate_umap(adata, n_high_var=5000, remove_tcr_genes=True)
    utils_representation.calculate_leiden(adata, resolution=0.5, n_high_var=5000, remove_tcr_genes=True)

In [None]:
for adata in [adata_cd4, adata_cd8]:
    sc.pl.umap(adata, color=['leiden', 'sample', 'donor', 'pool'], ncols=2, show=False)
    sc.pl.umap(adata, color=['binding_10x'], ncols=2, show=False)
    plt.suptitle(adata.uns['celltype'])
    plt.tight_layout()
    plt.show()

In [None]:
for adata in [adata_cd4, adata_cd8]:
    utils_vis.separate_umaps_by_condition(adata, 'leiden', 2, 3, do_int_sort=True, title=adata.uns['celltype'])

### Clonal Expansion

In [None]:
for adata in [adata_cd4, adata_cd8]:
    ir.tl.clonal_expansion(adata, target_col='clone_id', key_added='clone_size_ct_clipped', clip_at=3)
    ir.tl.clonal_expansion(adata, target_col='clone_id', clip_at=len(adata), key_added='clone_size_ct')
    ir.tl.clonal_expansion(adata, target_col='clone_id', clip_at=len(adata), expanded_in='donor',
                           key_added='clone_size_donor_ct')
    ir.tl.clonal_expansion(adata, target_col='clone_id', clip_at=len(adata), expanded_in='pool',
                           key_added='clone_size_pool_ct')

    adata.obs['clone_size_ct'] = adata.obs['clone_size_ct'].astype(float)
    adata.obs['clone_size_donor_ct'] = adata.obs['clone_size_donor_ct'].astype(float)
    adata.obs['clone_size_pool_ct'] = adata.obs['clone_size_pool_ct'].astype(float)
    adata.obs['clone_id'] = adata.obs['clone_id'].astype(str)

    ax = sc.pl.umap(adata, color=['clone_size_clipped', 'clone_size', 'clone_size_ct',
                             'clone_size_donor_ct', 'clone_size_pool_ct'], ncols=3, show=False)
    plt.suptitle(adata.uns['celltype'])
    plt.tight_layout()
    plt.show()

## DEG
- General DEG over all leidens

In [None]:
def deg_over_condition(condition, adata):
    adata_tmp = utils_representation.filter_tcr_genes(adata)
    adata_tmp = utils_representation.filter_high_var(adata_tmp, 5000)
    sc.tl.rank_genes_groups(adata_tmp, groupby=condition, n_genes=20)
    sc.pl.rank_genes_groups(adata_tmp, groubpy=condition, show=False)
    plt.suptitle(adata.uns['celltype'])
    plt.tight_layout()
    plt.show()
    #dict_deg = {}
    #for cluster in adata_tmp.obs[condition].unique():
    #    names = adata_tmp.uns['rank_genes_groups']['names'][cluster].tolist()
    #    scores = adata_tmp.uns['rank_genes_groups']['scores'][cluster].tolist()
    #    dict_deg[cluster] = list(zip(names, scores))

    #df_degs = pd.DataFrame(dict_deg)
    #df_degs = df_degs[sorted(df_degs.columns.tolist())]
    #df_degs.columns = [f'leiden_{el}' for el in df_degs.columns]
    #df_degs.to_csv(f'../../results/mvp/deg_by_{condition}_cluster.csv')
    adata.uns[f'rank_genes_groups_{condition}'] = adata_tmp.uns['rank_genes_groups']

In [None]:
for adata in [adata_cd4, adata_cd8]:
    deg_over_condition('leiden', adata)

In [None]:
for adata in [adata_cd4, adata_cd8]:
    deg_over_condition('donor', adata)

## PseudoTime

In [None]:
adata_tmp_cd4 = adata_cd4[~adata_cd4.obs['leiden'].isin(['3'])]
adata_tmp_cd8 = adata_cd8[~adata_cd8.obs['leiden'].isin(['5'])]
for adata in [adata_tmp_cd4, adata_tmp_cd8]:
    utils_representation.calculate_diffmap(adata, n_high_var=5000, remove_tcr_genes=True)

In [None]:
for adata in [adata_tmp_cd4, adata_tmp_cd8]:
    nrows = 3
    ncols = 5
    fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * 3, nrows * 3))
    axes = axes.reshape(-1)

    for i, ax in zip(range(adata.obsm['X_diffmap'].shape[1]), axes):
        root_ixs = adata.obsm['X_diffmap'][:, i].argmin()
        root_umap = adata.obsm['X_umap'][root_ixs]

        sc.pl.umap(adata, show=False, title=str(i), ax=ax)
        ax.plot(root_umap[0], root_umap[1],  marker='o', markersize=5, color="red")
    plt.suptitle(adata.uns['celltype'])
    fig.tight_layout()
    plt.show()

In [None]:
for adata, root_nr in [(adata_tmp_cd4, 1), (adata_tmp_cd8, 8)]:
    root_ixs = adata.obsm['X_diffmap'][:, root_nr].argmin()
    adata.uns['iroot'] = root_ixs
    utils_representation.calculate_dpt(adata, n_high_var=5000, remove_tcr_genes=True)

adata_cd4.obs['dpt_pseudotime'] = adata_tmp_cd4.obs['dpt_pseudotime']
adata_cd8.obs['dpt_pseudotime'] = adata_tmp_cd8.obs['dpt_pseudotime']

In [None]:
for adata in [adata_cd4, adata_cd8]:
    rcParams['figure.figsize'] = (6, 6)
    sc.pl.umap(adata, color=['leiden', 'dpt_pseudotime'], show=False)
    plt.suptitle(adata.uns['celltype'])
    plt.tight_layout()
    plt.show()

## T cell score

In [None]:
for adata in [adata_cd4, adata_cd8]:
    utils_annotation.add_tc_scores(adata)

## CiteSeq

In [None]:
cite_ids = ['Hu.CD101', 'Hu.CD103', 'Hu.CD105_43A3', 'Hu.CD107a', 'Hu.CD112', 'Hu.CD119', 
            'Hu.CD11a', 'Hu.CD11b', 'Hu.CD11c', 'Hu.CD122', 'Hu.CD123', 'Hu.CD124', 
            'Hu.CD127', 'Hu.CD13', 'Hu.CD134', 'Hu.CD137', 'Hu.CD141', 'Hu.CD146', 
            'Hu.CD14_M5E2', 'Hu.CD152', 'Hu.CD154', 'Hu.CD155', 'Hu.CD158', 'Hu.CD158b', 
            'Hu.CD158e1', 'Hu.CD16', 'Hu.CD161', 'Hu.CD163', 'Hu.CD169', 'Hu.CD18', 'Hu.CD183', 
            'Hu.CD185', 'Hu.CD19', 'Hu.CD194', 'Hu.CD195', 'Hu.CD196', 'Hu.CD1c', 'Hu.CD1d', 
            'Hu.CD2', 'Hu.CD20_2H7', 'Hu.CD21', 'Hu.CD22', 'Hu.CD223', 'Hu.CD224', 'Hu.CD226_11A8', 
            'Hu.CD23', 'Hu.CD24', 'Hu.CD244', 'Hu.CD25', 'Hu.CD26', 'Hu.CD267', 'Hu.CD268', 'Hu.CD27', 
            'Hu.CD270', 'Hu.CD272', 'Hu.CD274', 'Hu.CD279', 'Hu.CD28', 'Hu.CD29', 'Hu.CD303', 
            'Hu.CD31', 'Hu.CD314', 'Hu.CD319', 'Hu.CD32', 'Hu.CD328', 'Hu.CD33', 'Hu.CD335',
            'Hu.CD35', 'Hu.CD352', 'Hu.CD36', 'Hu.CD38_HIT2', 'Hu.CD39', 'Hu.CD3_UCHT1', 'Hu.CD40', 
            'Hu.CD41', 'Hu.CD42b', 'Hu.CD45RA', 'Hu.CD45RO', 'Hu.CD45_HI30', 'Hu.CD47', 'Hu.CD48', 
            'Hu.CD49a', 'Hu.CD49b', 'Hu.CD49d', 'Hu.CD4_RPA.T4', 'Hu.CD5', 'Hu.CD52', 'Hu.CD54', 
            'Hu.CD56', 'Hu.CD57', 'Hu.CD58', 'Hu.CD62L', 'Hu.CD62P', 'Hu.CD64', 'Hu.CD69', 'Hu.CD7', 
            'Hu.CD71', 'Hu.CD73', 'Hu.CD79b', 'Hu.CD8', 'Hu.CD81', 'Hu.CD82', 'Hu.CD83', 'Hu.CD85j', 
            'Hu.CD86', 'Hu.CD88', 'Hu.CD94', 'Hu.CD95', 'Hu.CD99', 'Hu.CLEC12A', 'Hu.CX3CR1', 'Hu.FceRIa', 
            'Hu.GPR56', 'Hu.HLA.ABC', 'Hu.HLA.DR', 'Hu.HLA.E', 'Hu.Ig.LightChain.k', 'Hu.Ig.LightChain.l', 
            'Hu.IgD', 'Hu.IgM', 'Hu.KLRG1', 'Hu.LOX.1', 'Hu.TCR.AB', 'Hu.TCR.Va7.2', 'Hu.TCR.Vd2', 'Hu.TIGIT', 
            'HuMs.CD44', 'HuMs.CD49f', 'HuMs.integrin.b7', 'HuMsRt.CD278', 'Isotype_HTK888', 'Isotype_MOPC.173', 
            'Isotype_MOPC.21', 'Isotype_MPC.11', 'Isotype_RTK2071', 'Isotype_RTK2758', 'Isotype_RTK4530']
customs_cite_ids = ['CCR7', 'CD62L', 'CXCR3', 'CD45RA']

In [None]:
for adata in [adata_cd4, adata_cd8]:
    adata.uns['cite_ids'] = cite_ids
    adata.uns['custom_cite_ids'] = customs_cite_ids
    for c in cite_ids + customs_cite_ids:
        adata.obs.loc[~adata.obs[c].isna(), f'clr_{c}'] = clr(adata[~adata.obs[c].isna()].obs[c].values)
    adata_cite = ann.AnnData(X=adata.obs[[f'clr_{el}' for el in cite_ids]
                                        ].values, obs=adata.obs[['leiden'] + cite_ids])
    adata_cite.var_names = cite_ids
    adata_cite = adata_cite[~adata_cite.obs[cite_ids[0]].isna()]

    sc.tl.rank_genes_groups(adata_cite, groupby='leiden', n_genes=20)
    sc.pl.rank_genes_groups(adata_cite, groubpy='leiden', show=False)
    plt.suptitle(adata.uns['celltype'])
    plt.tight_layout()
    plt.show()
    adata.uns['rank_genes_groups_leiden_cite'] = adata_cite.uns['rank_genes_groups']

In [None]:
cocktail_samples = ['run_1_HA1', 
                    'run_1_HA2', 
                    'run_1_HA3']

In [None]:
for cock, cust in [('Hu.CD62L', 'CD62L'), ('Hu.CD183', 'CXCR3'), ('Hu.CD45RA', 'CD45RA')]:
    df_1 = adata_full[(~adata_full.obs[cock].isna())].obs[[cock]].copy()
    df_1['in_cocktail'] = 'True'
    df_1 = df_1.rename(columns={cock: cust})
    df_2 = adata[(~adata.obs[cust].isna())].obs[[cust]].copy()
    df_2['in_cocktail'] = 'False'
    df_full = pd.concat([df_1, df_2])
    df_full[cust] = np.log(df_full[cust]+1)
    sb.violinplot(data=df_full, x='in_cocktail', y=cust)
    plt.show()

In [None]:
adata_cd4.obs['pool'].value_counts()

## Save all

In [None]:
for adata in [adata_cd4, adata_cd8]:
    for col in ['j_genes_dict', 'v_genes_dict', 'clonotype_sequence_dict']:
        _ = adata.uns.pop(col)

In [None]:
sc.write(adata=adata_cd4, filename=path_annotated_cd4)
sc.write(adata=adata_cd8, filename=path_annotated_cd8)