In [1]:
import numpy as np
import itertools
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
from scipy.stats import linregress, spearmanr

from bamf.bamfCR import *

import time

# set plot parameters
params = {'legend.fontsize': 18,
          'figure.figsize': (8, 7),
          'axes.labelsize': 24,
          'axes.titlesize':24,
          'axes.linewidth':3,
          'xtick.labelsize':20,
          'ytick.labelsize':20,
          'xtick.labelsize':20,
          'ytick.labelsize':20}
plt.rcParams.update(params)
plt.style.use('seaborn-colorblind')
plt.rcParams['pdf.fonttype'] = 42

np.random.seed(123)



# Import data

In [2]:
# used later for model validation
gLV_data = pd.read_csv("gLV_data/2021_02_19_MultifunctionalDynamicData.csv").rename(columns={"Experiments":"Treatments"})
gLV_data['Time'] = np.array(gLV_data['Time'].values, float) 
gLV_data.head()

Unnamed: 0,Treatments,Time,PC_OD,PJ_OD,BV_OD,BF_OD,BO_OD,BT_OD,BC_OD,BY_OD,...,CG_OD,ER_OD,RI_OD,CC_OD,DL_OD,DF_OD,Butyrate,Acetate,Lactate,Succinate
0,PC-BV-BT-BC-BP-EL-FP-CH-AC-BH-CG-ER-RI-DF,0.0,0.000471,0.0,0.000471,0.0,0.0,0.000471,0.000471,0.0,...,0.000471,0.000471,0.000471,0.0,0.0,0.000471,0.0,0.0,28.0,0.0
1,PC-BV-BT-BC-BP-EL-FP-CH-AC-BH-CG-ER-RI-DF,16.0,0.465116,0.0,0.029207,0.0,0.0,0.249717,0.500651,0.0,...,0.024339,0.327601,0.00146,0.0,0.0,0.39283,23.092697,47.849302,18.910852,26.141885
2,PC-BV-BT-BC-BP-EL-FP-CH-AC-BH-CG-ER-RI-DF,32.0,0.104523,0.0,0.027928,0.0,0.0,0.220107,0.38021,0.0,...,0.020739,0.293384,0.00083,0.0,0.0,0.280111,23.996267,38.915218,17.977137,26.884748
3,PC-BV-BT-BC-BP-EL-FP-CH-AC-BH-CG-ER-RI-DF,48.0,0.124852,0.0,0.012194,0.0,0.0,0.268268,0.33397,0.0,...,0.02366,0.188188,0.002366,0.0,0.0,0.224952,24.839219,34.325914,19.406971,31.628061
4,BV-BF-BO-BT-BU-DP-BL-BA-BP-EL-FP-CH-AC-BH-CG-E...,0.0,0.0,0.0,0.000388,0.000388,0.000388,0.000388,0.0,0.0,...,0.000388,0.000388,0.000388,0.0,0.0,0.0,0.0,0.0,28.0,0.0


In [3]:
# rearrange species in order of most to least diauxic lookin 
species = gLV_data.columns.values[2:-4]
species

array(['PC_OD', 'PJ_OD', 'BV_OD', 'BF_OD', 'BO_OD', 'BT_OD', 'BC_OD',
       'BY_OD', 'BU_OD', 'DP_OD', 'BL_OD', 'BA_OD', 'BP_OD', 'CA_OD',
       'EL_OD', 'FP_OD', 'CH_OD', 'AC_OD', 'BH_OD', 'CG_OD', 'ER_OD',
       'RI_OD', 'CC_OD', 'DL_OD', 'DF_OD'], dtype=object)

# Set parameter prior based on monoculture growth curves

In [4]:
# global parameters 
n_s = len(species)
n_r = 2

In [5]:
# maximum consumption rate parameters
# C = np.zeros([n_r, n_s]) 
# C[1:,:] = -1.
# C[2:,:] = -5.
C = np.random.uniform(-1., 0., [n_r, n_s])
# C[0] = 0.
C.T

