In [1]:
import numpy as np
import itertools
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
from scipy.stats import linregress

from bamf.torchCR import *

import time

np.random.seed(123)

# set plot parameters
params = {'legend.fontsize': 18,
          'figure.figsize': (8, 7),
          'axes.labelsize': 24,
          'axes.titlesize':24,
          'axes.linewidth':3,
          'xtick.labelsize':20,
          'ytick.labelsize':20}
plt.rcParams.update(params)
plt.style.use('seaborn-colorblind')
plt.rcParams['pdf.fonttype'] = 42



In [37]:
# all pytorch needs to do is provide the gradient of a function
from functorch import grad, jacfwd, jacrev

def torch_grad(f, argnums):
    # 
    return jacrev(f, argnums)

In [66]:
f = lambda x, A: np.matmul(A, x) 

In [56]:
x = torch.tensor(np.random.randn(3), dtype=torch.float32)
A = torch.tensor(np.random.randn(3,3), dtype=torch.float32)

In [57]:
f(x, A)

tensor([ 1.8678, -0.9941,  1.2923])

In [60]:
torch_grad = jacrev(f, 0)

In [61]:
%timeit torch_grad(x, A)

245 µs ± 24.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [73]:
import jax.numpy as jnp


jax_grad = jax.jacrev(lambda x, A: jnp.matmul(A, x))


In [74]:
%timeit jax_grad(x.numpy(), A.numpy())

1.49 ms ± 58.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [38]:
torch_grad = jacrev(lambda x, A: A@x)

In [39]:
x = np.random.randn(3)
A = np.random.randn(3,3)

x = torch.tensor(x, dtype=torch.float32)
A = torch.tensor(A, dtype=torch.float32)

In [40]:
%timeit torch_grad(x, A)

293 µs ± 45.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Import data

In [2]:
# used later for model validation
# gLV_data = pd.read_csv("gLV_data/DSM_processed_mono.csv")
gLV_data = pd.read_csv("gLV_data/gLV_data_for_CR.csv")
gLV_data

Unnamed: 0,Treatments,Time,s1,s2,s3,s4,s5,s6
0,exp_1,0.0,0.086631,0.025046,0.048303,0.098556,0.051949,0.061289
1,exp_1,8.0,0.090712,0.215443,0.190611,0.174476,0.032546,0.002532
2,exp_1,16.0,0.023573,0.353033,0.282975,0.173349,0.000000,0.000000
3,exp_10,0.0,0.024086,0.034346,0.051313,0.066662,0.010591,0.013089
4,exp_10,8.0,0.059290,0.277299,0.194568,0.172302,0.012084,0.018922
...,...,...,...,...,...,...,...,...
187,exp_8,8.0,0.105865,0.058858,0.132456,0.164099,0.075496,0.017758
188,exp_8,16.0,0.083224,0.293845,0.269835,0.161158,0.025799,0.004992
189,exp_9,0.0,0.035591,0.076255,0.059318,0.069170,0.015113,0.039888
190,exp_9,8.0,0.054867,0.315027,0.211837,0.146970,0.027926,0.007899


In [3]:
# get species names
species = gLV_data.columns.values[2:]
species

array(['s1', 's2', 's3', 's4', 's5', 's6'], dtype=object)

# Define function to make predictions on test data

# Initialize model parameters

In [4]:
# dimensions 
n_s = len(species)
n_r = 3
# input to NN includes species, resources (and maybe also time) 
n_x = n_s + n_r

# CR parameters 
d = -3.*np.ones(n_s)
C = np.random.uniform(-1., 0., [n_r, n_s])
P = np.random.uniform(-5., -1., [n_r, n_s])
K = np.ones(n_r)

# dimension of hidden layer
n_h = 4

# map to hidden dimension
p_std = 1./np.sqrt(n_x)
W1 = p_std*np.random.randn(n_h, n_x)
b1 = np.random.randn(n_h)

# parameters to compute efficiency matrix
p_std = 1./np.sqrt(n_h)
W2 = p_std*np.random.randn(n_r+2*n_s, n_h) 
b2 = np.random.randn(n_r+2*n_s)

# concatenate parameter initial guess
params = np.concatenate((d, W1.flatten(), b1, C.flatten(), W2.flatten(), b2.flatten(), P.flatten(), K))

# set prior so that C is sparse 
W1prior = np.zeros_like(W1)
b1prior = np.zeros_like(b1)
Cprior = -5.*np.ones([n_r, n_s]) 
Pprior = -5.*np.ones([n_r, n_s])
W2prior = np.zeros_like(W2)
b2prior = np.zeros_like(b2)

# concatenate prior 
prior = np.concatenate((d, W1prior.flatten(), b1prior, Cprior.flatten(), W2prior.flatten(), b2prior.flatten(), Pprior.flatten(), K))

n_params = len(params)
n_params

160

# Define model

