# Case vs. Control Analysis
In this notebook we'll use logistic regression to examine differences in taxonomic composition between cases and controls, conducted on a per-study basis to account for covariates. Here, we hope to uncover URT microbiome-based |associations that can be causative or preventative of respiratory diseases.
____


In [None]:
import pandas as pd
import os
import statsmodels.formula.api as smf
import seaborn as sns
import scipy
import statsmodels as sm
import matplotlib.pyplot as plt
from utils import *
from matplotlib.font_manager import FontProperties


%matplotlib inline

## Collect Reads
First, load the merged table constructed earlier

In [None]:
merged_table = pd.read_csv('../data/merged_table.csv')

## Specify Color Encoding
Import the disease-specific color dictionary we've been using 

In [None]:
color_dict = {'Asthma':'#a6cee3',
              'COVID-19':'#1f78b4', 
              'Influenza':'#b2df8a',
              'Pneumonia':'#33a02c',
              'RSV':'#fb9a99',
              'RTI':'#e31a1c',
              'Resp. Allergies':'#fdbf6f',
              'Rhinosinusitis':'#ff7f00',
              'COPD':'#cab2d6',
              'Tonsillitis':'#6a3d9a'}

## Per-Study Case vs. Control Logistic Regression
Using logistic regression, find associations between taxon abundance and case/control status. This is done on a per-study basis to remove biases from covariates.

In [None]:
# initialize data frame
study_specific = pd.DataFrame()

# iterate through studies
for study in merged_table['study'].unique():
    res_temp = merged_table[merged_table['study']==study]
    
    # iterate through taxa in study
    for x in res_temp['full_taxonomy'].unique():
        
        try:
            df = res_temp[res_temp['full_taxonomy']==x]
            df = df.copy()
            
            # binarize case/control status
            df['condition_bin'] = (df['condition'] == 'control').astype(int)
            
            # skip if not enough samples
            if df['condition_bin'].nunique()==1:
                continue
                
            # logistic regression
            model = smf.logit('condition_bin ~ clr', data = df)        
            sol = model.fit(disp=0)
            
            # calculate log fold change
            log2 = np.log2(df[df['condition']=='case']['relative'].mean()/
                   df[df['condition']=='control']['relative'].mean())
            
            # add result to dataframe 
            study_specific = pd.concat([study_specific, pd.DataFrame({
                'taxon':[x], 
                'pvalue':[sol.pvalues['clr']], 
                'log2_foldchange':[log2], 
                'study':[study]})])
            
        # account for exceptions
        except sm.tools.sm_exceptions.PerfectSeparationError:
            # print("Skipping group", x,"in", study, "due to perfect predictor error") ## uncomment for output
            continue
            
        except np.linalg.LinAlgError:
            # print("Skipping group", x,"in", study, "due to singular matrix") ## uncomment for output 
            continue

## Format Results
Format the resulting dataframe

In [None]:
# fdr correction of pvalues
study_specific['q'] = sm.stats.multitest.fdrcorrection(study_specific['pvalue'])[1]

# shorten taxon id to just genus name
study_specific['genus'] = study_specific['taxon'].str.split('|').str[-1]

# determine enrichment direction (1 = enriched in cases, -1 = enriched in controls)
study_specific['enrichment'] = study_specific['log2_foldchange']>0
study_specific['enrichment'] = study_specific['enrichment'].map({True:1,False:-1})

# filter to significant results
study_specific.loc[study_specific['q']>0.05,'enrichment']=0

# create dataframe with pvalues and with enrichments
p_frame = study_specific.pivot(index = 'genus',columns = 'study',values = 'q')
hits = study_specific.pivot(index = 'genus',columns = 'study',values = 'enrichment')

# fill in zeroes for easy plotting 
hits.fillna(0.0, inplace = True)

# # remove rows with no significant enrichments
hits = hits.loc[(hits != 0).any(axis=1)]
hits

## Calculate Prevalence and Abundance
Here we calculate the prevalence and abundance of each taxon in the analysis. 

In [None]:
# total number of samples
n = merged_table.sample_id.nunique()

# calculate prevalence and abundance for each genus
prevalence = merged_table[merged_table.reads > 0]['full_taxonomy'].value_counts() / n
abundance = merged_table[merged_table.reads > 0].groupby('full_taxonomy')['relative'].mean()

# shorten genus name for each 
prevalence.index = prevalence.index.str.split('|').str[-1]
abundance.index = abundance.index.str.split('|').str[-1]

# map to dataframe
hits['prevalence'] = hits.index.map(prevalence)
hits['abundance'] = hits.index.map(abundance)