array([[-0.30353081, -0.67704109],
       [-0.71386067, -0.63821134],
       [-0.77314855, -0.77173677],
       [-0.44868523, -0.70628595],
       [-0.28053103, -0.36902388],
       [-0.57689354, -0.90789506],
       [-0.0192358 , -0.56629883],
       [-0.31517026, -0.56913724],
       [-0.5190681 , -0.5063149 ],
       [-0.60788248, -0.57416971],
       [-0.65682198, -0.68773878],
       [-0.27095029, -0.57364869],
       [-0.56142776, -0.10661084],
       [-0.9403221 , -0.05583998],
       [-0.60195574, -0.49816332],
       [-0.26200459, -0.37604705],
       [-0.81750827, -0.8843816 ],
       [-0.82454824, -0.68271452],
       [-0.46844863, -0.58517379],
       [-0.46817241, -0.13369084],
       [-0.36559904, -0.74954463],
       [-0.15056821, -0.51696574],
       [-0.27554468, -0.01444021],
       [-0.38897649, -0.48051488],
       [-0.27755662, -0.38710547]])

In [6]:
# production rate parameters
# P = -5.*np.ones([n_r, n_s])
# P[1:,:] = -3.
P = np.random.uniform(-5., -1., [n_r, n_s])
# P[0] = -5.

# carrying capacity of resources 
K = np.ones(n_r)

P.T

array([[-4.51748534, -2.62227248],
       [-1.6946368 , -2.77285923],
       [-2.58775949, -4.36416142],
       [-2.81972797, -4.38771794],
       [-3.62894466, -2.21788188],
       [-3.78351684, -3.72493429],
       [-3.33191116, -2.23211882],
       [-2.27479694, -2.782467  ],
       [-1.49817263, -3.4441977 ],
       [-2.95831065, -1.29947004],
       [-2.32274487, -1.63332001],
       [-2.65625379, -3.57040973],
       [-2.50038599, -4.82563414],
       [-2.3012438 , -3.78092771],
       [-1.63063025, -3.40725727],
       [-4.66722005, -2.18016468],
       [-1.94526863, -1.01856607],
       [-4.0253345 , -3.57634054],
       [-4.22310816, -1.94980874],
       [-2.71017217, -2.62729233],
       [-4.61714993, -2.23319281],
       [-1.45869269, -4.39549019],
       [-2.49100411, -3.40449483],
       [-2.10633457, -4.03657641],
       [-4.93548317, -3.62617594]])

# Define function to make predictions on test data

In [7]:
# Define function to make predictions on test data

def plot(model, df_test):
    all_treatments = df_test.Treatments.values
    unique_treatments = np.unique(all_treatments)
    numspecies = len(species)
    
    # save true values and predictions
    true = []
    pred = []
    stdv = []
    spcs = []

    # pull a random community trajectory
    for treatment in unique_treatments:
        comm_inds = np.in1d(df_test['Treatments'].values, treatment)
        comm_data = df_test.iloc[comm_inds].copy()

        # make sure comm_data is sorted in chronological order
        comm_data.sort_values(by='Time', ascending=True, inplace=True)
        tspan = comm_data.Time.values

        # pull just the community data
        output_true = comm_data[species].values

        # run model using parameters
        x_test = np.copy(output_true[0, :])
        
        # predict end-point measured values
        output, output_stdv, _ = model.predict_latent(x_test, tspan)
        true.append(output_true[-1])
        pred.append(output[-1, :len(species)])
        stdv.append(output_stdv[-1, :len(species)])
        spcs.append(species)

        # increase teval
        t_eval = np.linspace(0, tspan[-1])
        steps = len(t_eval)
        output, output_stdv, _ = model.predict_latent(x_test, t_eval)
        # output = model.predict_point(x_test, t_eval)

        # plot the results
        plt.figure(figsize=(9, 6))
        ylim = 0
        for i, sp in enumerate(species):
            out = output[:,i]
            out_true = output_true[:, i]
            if out[0] > 0:
                plt.scatter(tspan, out_true, color='C{}'.format(i))
                plt.plot(t_eval, out, label=f"Predicted {sp}", color='C{}'.format(i))
                plt.fill_between(t_eval, out-output_stdv[:, i], out+output_stdv[:, i], color='C{}'.format(i), alpha=0.2)
                
        plt.xlabel("Time (hr)")
        plt.ylabel("Species abundance")
        #plt.legend(loc='upper left')
        plt.title(f"{treatment}")
        plt.savefig("figures/CRNN_mono_s.pdf")
        plt.show()
        
        # plot predictions of hidden variables
        for k in range(output.shape[-1] - (i+1)):
            plt.plot(t_eval, np.exp(output[:,i+1+k]), label=f"Predicted R{k+1}", color=f'C{k+1}')
            '''plt.fill_between(t_eval, 
                             np.exp(output[:,i+1+k] - stdv[:,i+1+k]), 
                             np.exp(output[:,i+1+k] + stdv[:,i+1+k]), 
                             color=f'C{k+1}', alpha=0.2)'''

        plt.legend()
        plt.ylabel("Resource concentration")
        plt.xlabel("Time (hr)")
        # plt.savefig("figures/CRNN_mono_r.pdf")
        plt.show()
        
    return np.concatenate(true), np.concatenate(pred), np.concatenate(stdv), np.concatenate(spcs)

