# Differential enrichment analysis

Based on: Schürch et al., 2020, Cell (https://doi.org/10.1016/j.cell.2020.07.005) (Fig 6G)

In [2]:
import pandas as pd
import numpy as np

import statsmodels.api as sm
import statsmodels.stats.multitest as smmulti 

import seaborn as sns
import matplotlib.pyplot as plt

## Data wrangling

### cell types

In [4]:
all_ct = pd.read_excel('/data/T_subsets.xlsx', sheet_name = 1, header = None)

In [5]:
all_ct.set_index(0, inplace = True)

In [6]:
all_ct.columns = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H'] # placeholders for now

In [7]:
all_ct

Unnamed: 0_level_0,A,B,C,D,E,F,G,H
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
49,Stem/Progenitor cells,CD34+CD117- HSPCs,Mice1_Day0,Mice1,Day0,85,0.156408,0.579630
61,Stem/Progenitor cells,CD34+CD117+ HSPCs,Mice1_Day0,Mice1,Day0,123,0.226332,
73,Stem/Progenitor cells,CD34-CD117+ HSPCs,Mice1_Day0,Mice1,Day0,107,0.196890,
1,APCs,APCs,Mice1_Day0,Mice1,Day0,459,0.844604,0.853804
375,APCs,PD-L1+ APCs,Mice1_Day0,Mice1,Day0,2,0.003680,
...,...,...,...,...,...,...,...,...
374,T cells,PD-1+CD8+ T cells,Mice9_Day14,Mice9,Day14,224,0.137593,
479,Unidentified,Unidentified,Mice9_Day14,Mice9,Day14,9367,5.753721,5.753721
250,Vessels,LepR+ Vessels,Mice9_Day14,Mice9,Day14,1452,0.891897,6.951517
24,Vessels,Arterioles,Mice9_Day14,Mice9,Day14,364,0.223589,


In [8]:
# sanity check 1
summed = all_ct.groupby(by=["A"]).sum(numeric_only = True)
all(summed.G == summed.H)
# confirms that column H sums are consistent with column G sums for a given cell type parent (A), regardless of day

True

In [9]:
all_ct.groupby(by=["D"]).sum(numeric_only = True)

# confirms that column H is the sum of *all* frequencies for a given cell type family (A), regardless of day

Unnamed: 0_level_0,F,G,H
D,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
Mice1,54345,100.0,100.0
Mice10,148987,100.0,100.0
Mice11,235810,100.0,100.0
Mice12,214801,100.0,100.0
Mice2,151771,100.0,100.0
Mice3,240613,100.0,100.0
Mice4,53109,100.0,100.0
Mice5,164964,100.0,100.0
Mice6,173675,100.0,100.0
Mice7,191517,100.0,100.0


In [10]:
# subset the data to just Day0 and Day21, and T cell parent types
sub_ct = all_ct.query('E in ["Day0", "Day21"]').query('A == "T cells"')
sub_ct

Unnamed: 0_level_0,A,B,C,D,E,F,G,H
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
85,T cells,CD4+ T cells,Mice1_Day0,Mice1,Day0,177,0.325697,1.098537
97,T cells,CD8+ T cells,Mice1_Day0,Mice1,Day0,224,0.412181,
133,T cells,DN T cells,Mice1_Day0,Mice1,Day0,89,0.163769,
351,T cells,PD-1+CD4+ T cells,Mice1_Day0,Mice1,Day0,37,0.068084,
363,T cells,PD-1+CD8+ T cells,Mice1_Day0,Mice1,Day0,70,0.128807,
86,T cells,CD4+ T cells,Mice10_Day21,Mice10,Day21,663,0.445005,1.263869
98,T cells,CD8+ T cells,Mice10_Day21,Mice10,Day21,194,0.130213,
134,T cells,DN T cells,Mice10_Day21,Mice10,Day21,293,0.196661,
352,T cells,PD-1+CD4+ T cells,Mice10_Day21,Mice10,Day21,195,0.130884,
364,T cells,PD-1+CD8+ T cells,Mice10_Day21,Mice10,Day21,538,0.361105,


In [11]:
cells = sub_ct.B.unique()

In [12]:
ct_df = sub_ct

In [13]:
# get rid of the cell type parent since all are the same
# also get rid of total frequency column as we can re-calculate it as needed
ct_df = ct_df.drop(['A', 'H'], axis = 1)
ct_df.columns = ['Tct', 'mouse_day', 'mouse', 'timepoint', 'n_cell', 'ct_freq']

In [14]:
# fix spaces in strings
ct_df['Tct'] = ct_df.Tct.str.replace('\xa0', ' ')

In [15]:
ct_df

Unnamed: 0_level_0,Tct,mouse_day,mouse,timepoint,n_cell,ct_freq
0,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
85,CD4+ T cells,Mice1_Day0,Mice1,Day0,177,0.325697
97,CD8+ T cells,Mice1_Day0,Mice1,Day0,224,0.412181
133,DN T cells,Mice1_Day0,Mice1,Day0,89,0.163769
351,PD-1+CD4+ T cells,Mice1_Day0,Mice1,Day0,37,0.068084
363,PD-1+CD8+ T cells,Mice1_Day0,Mice1,Day0,70,0.128807
86,CD4+ T cells,Mice10_Day21,Mice10,Day21,663,0.445005
98,CD8+ T cells,Mice10_Day21,Mice10,Day21,194,0.130213
134,DN T cells,Mice10_Day21,Mice10,Day21,293,0.196661
352,PD-1+CD4+ T cells,Mice10_Day21,Mice10,Day21,195,0.130884
364,PD-1+CD8+ T cells,Mice10_Day21,Mice10,Day21,538,0.361105


### wrangle neighborhoods

do this separately for each day, based on the way the files were formatted. then concatenate

#### Day0

In [None]:
day0 = pd.read_excel('/data/T_subsets.xlsx', sheet_name = 0, header = [2], nrows = 16)
day0

In [None]:
day0.columns

In [None]:
# rename the columns (this didn't translate well from the original spreadsheet)
day0.columns = [
    'Day0', 
    'neighborhood', 
    'mouse1--CD4+ T cells', 'mouse2--CD4+ T cells', 'mouse3--CD4+ T cells',
    'mouse1--CD8+ T cells', 'mouse2--CD8+ T cells', 'mouse3--CD8+ T cells',
    'mouse1--DN T cells', 'mouse2--DN T cells', 'mouse3--DN T cells', 
    'mouse1--PD-1+CD4+ T cells', 'mouse2--PD-1+CD4+ T cells', 'mouse3--PD-1+CD4+ T cells', 
    'mouse1--PD-1+CD8+ T cells', 'mouse2--PD-1+CD8+ T cells', 'mouse3--PD-1+CD8+ T cells']

In [None]:
# clean up indices
day0 = day0.drop(0, axis = 0).drop('Day0', axis = 1).reset_index(drop = True)

In [None]:
day0 # need to make long

In [None]:
# turn it into a long dataframe
day0_long = day0.melt(id_vars = ['neighborhood'], var_name = 'mouse_Tct', value_name = 'freq')

In [None]:
day0_long

In [None]:
# split the mouse_Tct column
day0_long[['mouse','Tct']] = day0_long['mouse_Tct'].str.split('--',expand=True)

In [None]:
day0_long

In [None]:
# add a timepoint column
day0_long['timepoint'] = 'Day0'

#### Day21

In [None]:
day21 = pd.read_excel('/data/T_subsets.xlsx', sheet_name = 0, header = [22], nrows = 16)
day21

In [None]:
day21.columns

In [None]:
day21.columns = [
    'Day21', 
    'neighborhood', 
    'mouse10--CD4+ T cells', 'mouse11--CD4+ T cells', 'mouse12--CD4+ T cells',
    'mouse10--CD8+ T cells', 'mouse11--CD8+ T cells', 'mouse12--CD8+ T cells',
    'mouse10--DN T cells', 'mouse11--DN T cells', 'mouse12--DN T cells', 
    'mouse10--PD-1+CD4+ T cells', 'mouse11--PD-1+CD4+ T cells', 'mouse12--PD-1+CD4+ T cells', 
    'mouse10--PD-1+CD8+ T cells', 'mouse11--PD-1+CD8+ T cells', 'mouse12--PD-1+CD8+ T cells']

In [None]:
day21 = day21.drop(0, axis = 0).drop('Day21', axis = 1).reset_index(drop = True)

In [None]:
day21 # need to make long

In [None]:
day21_long = day21.melt(id_vars = ['neighborhood'], var_name = 'mouse_Tct', value_name = 'freq')

In [None]:
day21_long

In [None]:
day21_long[['mouse','Tct']] = day21_long['mouse_Tct'].str.split('--',expand=True)

In [None]:
day21_long

In [None]:
day21_long['timepoint'] = 'Day21'

#### concatenate the neighborhood data

In [None]:
nb_df = pd.concat([day0_long, day21_long], ignore_index = True)

In [None]:
# clean up mouse labels to match cell frequency df
nb_df.mouse = nb_df['mouse'].str.replace('mouse', 'Mice')

In [None]:
nb_df # long neighborhood df

In [None]:
# rename "freq" so that it's clearer what it represents (also prepare for joining)
nb_df.rename(columns={"freq": "ct_in_nb_freq"}, inplace = True)

In [None]:
nb_df

### create one unified data frame

In [None]:
for mouse in nb_df.mouse.unique():
    print(mouse)
    assert mouse in ct_df.mouse.unique(), f'{mouse} not found in ct_df'

In [None]:
for celltype in nb_df.Tct.unique():
    print(celltype)
    assert celltype in ct_df.Tct.unique(), f'{celltype} not found in ct_df'

In [None]:
full_df = nb_df.drop(['mouse_Tct'], axis = 1).merge(
    ct_df.drop(['n_cell'], axis = 1),
    on = ['mouse', 'Tct', 'timepoint'], 
    how = 'inner')

### prepare for model fitting

In [None]:
# code timepoint as integer for model fitting
full_df['timepoint_int'] = full_df.timepoint.replace({"Day0": 0, "Day21": 1})

In [None]:
# correct dtypes to float so that the model fitting doesn't complain
full_df = full_df.astype({'ct_in_nb_freq': 'float64'})

In [None]:
full_df.dtypes

In [None]:
# perform log transform
ct_freq_col = 'ct_freq'
ct_in_nb_freq_col = 'ct_in_nb_freq'

# 1e-3 is the pseudocount used in the paper
full_df['ct_freq_log2'] = np.log2(1e-3 + full_df[ct_freq_col])
full_df['ct_in_nb_freq_log2'] = np.log2(1e-3 + full_df[ct_in_nb_freq_col])

In [None]:
# export it to csv
full_df.to_csv('/data/all_data_for_model.csv')

## Run the model

Y = beta0 * 1 + beta1 * X + beta2 * Yc

In [None]:
unique_Tct = list(full_df.Tct.unique())
unique_nb = list(full_df.neighborhood.unique())
unique_mice = list(full_df.mouse.unique())

ct_freq_col = 'ct_freq_log2'
ct_in_nb_freq_col = 'ct_in_nb_freq_log2'

In [None]:
res = []

np.random.seed(19220116)

for ct in unique_Tct:
    for nb in unique_nb:
        df = full_df.query(f'Tct == "{ct}" & neighborhood == "{nb}"')
        nsample = len(df)
        if nsample == len(unique_mice):
            X = df[['timepoint_int', ct_freq_col]].to_numpy()
            X = sm.add_constant(X)
            y = df[ct_in_nb_freq_col].to_numpy()
            
            results = sm.OLS(y, X).fit()
            #print(results.summary())
            
            # results.params[1] is the beta1 coeff, which is multiplied by timepoint_int
            # in the linear model (this is the effect attributed to Day21 (timepoint_int: 1) 
            # vs. Day 0 (timepoint_int: 0)). we pull the corresponding p-value as well
            res.append([ct, nb, results.pvalues[1], results.params[1]])
        else:
            print(f'not enough samples for {ct}, {nb}: {nsample}')

In [None]:
results = pd.DataFrame(res, columns = ['T cell type', 'neighborhood', 'p-value', 'beta1_est'])

### perform p-value correction for FDR

In [None]:
fdr_res = smmulti.fdrcorrection(
    results['p-value'].to_numpy(dtype = 'float64'), 
    alpha=0.05, 
    method='indep', 
    is_sorted=False
)

# add these to the results
results['h0_rejected'] = fdr_res[0]
results['p-value_cor'] = fdr_res[1]

In [None]:
# write results to csv
results.to_csv('/results/ols_results.csv')

## Visualize results

### beta1 coefficient (model results)

In [None]:
# pivot data for heatmaps
beta1 = results.pivot(columns = 'neighborhood',
                      index = 'T cell type',
                      values = 'beta1_est')

pvals = results.pivot(columns = 'neighborhood',
                      index = 'T cell type',
                      values = 'p-value_cor')

In [None]:
ax = sns.heatmap(beta1, 
                 cmap = 'bwr', 
                 center = 0,
                 # vmin = -1, 
                 # vmax = 1,
                 square = True,
                 cbar = True,
                 cbar_kws = {"shrink": 0.5,
                             "label": "differential enrichment\nred: higher on Day 21\nblue:lower on Day 21"
                            })

for a,b in zip(*np.where (pvals < 0.05)):
    plt.text(b + .5,
             a + .75,
             '*',
             fontsize = 20,
             ha = 'center',
             va = 'center', 
             c = 'white')
    
plt.tight_layout()
plt.savefig('/results/heatmap.svg', bbox_inches="tight", pad_inches = 0.2)

in the heatmap, red (>0) indicates higher ct freq in that neighborhood for Day 21

### heatmaps of the ct freq in each neighborhood

In [None]:
day0_means = full_df.query('timepoint == "Day0"').groupby(["neighborhood","Tct"]).ct_in_nb_freq.mean()
day21_means = full_df.query('timepoint == "Day21"').groupby(["neighborhood","Tct"]).ct_in_nb_freq.mean()

In [None]:
ax = sns.heatmap(day0_means.unstack().T, 
            vmin = 0,
            vmax = 1.5,
            square = True,
            cbar = True,
            cbar_kws = {"shrink": 0.5})

plt.title('Day0')

In [None]:
ax = sns.heatmap(day21_means.unstack().T, 
            vmin = 0,
            vmax = 1.5,
            square = True,
            cbar = True,
            cbar_kws = {"shrink": 0.5})

plt.title('Day21')

#### dropped 'Erythroid_EryC'
made it hard to visualize others

In [None]:
ax = sns.heatmap(day0_means.unstack().T.drop('Erythroid_EryC', axis = 1), 
            vmin = 0,
            vmax = 0.5,
            square = True,
            cbar = True,
            cbar_kws = {"shrink": 0.5})

plt.title('Day0')

In [None]:
ax = sns.heatmap(day21_means.unstack().T.drop('Erythroid_EryC', axis = 1),
            vmin = 0,
            vmax = 0.43,
            square = True,
            cbar = True,
            cbar_kws = {"shrink": 0.5})

plt.title('Day21')