## We'll use this notebook to build and grow metabolic models for the Arivale cohort, and find associations between resulting SCFA production predictions and blood metabolites. Arivale data is not publically available, but is available to qualified researchers upon request given a data use agreement. Inquiries can be made at data-access@isbscience.org. 

In [None]:
import os
import pandas as pd
import numpy as np
import micom
from plotnine import *
import scipy
import statsmodels

## The ADI is an internal tool used for wrangling Arivale data 
from arivale_data_interface import *
frozen_ss_path='/proj/arivale/snapshots/arivale_snapshot_ISB_2020-03-16_2156/'
sn=list_snapshot_contents()
def get_frozen_snapshot(ss_name, ss_path=frozen_ss_path):
    
    return get_snapshot(ss_name, path=ss_path)

%matplotlib inline

## First we'll pull in our medium. We'll use a functionally complete medium representing a European diet. 

In [None]:
# Read in the medium 
medium = pd.read_csv('../media/european.csv')

## Now we pull in our abundance data. ASV identifiers are present in a different file, which we'll map onto the taxonomy dataframe

In [None]:
# Load the abundance data
taxonomy = pd.read_csv('/proj/arivale/microbiome/16S_processed/asvs.csv')
# And the associated ASV identifiers
asv = pd.read_csv('/proj/arivale/microbiome/16S_processed/taxonomy_refseq.csv')
# Pull the taxonomy info into a single colum
asv['taxon'] = asv[['Kingdom','Phylum','Class','Order','Family','Genus']].apply(lambda row: ".".join(
    row.fillna("").str.lower()), axis=1).to_frame()
# Set the index as the id (hash) and turn into a dict
asv = asv.set_index('id')['taxon'].to_dict()
# Now map the ASV dict onto the taxonomy frame
taxonomy['taxon'] = taxonomy['hash'].map(asv)
# Create a genus column by splitting the taxonomy column
taxonomy['genus'] = taxonomy['taxon'].str.split('.').str[-1].str.capitalize()
# Replace empty classifications with NaNs and drop 
taxonomy['genus'] = taxonomy['genus'].replace('', np.nan)
taxonomy = taxonomy.rename(columns = {'id':'sample_id',
                                      'taxon':'id',
                                      'count':'abundance'})
taxonomy = taxonomy.dropna()

## We'll also identify the model database we want to use to construct the models

In [None]:
agora = ('../agora/data')

## Now we can build our models, using the AGORA (version 1.03) database. We'll use an abundance cutoff of 0.001 for consistency with previous analysis

In [None]:
manifest = micom.workflows.build(taxonomy, agora, '../models/arivale', 
                                 cutoff = 0.001, threads = 20)

## Now we grow the models again, using the standard European diet and a tradeoff parameter of 0.7, for consistency

In [None]:
growth = micom.workflows.grow(manifest, '../models/arivale', medium, tradeoff = 0.7,threads = 20)
exchanges = micom.measures.production_rates(growth)

## We can scale the flux by abundance and pull out SCFA exchanges, pivoting the table for readability

In [None]:
# Isolate SCFA production
exchanges = exchanges[(exchanges['reaction'].str.startswith('EX_but(e)'))|
                      (exchanges['reaction'].str.startswith('EX_ppa(e)'))|
                      (exchanges['reaction'].str.startswith('EX_ac(e)'))]
# Pivot the table
exchanges = pd.pivot_table(exchanges, 
                           index = 'sample_id', 
                           columns = 'name', 
                           values = 'flux')

## Let's add a column for total SCFA production. We'll also fill in missing values and rename columns

In [None]:
# Add column for total SCFAs
exchanges['Total'] = exchanges['but[e]'] + exchanges['ppa[e]'] + exchanges['ac[e]']
# Add column for Butyrate and Propionate (both anti-inflammatory)
exchanges['ButyrateAndPropionate'] = exchanges['but[e]'] + exchanges['ppa[e]']
# Fill NaNs with zeros
exchanges = exchanges.fillna(0.0)

## We'll need metadata to merge with the growth data. 

In [None]:
# Grab the metadata
metadata = pd.read_csv('/proj/gibbons/arivale_16S/data/metadata.csv').rename(columns = {'id':'sample_id'})
# Merge the exchanges with the metadata
exchanges = pd.merge(exchanges, metadata, on = 'sample_id', how = 'inner')
# Drop rows that don't have days in program data and sort 
exchanges = exchanges.dropna(subset = 'days_in_program').sort_values(by = 'days_in_program')

## Now we'll get our blood chemistry panel and z-score all the markers. Then we'll merge the chemistries with our predictions, using a tolerance of 30 days

In [None]:
# Grab the chemistries data
chems = get_frozen_snapshot('chemistries').sort_values(by = 'days_in_program')
# Identify the names of features in the chemistries
chem_features = chems.columns[12:]
# For each feature, zscore the measurements, skipping any percentiles
for feature in chem_features:
    if 'PERCENTILE' in feature:
        continue
    else:
        chems[feature] = scipy.stats.zscore(chems[feature], nan_policy = 'omit')