# Initialize model parameters

In [8]:
# input to NN includes species, resources (and maybe also time) 
n_x = n_s + n_r

# dimension of hidden layer
n_h = 8

# normalizing constant for time as input
t_max = np.max(gLV_data.Time.values)

# log of death rate 
d = -3.*np.ones(n_s)

# map to hidden dimension
p_std = 1./np.sqrt(n_x)
W1 = p_std*np.random.randn(n_h, n_x)
b1 = np.random.randn(n_h)

# parameters to compute efficiency matrix
p_std = 1./np.sqrt(n_h)
W2 = p_std*np.random.randn(n_r+2*n_s, n_h) 
b2 = np.random.randn(n_r+2*n_s)

# concatenate parameter initial guess
params = np.concatenate((d, W1.flatten(), b1, C.flatten(), W2.flatten(), b2.flatten(), P.flatten(), K))

# set prior so that C is sparse 
W1prior = np.zeros_like(W1)
b1prior = np.zeros_like(b1)
Cprior = -5.*np.ones([n_r, n_s]) 
Pprior = -5.*np.ones([n_r, n_s])
W2prior = np.zeros_like(W2)
b2prior = np.zeros_like(b2)

# concatenate prior 
prior = np.concatenate((d, W1prior.flatten(), b1prior, Cprior.flatten(), W2prior.flatten(), b2prior.flatten(), Pprior.flatten(), K))

n_params = len(params)
n_params

819

# Define model

In [9]:
# using consumer resource model  
def system(t, x, params): 
    
    # species 
    s = x[:n_s]
    
    # resources
    r = jnp.exp(x[n_s:])
    
    # compute state 
    state = jnp.concatenate((s, r))
    
    # death rate
    d = jnp.exp(params[:n_s])
    
    # map to hidden layer
    W1 = np.reshape(params[n_s:n_s+n_x*n_h], [n_h, n_x])
    b1 = params[n_s+n_x*n_h:n_s+n_x*n_h+n_h]
    h1 = jnp.tanh(W1@state + b1)
    
    # maximum consumption rate parameters
    Cmax = jnp.exp(np.reshape(params[n_s+n_x*n_h+n_h:n_s+n_x*n_h+n_h+n_r*n_s], [n_r, n_s]))
    
    # attractiveness of resource i to species j / consumption efficiency
    W2 = np.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h], [n_r+2*n_s, n_h])
    b2 = np.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s], [n_r+2*n_s])
    h2 = jax.nn.sigmoid(W2@h1 + b2)
    
    # divide hidden layer into resource availability, species growth efficiency, resource production efficiency
    f = h2[:n_r]
    g = h2[n_r:n_r+n_s]
    h = h2[n_r+n_s:]
    
    # update Consumption matrix according to resource attractiveness 
    C = jnp.einsum("i,ij->ij", f, Cmax)
    
    # max production rate
    Pmax = jnp.exp(jnp.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s+n_r*n_s], [n_r, n_s]))
    K = jnp.exp(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s+n_r*n_s:])
    
    # scaled production rate
    P = jnp.einsum("ij,j->ij", Pmax, h)
    
    # rate of change of species 
    dsdt = s*(g*(C.T@r) - d)

    # rate of change of log of resources 
    dlrdt = (1. - r/K) * ((P-C)@s) 

    return jnp.append(dsdt, dlrdt)

