## We'll use this notebook to build and grow metabolic models for the Arivale cohort, and find associations between resulting SCFA production predictions and blood metabolites.

In [None]:
import os
import pandas as pd
import numpy as np
import sklearn
import micom
import seaborn as sns
import matplotlib.pyplot as plt
from plotnine import *
import scipy
import statsmodels
from IPython.display import display
from arivale_data_interface import *
#below is some code for when you want to freeze your data. We have new samples coming in from trials, certain normalization
#procedures change, etc. So this way you can work with consistent data using get_frozen_snapshot()
frozen_ss_path='/proj/arivale/snapshots/arivale_snapshot_ISB_2020-03-16_2156/'
sn=list_snapshot_contents()
def get_frozen_snapshot(ss_name, ss_path=frozen_ss_path):
    
    return get_snapshot(ss_name, path=ss_path)

%matplotlib inline

pd.options.mode.chained_assignment = None  # default='warn'

## First we'll pull in our medium. We'll use a functionally complete medium representing a European diet. 

In [None]:
os.chdir('/proj/gibbons/nbohmann/exvivo/diets')
diet = pd.read_csv('western_completed.csv')
diet.head()

## Now we pull in our abundance data. 

In [None]:
abundance = pd.read_csv('/proj/arivale/microbiome/16S_processed/asvs.csv')
abundance

## To get taxonomy, we'll also grab the corresponding refseq taxonomy table, and map onto the abundance data

In [None]:
taxonomy = pd.read_csv('/proj/arivale/microbiome/16S_processed/taxonomy_refseq.csv')
taxonomy['taxon'] = taxonomy[['Kingdom','Phylum','Class','Order','Family','Genus']].apply(lambda row: ".".join(
    row.fillna("").str.lower()), axis=1).to_frame()
taxonomy = taxonomy.set_index('id')['taxon'].to_dict()
abundance['taxon'] = abundance['hash'].map(taxonomy)
abundance['genus'] = abundance['taxon'].str.split('.').str[-1].str.capitalize()
abundance['genus'] = abundance['genus'].replace('', np.nan)
abundance = abundance.dropna()
abundance.rename(columns = {'id':'sample_id'}, inplace = True)
abundance = abundance.groupby(by = ['sample_id','genus'])['count'].agg('sum').reset_index()
abundance['abundance'] = abundance['count']/abundance.groupby('sample_id')['count'].transform('sum')
abundance['taxon'] = abundance['genus']
abundance.head()

## Now we can build our models, using the AGORA (version 1.03) database. We'll use an abundance cutoff of 0.001 for consistency with previous analysis

In [None]:
agora = ('/proj/gibbons/refs/micom_dbs/agora103_genus.qza')
manifest = micom.workflows.build(abundance, agora, '/proj/gibbons/nbohmann/arivale/models_reclass', 
                                 cutoff = 0.001, threads = 20)

## We'll need a manifest file too

In [None]:
os.chdir('/proj/gibbons/nbohmann/arivale/models_reclass')
manifest = pd.DataFrame({'file':os.listdir()})
manifest['sample_id']= manifest['file'].str.split('.').str[0]
manifest = manifest[~manifest['file'].str.contains('manifest')]
manifest.to_csv('manifest.csv')
manifest

## Now we grow the models again, using the standard European diet and a tradeoff parameter of 0.7, for consistency

In [None]:
os.chdir('/proj/gibbons/nbohmann/arivale')
growth = micom.workflows.grow(manifest, 'models_reclass', diet, tradeoff = 0.7,threads = 20)
growth.growth_rates.to_csv('/proj/gibbons/nbohmann/arivale/reclass_growthrates.csv')
growth.exchanges.to_csv('/proj/gibbons/nbohmann/arivale/reclass_exchanges.csv')

## We can scale the flux by abundance and pull out SCFA exchanges, pivoting the table for readability

In [None]:
exchanges = pd.read_csv('/proj/gibbons/nbohmann/arivale/reclass_exchanges.csv', index_col = 0)
exchanges = exchanges[(exchanges.reaction.str.startswith('EX_but(e)'))|(exchanges.reaction.str.startswith('EX_ppa(e)')
                                                                       |(exchanges.reaction.str.startswith('EX_ac(e)')))]
exchanges = exchanges[exchanges.direction == "export"].groupby(["sample_id", "metabolite", "reaction"]).apply(
    lambda df: sum(df.flux * df.abundance)).reset_index()
exchanges = pd.pivot_table(exchanges, index = 'sample_id', columns = 'metabolite', values = 0)

## Let's add a column for total SCFA production. We'll also fill in missing values and rename columns

In [None]:
exchanges['Total'] = exchanges['but[e]'] + exchanges['ppa[e]'] + exchanges['ac[e]']
exchanges['ButyrateAndPropionate'] = exchanges['but[e]'] + exchanges['ppa[e]']
exchanges = exchanges.rename(columns = {'ac[e]':'Acetate','but[e]':'Butyrate','ppa[e]':'Propionate'})
exchanges = exchanges.fillna(0.0)
exchanges

## We'll need metadata to merge with the growth data. 