## Calculate Enrichment Heuristic
Here we'll calculate the between study enrichment, defined by N(same direction) - N(opposite direction). If the result is greater than 3, we'll include this in the results

In [None]:
# remove low abundance taxa
hits = hits[hits['abundance']>0.005]

# calculate heuristic
hits['overall'] = hits[hits.columns[0:-2]].sum(axis = 1)

# sort by abundance
hits.sort_values(by = 'abundance', inplace = True)

# assign signature for heuristic
hits.loc[hits['overall']>=3, 'signature'] = 1
hits.loc[hits['overall']<= -3, 'signature'] = -1
hits['signature'].fillna(0.0,inplace = True)

# drop calculation column 
hits.drop(columns = 'overall',inplace = True)

# transpose for plotting 
hits = hits.T

# formate for plotting
hits['authors'] = hits.index.str.split(',').str[0]
hits['disease'] = hits.index.map(merged_table.set_index('study')['disease'].to_dict())
hits['fill'] = hits['disease'].map(color_dict)
hits.sort_values(by = 'disease', inplace = True)

## Plot associations
Now, using a heatmap, we plot results from the logistic regression. Overall hits are included as an additional subplot, as are prevalence and abundance.

In [None]:
fig, (ax1, ax2, ax3, ax4) = plt.subplots(nrows=4,
                               figsize=(18, 10),
                               gridspec_kw={'height_ratios': [30,1.5,3,3]})



sns.heatmap(hits.iloc[0:-3].drop(columns = ['fill','disease', 'authors']),
                    cmap=sns.diverging_palette(220,20,center='light',as_cmap=True),
                    cbar = False,
                    ax = ax1)
for i, color in enumerate(hits.iloc[0:-3]['fill']):
    ax1.add_patch(plt.Rectangle(xy=(-0.02, i), width=0.02, height=1, color=color, lw=0,
                               transform=ax1.get_yaxis_transform(), clip_on=False))

    
sns.heatmap(hits[hits.index=='signature'].drop(columns = ['fill','disease', 'authors']),
                    cmap=sns.diverging_palette(220,20,center='light',as_cmap=True),center=0.00,
                    cbar = False,
                    ax = ax2)
sns.barplot(x=hits.iloc[0:-3].T.iloc[0:-3].index,
            y=hits.T.iloc[0:-3]['abundance'],
            ax=ax3,
            color='gray')
sns.barplot(x=hits.iloc[0:-3].T.iloc[0:-3].index,
            y=hits.T.iloc[0:-3]['prevalence'],
            ax=ax4,
            color='gray')


font_props = FontProperties().copy()
font_props.set_size(14)

ax1.set_xticks([])
ax1.tick_params(axis='y', which='major', pad=25, length=0)
ax1.set(xlabel=None)
ax1.set_yticklabels(ax1.get_ymajorticklabels(), fontproperties=font_props)

ax2.set_xticks([])
ax2.set_yticks([])
ax2.set(xlabel=None)
ax2.set(ylabel=None)

font_props.set_style("italic")
ax3.set_ylabel('',fontsize = 14)
ax3.set(xlabel=None)
ax3.set_xticklabels([])

font_props.set_style("italic")
ax4.set_ylabel('', fontsize =14, rotation = 0)
ax4.set(xlabel=None)
plt.xticks(rotation =80)
ax4.set_xticklabels(ax4.get_xmajorticklabels(), fontproperties=font_props)

plt.savefig('../visualizations/logistic_regression.svg', dpi=300, bbox_inches="tight", format = 'svg')
plt.show()

## Calculate Effect Sizes
Calcuate effect sizes and p-values for each association 

In [None]:
# isolate columns with abundances
taxa = hits.columns[0:-3].unique()

# isolate rows with studies
studies = hits.iloc[0:-3].index.unique()

# initialize dataframe
effects = pd.DataFrame(index = taxa, columns = studies)

# iterate through taxa and studies
for taxon in taxa:
    for study in studies:
        # calculate effect size
        cohens_d = effectsize(
            merged_table[(merged_table['genus']==taxon)&
                         (merged_table['study']==study)&
                         (merged_table['condition'] =='control')]['clr'],
            merged_table[(merged_table['genus']==taxon)&
                         (merged_table['study']==study)&
                         (merged_table['condition'] =='case')]['clr'])
        effects.at[taxon, study] = cohens_d

# create dataframes
effects = effects[effects.index.isin(hits.columns)].T
p_frame = p_frame[p_frame.index.isin(hits.columns)].T