# Define observation matrix

In [10]:
# define observation matrices 
O = np.zeros([n_s, n_s+n_r])
O[:n_s,:n_s] = np.eye(n_s)
O

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],


# Fit model to mono culture data

In [11]:
r0 = np.random.uniform(-1, 0, n_r)
# r0 = np.sort(r0)[::-1]
print(r0)

model = ODE(system = system, 
            dataframe=gLV_data,
            C=O,
            CRparams = params, 
            r0 = r0,
            prior = prior,
            species = species,
            alpha_0=1e-5,
            verbose=True)

# fit to data 
t0 = time.time()
model.fit(evidence_tol=1e-3, patience=1, max_fails=2)
print("Elapsed time {:.2f}s".format(time.time()-t0))

[-0.6737644  -0.51678308]
Updating precision...
Total samples: 7125, Updated regularization: 1.00e-05
Total weighted fitting error: 79.557


KeyboardInterrupt: 

In [None]:
# plot fit to mono data 
# model.update_precision()
true, pred, stdv, spcs = plot(model, gLV_data)

In [None]:
for i,sp in enumerate(np.unique(spcs)):
    sp_inds = spcs == sp
    inds_pos = true[sp_inds] > 0
    
    R = linregress(true[sp_inds][inds_pos], pred[sp_inds][inds_pos]).rvalue
    plt.scatter(true[sp_inds][inds_pos], pred[sp_inds][inds_pos], c=f"C{i}", label=sp + " R={:.2f}".format(R))
plt.legend(loc="lower right")
plt.show()

# Correlate predicted resource with measured metabolites

In [None]:
# lactate measurements 
Lactate = {}

for treatment, df in gLV_data.groupby("Treatments"):
    Lactate[treatment] = df['Lactate'].values[-1]
    
# acetate measurements 
Acetate = {}

for treatment, df in gLV_data.groupby("Treatments"):
    Acetate[treatment] = df['Acetate'].values[-1] 
    
# succinate measurements 
Succinate = {}

for treatment, df in gLV_data.groupby("Treatments"):
    Succinate[treatment] = df['Succinate'].values[-1]

In [None]:
# resource predictions 
R1_preds = {}

# loop over each sample in dataset
for treatment, t_eval, Y_measured in model.dataset:

    # run model using current parameters, output = [n_time, self.n_sys_vars]
    output = np.nan_to_num(model.runODE(t_eval, Y_measured[0], model.params[:n_r], model.params[n_r:]))
    
    R1_preds[treatment] = np.exp(output[-1, n_s:])[0]

# resource predictions 
R2_preds = {}

# loop over each sample in dataset
for treatment, t_eval, Y_measured in model.dataset:

    # run model using current parameters, output = [n_time, self.n_sys_vars]
    output = np.nan_to_num(model.runODE(t_eval, Y_measured[0], model.params[:n_r], model.params[n_r:]))
    
    R2_preds[treatment] = np.exp(output[-1, n_s:])[1]

In [None]:
pred_vals = []
true_vals = []

for treatment, df in gLV_data.groupby("Treatments"):
    pred_vals.append(R1_preds[treatment])
    true_vals.append(Lactate[treatment])

rho, pvalue = spearmanr(pred_vals, true_vals)