In [None]:
metadata = pd.read_csv('/proj/gibbons/arivale_16S/data/metadata.csv').rename(columns = {'id':'sample_id'})
exchanges = pd.merge(exchanges, metadata, on = 'sample_id', how = 'inner')
exchanges = exchanges.dropna(subset = 'days_in_program')
exchanges = exchanges.sort_values(by = 'days_in_program')
exchanges

## Now we'll get our blood chemistry panel and z-score all the markers. Then we'll merge the chemistries with our predictions, using a tolerance of 30 days

In [None]:
chems = get_frozen_snapshot("chemistries").sort_values(by = 'days_in_program') #get chemistries
chem_features = chems.columns[12:]  # inspect the table to find out where the metadata ends 
for feature in chem_features: #zscore all the features
    if 'PERCENTILE' in feature: #skip features that are percentiles (ordinal, shouldn't be zscored)
        continue
    else:
        chems[feature] = scipy.stats.zscore(chems[feature], nan_policy = 'omit') # zscore
chems['days_in_program'] = chems['days_in_program'].astype('float64')
chems = chems.dropna(subset = 'days_in_program')
chems_merged = pd.merge_asof( #merge scfa and chems by public client id on days in program
    exchanges, chems, 
    by="public_client_id", 
    on="days_in_program", 
    direction="nearest", 
    tolerance=30).dropna(subset=chem_features, how="all")
chem_features = chems.columns[chems.columns.isin(chem_features)]
metadata = get_frozen_snapshot('chemistries_metadata').set_index('Name')['Display Name'].to_dict() #get display names

## Here we'll define our function for finding significantly associated markers. We need to rename some of the markers, due to Statsmodels constraints

In [None]:
def find_hits(features, merged_df, marker, metadata): #enter the feature list, df, and scfa marker
    df = pd.DataFrame(columns = {'feature':[],'p':[],'rho':[]}) #initialize df 
    for x in features: #need to replace special characters to use columns names in statsmodels.ols
        x_fixed = x.replace('/','_')
        x_fixed = x_fixed.replace(' ','_')
        x_fixed = x_fixed.replace('-','_')
        x_fixed = x_fixed.replace('(','')
        x_fixed = x_fixed.replace(')','')
        x_fixed = x_fixed.replace(',','')
        merged_df.rename(columns={x:x_fixed}, inplace = True) #rename column with fixed name
        if chems_merged[x_fixed].describe()['count'] == 0.0: #continue if only NaNs
            continue
        mod = statsmodels.formula.api.ols(formula = x_fixed+' ~'+marker+' + sex + age + vendor',
                                          data = merged_df) #ordinary least squares
        res = mod.fit() #fit to regression
        new_feature = pd.DataFrame({'feature':[x_fixed],'display':metadata[x],'p':[res.pvalues[marker]],
                                    'rho':[res.params[marker]]}) #pull out coefficients and pvalues
        df = pd.concat([df,new_feature]) #add to output df
    df = df.sort_values(by = 'p') #sort output df by pval
    df.dropna(inplace = True) #drop na's
    df.reset_index(inplace = True,drop = True)
    df['p_corrected'] = statsmodels.stats.multitest.fdrcorrection(df['p'],
                                                                  method = 'indep')[1] #fdr correct pvalues
    df['measure'] = marker # add the name of the scfa marker being tested 
    return df 

## Now we'll iterate this function across all our SCFAs to find significant hits

In [None]:
df = find_hits(chem_features, chems_merged, 'Propionate', metadata)
df = pd.concat([df, find_hits(chem_features, chems_merged, 'Acetate', metadata)])
df = pd.concat([df, find_hits(chem_features, chems_merged, 'Butyrate', metadata)])
df = pd.concat([df, find_hits(chem_features, chems_merged, 'Total', metadata)])
df = df[df.feature.isin(df[df['p_corrected']<0.05]['feature'].unique())] #significance only 
df

## Now, plot significant hits

In [None]:
df_pivot = pd.pivot_table(df, columns = 'measure', 
                          index = 'display', values = 'rho') #pivot for heatmap values
df_pivot = df_pivot[['Butyrate','Propionate','Acetate','Total']]
pvals = pd.pivot_table(df, columns = 'measure',
                       index = 'display', values = 'p_corrected') #pivot for heatmap annotations
pvals = pvals[['Butyrate','Propionate','Acetate','Total']]
pvals[pvals>0.05] = np.nan
pvals[(pvals<=0.05)&(pvals>0.01)] = 100
pvals[(pvals<=0.01)&(pvals>0.001)] = 200
pvals[pvals<0.001] = 300
pvals = pvals.replace(100,'*').replace(200,'**').replace(300,'***').replace(np.nan,'')

sns.set(font_scale= 1.5)# set font
f, ax = plt.subplots(figsize=(18,3)) #initialize plot
cmap = sns.diverging_palette(230, 10, sep=20, as_cmap=True)
ax = sns.heatmap(df_pivot.T, center = 0, cmap = cmap, annot = pvals.T, #make heatmap with annotations
                 fmt='', annot_kws={'fontsize': 18, 'color':'white',
                           'verticalalignment': 'center'})

ax.set(xlabel = None)
ax.set(ylabel = None)

## Save results

In [None]:
chems_merged.to_csv('/proj/gibbons/nbohmann/exvivo/scfa_paper/arivale.csv')
df.to_csv('/proj/gibbons/nbohmann/exvivo/scfa_paper/arivale_hits.csv')