# Logisitc Regression 

In this notebook we'll use the output of our QIIME2 workflow to investigate differences in URT composition between cases and controls using logisitic regression 

In [None]:
import pandas as pd
import os
import statsmodels.formula.api as smf
import seaborn as sns
import statsmodels as sm
import matplotlib.pyplot as plt
from utils import *
from matplotlib.font_manager import FontProperties


%matplotlib inline

## Build Dataframe
First we'll build a dataframe of all reads in the meta-analysis

In [None]:
collapse_on=["kingdom", "phylum", "class", "order", "family", "genus"]
disease = pd.read_csv('/proj/gibbons/nbohmann/metabug/conditions.csv',index_col = 0,header = None)[1].to_dict()
os.chdir('/proj/gibbons/nbohmann/metabug/manifest/NP/')
manifest_NP = pd.read_csv('NP_manifest.csv',index_col= 0, header = None)
res = pd.DataFrame()
for file_name in manifest_NP.index:
    #pull out feature tables with total reads
    ab = qiime_to_dataframe(feature_table="qiime/"+file_name+"_table.qza",
                        taxonomy="qiime/"+file_name+"_taxonomy.qza", 
                        collapse_on=collapse_on) 
    #merge with metadata
    meta = pd.read_csv('metadata/'+file_name+'_metadata.tsv', sep="\t")
    meta.rename(columns={meta.columns[0]: "sample_id"}, inplace=True)
    ab = pd.merge(ab, meta, on="sample_id")
    ab['URT'] = 'NP'
    ab['Study'] = file_name
    #concatenate dataframes together
    res = pd.concat([res,ab])
os.chdir('/proj/gibbons/nbohmann/metabug/manifest/OP/')
manifest_OP = pd.read_csv('OP_manifest.csv',index_col= 0, header = None)
for file_name in manifest_OP.index:
    #pull out feature tables with total reads
    ab = qiime_to_dataframe(feature_table="qiime/"+file_name+"_table.qza",
                        taxonomy="qiime/"+file_name+"_taxonomy.qza", 
                        collapse_on=collapse_on) 
    #merge with metadata
    meta = pd.read_csv('metadata/'+file_name+'_metadata.tsv', sep="\t")
    meta.rename(columns={meta.columns[0]: "sample_id"}, inplace=True)
    ab = pd.merge(ab, meta, on="sample_id")
    ab['URT'] = 'OP'
    ab['Study'] = file_name
    #concatenate dataframes together
    res = pd.concat([res,ab])
res = res.dropna(subset = ['genus']) #filter data 
res = res[~(res.genus.str.contains('None'))&~(res.genus.str.contains('uncultured'))&~(res.genus.str.contains('Chloroplast'))]
res = clr(filter_taxa(res, min_reads=2, min_prevalence=0.5)) #center-log-ratio transformation and filtering
res['disease'] = res['Study'].map(disease)
res['log10'] = np.log10(res['reads'])
res['Study'] = res['Study'].str.replace('PRJEB15534','PRJEB22676')

## Visualize Reads
Now visualize the total reads, encoding by study and geographic location

In [None]:
res_grouped = res.groupby(['sample_id','Study','disease']).mean().reset_index()
res_grouped = res_grouped.sort_values(by = 'disease')
yscale = res_grouped['Study'].unique()
fillscale = res_grouped['disease'].unique()[::-1]
sample_plot = (ggplot(
    res_grouped, aes(x = 'Study',fill = 'disease'))
    +geom_histogram(bins = 26)
    +scale_x_discrete(limits = yscale)
    +scale_fill_discrete(limits = fillscale)
    +coord_flip()
    +theme(text = element_text(size=20),panel_background=element_rect(fill = "white",
                                colour = "white",size = 0.5, linetype = "solid"),
                                panel_grid=element_line(size = .2, linetype = "solid",colour = "gray"),
                                axis_line = element_line(size = 2, linetype = "solid",colour = "black"),
                                legend_title=element_blank(),
                                legend_position='right',
                                figure_size=(16, 12))
              )
sample_plot

## Calculate Overall Enrichments
Using logistic regression, determine taxa that are enriched in cases or controls irrespective of disease type

In [None]:
overall = pd.DataFrame()
for x in res['full_taxonomy'].unique():
    try:
        df = res[res['full_taxonomy']==x]
        df = df.copy()
        df['condition_bin'] = (df['condition'] == 'control').astype(int)
        model = smf.logit('condition_bin ~ clr + URT+ region', data = df)
        sol = model.fit(method='bfgs', disp=0)
        log2 = np.log2(df[df['condition']=='case']['relative'].mean()/
                       df[df['condition']=='control']['relative'].mean())
        overall = pd.concat([overall,
                             pd.DataFrame({
                                 'taxon':[x], 
                                 'pvalue':[sol.pvalues['clr']],
                                 'log2_foldchange':[log2]})])
    except sm.tools.sm_exceptions.PerfectSeparationError:
            print("Skipping group", x, "due to perfect predictor error")
            continue
    