# Convert days in program to a float
chems['days_in_program'] = chems['days_in_program'].astype('float64')
# Drop NaNs
chems = chems.dropna(subset = 'days_in_program')
# Merge the exchanges and chemistries based on days in program, finding points within 30 days of each other
chems_merged = pd.merge_asof(
    exchanges, chems, 
    by='public_client_id', 
    on='days_in_program', 
    direction='nearest', 
    tolerance=30).dropna(subset=chem_features, how='all')
# Pull in chemistry metadata to apply full names to blood chemistries
metadata = get_frozen_snapshot('chemistries_metadata').set_index('Name')['Display Name'].to_dict() #get display names

## Here we'll define our function for finding significantly associated markers. We need to rename some of the markers, due to Statsmodels formula constraints

In [None]:
# Enter the list of features, the dataframe, the SCFA of interest and the associated metadata
def find_hits(features, merged_df, marker, metadata):
    # Initialize a new dataframe to contain p values and beta coefficients for each feature
    df = pd.DataFrame(columns = {'feature':[],'p':[],'beta hgfd':[]})
    # Iterate through features
    for x in features: 
        # Rename features for use in statsmodels OLS
        x_fixed = x.replace('/','_')
        x_fixed = x_fixed.replace(' ','_')
        x_fixed = x_fixed.replace('-','_')
        x_fixed = x_fixed.replace('(','')
        x_fixed = x_fixed.replace(')','')
        x_fixed = x_fixed.replace(',','')
        merged_df.rename(columns={x:x_fixed}, inplace = True)
        # If count is equal to zero, skip this feature
        if chems_merged[x_fixed].describe()['count'] == 0.0:
            continue
        # Initialize regression model
        mod = statsmodels.formula.api.ols(formula = x_fixed+' ~'+marker+' + sex + age + vendor',
                                          data = merged_df) 
        # Fit to regression
        res = mod.fit()
        # Concatenate new feature to results dataframe
        new_feature = pd.DataFrame({'feature':[x_fixed],'display':metadata[x],'p':[res.pvalues[marker]],
                                    'rho':[res.params[marker]]})
        df = pd.concat([df,new_feature]) #add to output df
    # Reformat the dataframe
    df = df.sort_values(by = 'p')
    df.dropna(inplace = True)
    df.reset_index(inplace = True,drop = True)
    # FDR Correction of P values
    df['p_corrected'] = statsmodels.stats.multitest.fdrcorrection(df['p'],
                                                                  method = 'indep')[1] 
    # Add the name of the SCFA being used as a dependent variable
    df['measure'] = marker
    return df 

## Now we'll iterate this function across all our SCFAs to find significant hits

In [None]:
df = find_hits(chem_features, chems_merged, 'Propionate', metadata)
df = pd.concat([df, find_hits(chem_features, chems_merged, 'Acetate', metadata)])
df = pd.concat([df, find_hits(chem_features, chems_merged, 'Butyrate', metadata)])
df = pd.concat([df, find_hits(chem_features, chems_merged, 'Total', metadata)])
df = df[df.feature.isin(df[df['p_corrected']<0.05]['feature'].unique())] #significance only 


## We need to replace the pvalues with a symbolic representation. This function will do that when applied to a dataframe of pvalues

In [None]:
def pvalue_stars(val):
    if val < 0.001:
        return '***'
    elif val < 0.01:
        return '**'
    elif val < 0.05: 
        return '*'
    else:
        return ''

## Now, build the pivot tables and plot the results in a heatmap

In [None]:
# Pivot the table
df_pivot = pd.pivot_table(df, columns = 'measure', 
                          index = 'display', values = 'rho')
# Reorder the columns
df_pivot = df_pivot[['Butyrate','Propionate','Acetate','Total']]
# Build a new pivoted table of p values 
pvals = pd.pivot_table(df, columns = 'measure',
                       index = 'display', values = 'p_corrected') #pivot for heatmap annotations
# Reorder the columns
pvals = pvals[['Butyrate','Propionate','Acetate','Total']]
pvals = pvals.applymap(pvalue_stars)

# Build the heatmap
sns.set(font_scale= 1.5)# set font
f, ax = plt.subplots(figsize=(5,9)) #initialize plot
cmap = sns.diverging_palette(230, 10, sep=20, as_cmap=True)
ax = sns.heatmap(df_pivot, center = 0, cmap = cmap, annot = pvals, #make heatmap with annotations
                 fmt='', annot_kws={'fontsize': 18, 'color':'white',
                           'verticalalignment': 'center'})

ax.set(xlabel = None)
ax.set(ylabel = None)

## Save results to use later

In [None]:
chems_merged.to_csv('../data/arivale_exchanges.csv')
df.to_csv('../data/arivale_hits.csv')