plt.scatter(pred_vals, true_vals, label=r"$\rho$={:.3f}, p-value={:.2e}".format(rho, pvalue))
plt.xlabel("Predicted resource 1")
plt.ylabel("Measured Lactate (mM)")
plt.legend(loc='upper left', fontsize=16)
plt.title("Lactate vs. R1")
plt.show()

In [None]:
pred_vals = []
true_vals = []

for treatment, df in gLV_data.groupby("Treatments"):
    pred_vals.append(R2_preds[treatment])
    true_vals.append(Lactate[treatment])

rho, pvalue = spearmanr(pred_vals, true_vals)

plt.scatter(pred_vals, true_vals, label=r"$\rho$={:.3f}, p-value={:.2e}".format(rho, pvalue))
plt.xlabel("Predicted resource 2")
plt.ylabel("Measured Lactate (mM)")
plt.legend(loc='upper left', fontsize=16)
plt.title("Lactate vs. R2")
plt.show()

In [None]:
pred_vals = []
true_vals = []

for treatment, df in gLV_data.groupby("Treatments"):
    pred_vals.append(R1_preds[treatment])
    true_vals.append(Acetate[treatment])

rho, pvalue = spearmanr(pred_vals, true_vals)

plt.scatter(pred_vals, true_vals, label=r"$\rho$={:.3f}, p-value={:.2e}".format(rho, pvalue))
plt.xlabel("Predicted resource 1")
plt.ylabel("Measured Acetate (mM)")
plt.legend(loc='upper left', fontsize=16)
plt.title("Acetate vs. R1")
plt.show()

In [None]:
pred_vals = []
true_vals = []

for treatment, df in gLV_data.groupby("Treatments"):
    pred_vals.append(R2_preds[treatment])
    true_vals.append(Acetate[treatment])

rho, pvalue = spearmanr(pred_vals, true_vals)

plt.scatter(pred_vals, true_vals, label=r"$\rho$={:.3f}, p-value={:.2e}".format(rho, pvalue))
plt.xlabel("Predicted resource 2")
plt.ylabel("Measured Acetate (mM)")
plt.legend(loc='upper left', fontsize=16)
plt.title("Acetate vs. R2")
plt.show()

In [None]:
pred_vals = []
true_vals = []

for treatment, df in gLV_data.groupby("Treatments"):
    pred_vals.append(R1_preds[treatment])
    true_vals.append(Succinate[treatment])

rho, pvalue = spearmanr(pred_vals, true_vals)

plt.scatter(pred_vals, true_vals, label=r"$\rho$={:.3f}, p-value={:.2e}".format(rho, pvalue))
plt.xlabel("Predicted resource 1")
plt.ylabel("Measured Succinate (mM)")
plt.legend(loc='upper left', fontsize=16)
plt.title("Succinate vs. R1")
plt.show()

In [None]:
pred_vals = []
true_vals = []

for treatment, df in gLV_data.groupby("Treatments"):
    pred_vals.append(R2_preds[treatment])
    true_vals.append(Succinate[treatment])

rho, pvalue = spearmanr(pred_vals, true_vals)

plt.scatter(pred_vals, true_vals, label=r"$\rho$={:.3f}, p-value={:.2e}".format(rho, pvalue))
plt.xlabel("Predicted resource 2")
plt.ylabel("Measured Succinate (mM)")
plt.legend(loc='upper left', fontsize=16)
plt.title("Succinate vs. R2")
plt.show()

# Correlation between Acetate and Lactate

In [None]:
lactate_vals = []
acetate_vals = []

for treatment, df in gLV_data.groupby("Treatments"):
    lactate_vals.append(df['Lactate'].values[-1])
    acetate_vals.append(df['Acetate'].values[-1])

In [None]:
rho, pvalue = spearmanr(acetate_vals, lactate_vals)

plt.scatter(acetate_vals, lactate_vals, label=r"$\rho$={:.3f}, p-value={:.2e}".format(rho, pvalue))
plt.xlabel("Measured Acetate (mM)")
plt.ylabel("Measured Lactate (mM)")
plt.legend(loc='upper left', fontsize=16)
plt.savefig("figures/lactate_vs_acetate.pdf")
plt.show()