overall['q'] = sm.stats.multitest.fdrcorrection(overall['pvalue'])[1]
overall = overall[overall['q']<0.05].sort_values(by = 'log2_foldchange')
overall['genus'] = overall['taxon'].str.split('|').str[-1]
overall['enrichment'] = overall['log2_foldchange']>0
overall['enrichment'] = overall['enrichment'].map({True:1,False:-1})
overall.set_index('genus',inplace = True)

## Calculate Disease-Specific Enrichment
Now we do the same within each disease type

In [None]:
disease_specific = pd.DataFrame()
for disease in res['disease'].unique():
    res_sub = res[res['disease']==disease]
    for x in res_sub['full_taxonomy'].unique():
        try:
            df = res_sub[res_sub['full_taxonomy']==x]
            df = df.copy()
            df['condition_bin'] = (df['condition'] == 'control').astype(int)
            if df['condition_bin'].nunique()==1:
                continue
            model = smf.logit('condition_bin ~ clr + URT', data = df)        
            sol = model.fit(disp=0)
            log2 = np.log2(df[df['condition']=='case']['relative'].mean()/
                   df[df['condition']=='control']['relative'].mean())
            disease_specific = pd.concat([disease_specific, pd.DataFrame({
                'taxon':[x], 
                'pvalue':[sol.pvalues['clr']], 
                'log2_foldchange':[log2], 
                'disease':disease})])
        except sm.tools.sm_exceptions.PerfectSeparationError:
            print("Skipping group", x,"in", disease, "due to perfect predictor error")
            continue
disease_specific['q'] = sm.stats.multitest.fdrcorrection(disease_specific['pvalue'])[1]
disease_specific = disease_specific[disease_specific['q']<0.05].sort_values(by = 'q')
disease_specific['genus'] = disease_specific['taxon'].str.split('|').str[-1]
disease_specific['enrichment'] = disease_specific['log2_foldchange']>0
disease_specific['enrichment'] = disease_specific['enrichment'].map({True:1,False:-1})
disease_specific = disease_specific.pivot(index = 'genus',columns = 'disease',values = 'enrichment')

## Total Enrichments
Now concatenate the significant enrichments into a dataframe and calculate abundance 

In [None]:
hits = pd.concat([disease_specific, overall[['enrichment']]],axis = 1).rename(columns = {'enrichment':'Overall'})
hits = hits.fillna(0.0)
n = res.sample_id.nunique()
prevalence = res[res.reads > 0].full_taxonomy.value_counts() / n
abundance = res[res.reads > 0].groupby('full_taxonomy')['relative'].mean()
prevalence.index = prevalence.index.str.split('|').str[-1]
abundance.index = abundance.index.str.split('|').str[-1]
prevalence = prevalence[prevalence.index.isin(hits.index)].to_dict()
abundance = abundance[abundance.index.isin(hits.index)].to_dict()
hits['prevalence'] = hits.index.map(prevalence)
hits['abundance'] = hits.index.map(abundance)
hits = hits.sort_values(by = 'abundance', ascending = False)
hits

## Plot results
Plot the results of the regressions, as well as abundance

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3,
                               figsize=(10, 7),
                               gridspec_kw={'height_ratios': [8,1,3]})



sns.heatmap(hits[['COPD','Influenza','Pneumonia','RSV','RTI',
                          'Resp. Allergies','Rhinosinusitis','Tonsillitis']].T,
                    cmap=sns.diverging_palette(220,20,center='light',as_cmap=True),
                    cbar = False,
                    ax = ax1)
sns.heatmap(hits[['Overall']].T,
                    cmap=sns.diverging_palette(220,20,center='light',as_cmap=True),
                    cbar = False,
                    ax = ax2)

sns.barplot(x=hits.index,
            y=hits['abundance'],
            ax=ax3)


font_props = FontProperties().copy()
font_props.set_size(15)

ax1.set_xticks([])
ax1.set(xlabel=None)
ax1.set_yticklabels(ax1.get_ymajorticklabels(), fontproperties=font_props)

ax2.set_xticks([])
ax2.set(xlabel=None)
ax2.set_yticklabels(ax2.get_ymajorticklabels(), fontproperties=font_props)


font_props.set_style("italic")
ax3.set_ylabel('Relative \n Abundance', fontsize = 15)
ax3.set(xlabel=None)
plt.xticks(rotation =80)
ax3.set_xticklabels(ax3.get_xmajorticklabels(), fontproperties=font_props)

plt.show()