In [1]:
import numpy as np
import itertools
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
from scipy.stats import linregress

from bamf.bamfCR import *

import time

np.random.seed(123)

# set plot parameters
params = {'legend.fontsize': 18,
          'figure.figsize': (8, 7),
          'axes.labelsize': 24,
          'axes.titlesize':24,
          'axes.linewidth':3,
          'xtick.labelsize':20,
          'ytick.labelsize':20}
plt.rcParams.update(params)
plt.style.use('seaborn-colorblind')
plt.rcParams['pdf.fonttype'] = 42



# Import data

In [2]:
# used later for model validation
# gLV_data = pd.read_csv("gLV_data/DSM_processed_mono.csv")
gLV_data = pd.read_csv("gLV_data/gLV_data_for_CR.csv")
gLV_data

Unnamed: 0,Treatments,Time,s1,s2,s3,s4,s5,s6,s7,s8,s9,s10,s11,s12
0,exp_1,0.0,0.020140,0.081164,0.046799,0.080794,0.000743,0.055159,0.093193,0.058218,0.020610,0.071776,0.037899,0.066838
1,exp_1,8.0,0.158808,0.037643,0.107039,0.262354,0.021680,0.112563,0.082086,0.083350,0.000000,0.014638,0.005025,0.123679
2,exp_1,16.0,0.165128,0.000000,0.161716,0.290950,0.068935,0.118747,0.094958,0.093735,0.024552,0.013541,0.028713,0.129184
3,exp_10,0.0,0.042981,0.087288,0.035596,0.092976,0.014878,0.094003,0.083272,0.084605,0.012392,0.059649,0.001639,0.072118
4,exp_10,8.0,0.180660,0.021611,0.093696,0.285043,0.095564,0.130708,0.113519,0.127640,0.011734,0.024132,0.010197,0.140246
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
187,exp_8,8.0,0.081683,0.033165,0.105040,0.319787,0.105994,0.082605,0.000000,0.108795,0.064975,0.000000,0.036798,0.111022
188,exp_8,16.0,0.111730,0.015665,0.128014,0.331796,0.135531,0.058358,0.010763,0.088152,0.094879,0.013920,0.013347,0.161240
189,exp_9,0.0,0.063205,0.002620,0.088759,0.001612,0.012696,0.077716,0.004590,0.071100,0.097105,0.087168,0.071016,0.095851
190,exp_9,8.0,0.174650,0.000000,0.158850,0.097460,0.069699,0.115715,0.046565,0.061763,0.040231,0.027781,0.037776,0.195518


In [3]:
# get species names
species = gLV_data.columns.values[2:]
species

array(['s1', 's2', 's3', 's4', 's5', 's6', 's7', 's8', 's9', 's10', 's11',
       's12'], dtype=object)

# Define function to make predictions on test data

# Initialize model parameters

In [4]:
# dimensions 
n_s = len(species)
n_r = 3
# input to NN includes species, resources (and maybe also time) 
n_x = n_s + n_r

# CR parameters 
d = -3.*np.ones(n_s)
C = np.random.uniform(-1., 0., [n_r, n_s])
P = np.random.uniform(-5., -1., [n_r, n_s])
K = np.ones(n_r)

# dimension of hidden layer
n_h = 4

# map to hidden dimension
p_std = 1./np.sqrt(n_x)
W1 = p_std*np.random.randn(n_h, n_x)
b1 = np.random.randn(n_h)

# parameters to compute efficiency matrix
p_std = 1./np.sqrt(n_h)
W2 = p_std*np.random.randn(n_r+2*n_s, n_h) 
b2 = np.random.randn(n_r+2*n_s)

# concatenate parameter initial guess
params = np.concatenate((d, W1.flatten(), b1, C.flatten(), W2.flatten(), b2.flatten(), P.flatten(), K))

# set prior so that C is sparse 
W1prior = np.zeros_like(W1)
b1prior = np.zeros_like(b1)
Cprior = -5.*np.ones([n_r, n_s]) 
Pprior = -5.*np.ones([n_r, n_s])
W2prior = np.zeros_like(W2)
b2prior = np.zeros_like(b2)

# concatenate prior 
prior = np.concatenate((d, W1prior.flatten(), b1prior, Cprior.flatten(), W2prior.flatten(), b2prior.flatten(), Pprior.flatten(), K))

n_params = len(params)
n_params

286

# Define model