In [13]:
# using consumer resource model  
def system(t, x, params): 
    
    # species 
    s = x[:n_s]
    
    # resources
    r = torch.exp(x[n_s:])
    
    # compute state 
    state = torch.concatenate((s, r))
    
    # death rate
    d = torch.exp(params[:n_s])
    
    # map to hidden layer
    W1 = torch.reshape(params[n_s:n_s+n_x*n_h], [n_h, n_x])
    b1 = params[n_s+n_x*n_h:n_s+n_x*n_h+n_h]
    h1 = torch.tanh(W1@state + b1)
    
    # maximum consumption rate parameters
    Cmax = torch.exp(torch.reshape(params[n_s+n_x*n_h+n_h:n_s+n_x*n_h+n_h+n_r*n_s], [n_r, n_s]))
    
    # attractiveness of resource i to species j / consumption efficiency
    W2 = torch.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h], [n_r+2*n_s, n_h])
    b2 = torch.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s], [n_r+2*n_s])
    h2 = torch.sigmoid(W2@h1 + b2)
    
    # divide hidden layer into resource availability, species growth efficiency, resource production efficiency
    f = h2[:n_r]
    g = h2[n_r:n_r+n_s]
    h = h2[n_r+n_s:]
    
    # update Consumption matrix according to resource attractiveness 
    C = torch.einsum("i,ij->ij", f, Cmax)
    
    # max production rate
    Pmax = torch.exp(torch.reshape(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s:n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s+n_r*n_s], [n_r, n_s]))
    K = torch.exp(params[n_s+n_x*n_h+n_h+n_r*n_s+(n_r+2*n_s)*n_h+n_r+2*n_s+n_r*n_s:])
    
    # scaled production rate
    P = torch.einsum("ij,j->ij", Pmax, h)
    
    # rate of change of species 
    dsdt = s*(g*(C.T@r) - d)

    # rate of change of log of resources 
    dlrdt = (1. - r/K) * ((P-C)@s) 

    return torch.concatenate((dsdt, dlrdt))

In [15]:
system(1., torch.tensor([np.ones(9)]), torch.tensor(params))

RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x9 and 1x9)

# Define observation matrix

In [6]:
# define observation matrices 
O = np.zeros([n_s, n_s+n_r])
O[:n_s,:n_s] = np.eye(n_s)
O

array([[1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0.]])

# Fit model to mono culture data

In [9]:
torch.tensor([n_s, n_r, n_h], dtype=int)

tensor([6, 3, 4])

In [11]:
r0 = np.random.uniform(-3, 0, n_r)
print(r0)

model = ODE(system = system, 
            dataframe=gLV_data,
            C=O,
            CRparams = params, 
            shapes = [n_s, n_r, n_h],
            r0 = r0,
            prior = prior,
            species = species,
            alpha_0=1e-5,
            verbose=True)

# fit to data 
t0 = time.time()
model.fit(evidence_tol=1e-3, nlp_tol=1e-3, patience=1, max_fails=1)
print("Elapsed time {:.2f}s".format(time.time()-t0))

[-1.80186387 -1.72707942 -1.31334486]


RuntimeError: 

aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a):
Expected a value of type 'List[int]' for argument 'shape' but instead found type 'List[Tensor]'.
Empty lists default to List[Tensor]. Add a variable annotation to the assignment to create an empty list of another type (torch.jit.annotate(List[T, []]) where T is the type of elements in the list for Python 2)
:
  File "/tmp/ipykernel_8975/1259436785.py", line 23
    
    # map to hidden layer
    W1 = torch.reshape(params[n_s:n_s+n_x*n_h], [n_h, n_x])
         ~~~~~~~~~~~~~ <--- HERE
    b1 = params[n_s+n_x*n_h:n_s+n_x*n_h+n_h]
    h1 = torch.tanh(W1@state + b1)
'system' is being compiled since it was called from 'dX_dt'
  File "/home/jaron/Documents/BAMF/bamf/torchCR.py", line 187
        def dX_dt(t, x, params):
            # concatentate x and z
            return system(t, x, params)
                   ~~~~~~~~~~~~~~~~~~~ <--- HERE


In [None]:
def batch(self):
    # loop over each sample in dataset
    for n_t, (t_eval, Y_batch) in self.dataset.items():

        # split samples into batches
        n_samples = Y_batch.shape[0]
        for batch_inds in np.array_split(np.arange(n_samples), n_samples//self.batch_size):
            
            # run model using current parameters, output = [n_time, self.n_sys_vars]
            outputs = np.nan_to_num(self.batchODEZ(t_eval, Y_batch[batch_inds], self.params[:self.n_r], self.params[self.n_r:]))
            
            
def forward(self):
    # loop over each sample in dataset
    for n_t, (t_eval, Y_batch) in self.dataset.items():

        # split samples into batches
        for Y_measured in Y_batch:

            # run model using current parameters, output = [n_time, self.n_sys_vars]
            output = np.nan_to_num(self.runODEZ(t_eval, Y_measured, self.params[:self.n_r], self.params[self.n_r:]))

In [None]:
model.batch_size = 4
model.batch_size

model.batchODEZ = jit(vmap(model.runODEZ, (None, 0, None, None)))

In [None]:
%timeit batch(model)

In [None]:
%timeit forward(model)