In [5]:
# using consumer resource model  
def system(t, x, params): 
    
    # species 
    s = x[:n_s]
    
    # resources
    r = jnp.exp(x[n_s:])
    
    # compute state 
    state = jnp.concatenate((s, r))
    
    # death rate
    d = jnp.exp(params[:n_s])
    
    # map to hidden layer
    W1 = np.reshape(params[n_s:n_s+n_x*n_h], [n_h, n_x])
    b1 = params[n_s+n_x*n_h:n_s+n_x*n_h+n_h]
    h1 = jnp.tanh(W1@state + b1)
    
    # maximum consumption rate parameters
    Cmax = jnp.exp(np.reshape(params[n_s+n_x*n_h+n_h:n_s+n_x*n_h+n_h+n_r*n_s], [n_r, n_s]))
    
    # attractiveness of resource i to species j / consumption efficiency
    W2 = np.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h], [n_r+2*n_s, n_h])
    b2 = np.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s], [n_r+2*n_s])
    h2 = jax.nn.sigmoid(W2@h1 + b2)
    
    # divide hidden layer into resource availability, species growth efficiency, resource production efficiency
    f = h2[:n_r]
    g = h2[n_r:n_r+n_s]
    h = h2[n_r+n_s:]
    
    # update Consumption matrix according to resource attractiveness 
    C = jnp.einsum("i,ij->ij", f, Cmax)
    
    # max production rate
    Pmax = jnp.exp(jnp.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s+n_r*n_s], [n_r, n_s]))
    K = jnp.exp(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s+n_r*n_s:])
    
    # scaled production rate
    P = jnp.einsum("ij,j->ij", Pmax, h)
    
    # rate of change of species 
    dsdt = s*(g*(C.T@r) - d)

    # rate of change of log of resources 
    dlrdt = (1. - r/K) * ((P-C)@s) 

    return jnp.append(dsdt, dlrdt)

# Define observation matrix

In [6]:
# define observation matrices 
O = np.zeros([n_s, n_s+n_r])
O[:n_s,:n_s] = np.eye(n_s)
O

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])

# Fit model to mono culture data

In [7]:
r0 = np.random.uniform(-3, 0, n_r)
print(r0)

model = ODE(system = system, 
            dataframe=gLV_data,
            C=O,
            CRparams = params, 
            r0 = r0,
            prior = prior,
            species = species,
            alpha_0=1e-5,
            verbose=True)

# fit to data 
t0 = time.time()
model.fit(evidence_tol=1e-3, nlp_tol=1e-3, patience=1, max_fails=1)
print("Elapsed time {:.2f}s".format(time.time()-t0))

[-1.88281056 -0.42854083 -2.92016665]
Total samples: 128, Updated regularization: 1.00e-05
Total weighted fitting error: 2.854
Total weighted fitting error: 2.415
Total weighted fitting error: 2.137
Total weighted fitting error: 1.868
Total weighted fitting error: 1.796
 message: Optimization terminated successfully.
 success: True
  status: 0
     fun: 1.795741331635591
       x: [-1.779e+00  6.008e-02 ...  8.453e-01  9.997e-01]
     nit: 5
     jac: [-8.832e-02 -1.574e+00 ...  3.313e-01  2.203e-04]
    nfev: 11
    njev: 11
    nhev: 5
Evidence -279.596
Updating precision...
Total samples: 128, Updated regularization: 2.04e-05
Total weighted fitting error: 18.971
Total weighted fitting error: 18.091
Total weighted fitting error: 18.038
 message: Optimization terminated successfully.
 success: True
  status: 0
     fun: 18.037609100341797
       x: [-1.908e+00  1.340e-02 ...  8.296e-01  9.996e-01]
     nit: 3
     jac: [ 1.650e-01 -2.337e+00 ...  5.235e-01  3.217e-04]
    nfev: 4
    

In [8]:
def batch(self):
    # loop over each sample in dataset
    for n_t, (t_eval, Y_batch) in self.dataset.items():

        # split samples into batches
        n_samples = Y_batch.shape[0]
        for batch_inds in np.array_split(np.arange(n_samples), n_samples//self.batch_size):
            
            # run model using current parameters, output = [n_time, self.n_sys_vars]
            outputs = np.nan_to_num(self.batchODEZ(t_eval, Y_batch[batch_inds], self.params[:self.n_r], self.params[self.n_r:]))
            
            
def forward(self):
    # loop over each sample in dataset
    for n_t, (t_eval, Y_batch) in self.dataset.items():

        # split samples into batches
        for Y_measured in Y_batch:

            # run model using current parameters, output = [n_time, self.n_sys_vars]
            output = np.nan_to_num(self.runODEZ(t_eval, Y_measured, self.params[:self.n_r], self.params[self.n_r:]))

In [9]:
model.batch_size = 4
model.batch_size

model.batchODEZ = jit(vmap(model.runODEZ, (None, 0, None, None)))

In [10]:
%timeit batch(model)

1.25 s ± 28.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
%timeit forward(model)

1 s ± 30.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
