In [1]:
import numpy as np
import itertools
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scipy as sp
from scipy.stats import linregress

from bamf.torchCR import *

import time

np.random.seed(123)

# set plot parameters
params = {'legend.fontsize': 18,
          'figure.figsize': (8, 7),
          'axes.labelsize': 24,
          'axes.titlesize':24,
          'axes.linewidth':3,
          'xtick.labelsize':20,
          'ytick.labelsize':20}
plt.rcParams.update(params)
plt.style.use('seaborn-colorblind')
plt.rcParams['pdf.fonttype'] = 42

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-akwpmt8_ because the default path (/home/jaron/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


# Import data

In [2]:
# used later for model validation
gLV_data = pd.read_csv("gLV_data/DSM_processed_mono.csv")
# gLV_data = pd.read_csv("gLV_data/gLV_data_for_CR.csv")
gLV_data

Unnamed: 0,Treatments,Time,CA,BT,BU,CS,CD,DP,CH,BV
0,DSM27147-BT,0.0,0.000000,0.005000,0.000000,0.0,0.005000,0.0,0.0,0.000000
1,DSM27147-BT,12.0,0.000000,1.309089,0.000000,0.0,0.219745,0.0,0.0,0.000000
2,DSM27147-BT,24.0,0.000000,1.069737,0.000000,0.0,0.109763,0.0,0.0,0.000000
3,DSM27147-BT-BU-BV-CA,0.0,0.002000,0.002000,0.002000,0.0,0.002000,0.0,0.0,0.002000
4,DSM27147-BT-BU-BV-CA,12.0,0.010216,1.074319,0.070321,0.0,0.287990,0.0,0.0,0.119654
...,...,...,...,...,...,...,...,...,...,...
356,MonocultureDSM,54.0,0.000000,0.000000,0.000000,0.0,0.146533,0.0,0.0,0.000000
357,MonocultureDSM,57.0,0.000000,0.000000,0.000000,0.0,0.149667,0.0,0.0,0.000000
358,MonocultureDSM,60.0,0.000000,0.000000,0.000000,0.0,0.095033,0.0,0.0,0.000000
359,MonocultureDSM,63.0,0.000000,0.000000,0.000000,0.0,0.074300,0.0,0.0,0.000000


In [3]:
# get species names
species = gLV_data.columns.values[2:]
species

array(['CA', 'BT', 'BU', 'CS', 'CD', 'DP', 'CH', 'BV'], dtype=object)

# Fit model 

In [4]:
model = CRNN(gLV_data, species, n_r=2, n_h=10, device='cuda', verbose=True)

In [None]:
model.fit(evidence_tol=1e-2, nlp_tol=1e-2, patience=1, max_fails=1)

Total samples: 294, Updated regularization: 1.00e-05
Total weighted fitting error: 53.705
Total weighted fitting error: 43.484
Total weighted fitting error: 16.494
Total weighted fitting error: 14.825
     fun: 14.824535712575862
     jac: array([ 6.40288115e-01,  1.75653518e+00,  7.53703107e-01, -9.18849734e-01,
        1.99161869e+00,  1.25346517e+00,  2.48681382e-01, -9.12704863e-02,
        1.05918204e-01,  3.09189226e+00,  5.80876969e-01,  3.82084366e+00,
       -1.89322885e+00,  9.22001258e-01, -1.29637477e-01, -8.53043294e-02,
       -4.04785794e-01, -3.37116080e+00,  9.11360457e-01,  7.43410957e+00,
       -3.73141273e+00,  1.67733006e+00, -1.80920919e-01,  8.10497161e-03,
       -2.04143025e-01, -5.33708358e+00, -3.97574259e-04,  4.73843637e-02,
       -4.48168543e-03, -1.93502397e-02, -2.16811265e-04,  3.42428156e-02,
        3.74675899e-02, -7.77094432e-02, -1.46583205e-02,  1.14902801e-01,
       -1.99038572e-03, -5.50978424e-03,  2.72906674e-04,  1.49486612e-02,
        5.

Evidence -321.130
Updating precision...


In [None]:
# Define function to make predictions on test data
def predict(self, x_test, t_eval):

    # convert to torch tensors
    t_eval = torch.tensor(t_eval, dtype=torch.float32, device=self.device)
    x_test = torch.atleast_2d(torch.tensor(x_test, dtype=torch.float32, device=self.device))
    
    # integrate forward sensitivity equations
    xYZ = self.runODEZ(t_eval, x_test, self.r0, self.params)
    Y_predicted = torch.nan_to_num(xYZ[0]).cpu().numpy()
    Y = xYZ[1]
    Z = xYZ[2:]

    # collect gradients and reshape
    Z = torch.concatenate([Z_i.reshape(Z_i.shape[0], Z_i.shape[1], -1) for Z_i in Z], -1)

    # stack gradient matrices
    G = torch.concatenate((Y, Z), axis=-1)

    # calculate covariance of each output (dimension = [steps, outputs])
    BetaInv = np.zeros([self.n_x, self.n_x])
    BetaInv[:self.n_s, :self.n_s] = self.BetaInv
    covariance = BetaInv + self.GAinvG(G.cpu().numpy(), self.Linv)

    # predicted stdv
    get_diag = functorch.vmap(torch.diag, (0,))
    stdv = torch.sqrt(get_diag(torch.tensor(covariance))).cpu().numpy()

    return Y_predicted, stdv, covariance

def plot(model, df_test):
    all_treatments = df_test.Treatments.values
    unique_treatments = np.unique(all_treatments)
    numspecies = len(species)
    
    # save true values and predictions
    true = []
    pred = []
    stdv = []
    spcs = []

    # pull a random community trajectory
    for treatment in unique_treatments:
        comm_inds = np.in1d(df_test['Treatments'].values, treatment)
        comm_data = df_test.iloc[comm_inds].copy()

        # make sure comm_data is sorted in chronological order
        comm_data.sort_values(by='Time', ascending=True, inplace=True)
        tspan = comm_data.Time.values

        # pull just the community data
        output_true = comm_data[species].values

        # run model using parameters
        x_test = np.copy(output_true[0, :])
        
        # predict end-point measured values
        output, output_stdv, _ = predict(model, x_test, tspan)
        true.append(output_true[-1])
        pred.append(output[-1, :len(species)])
        stdv.append(output_stdv[-1, :len(species)])
        spcs.append(species)

        # increase teval
        t_eval = np.linspace(0, tspan[-1]+5)
        steps = len(t_eval)
        output, output_stdv, _ = predict(model, x_test, t_eval)
        # output = model.predict_point(x_test, t_eval)

        # plot the results
        plt.figure(figsize=(9, 6))
        ylim = 0
        for i, sp in enumerate(species):
            out = output[:,i]
            out_true = output_true[:, i]
            if out[0] > 0:
                plt.scatter(tspan, out_true, color='C{}'.format(i))
                plt.plot(t_eval, out, label=f"Predicted {sp}", color='C{}'.format(i))
                plt.fill_between(t_eval, out-output_stdv[:, i], out+output_stdv[:, i], color='C{}'.format(i), alpha=0.2)
                
        plt.xlabel("Time (hr)")
        plt.ylabel("Species abundance")
        #plt.legend(loc='upper left')
        plt.title(f"{treatment}")
        # plt.savefig("figures/CRNN_mono_s.pdf")
        plt.show()
        
        # plot predictions of hidden variables
        for k in range(output.shape[-1] - (i+1)):
            plt.plot(t_eval, np.exp(output[:,i+1+k]), label=f"Predicted R{k+1}", color=f'C{k+1}')
            '''plt.fill_between(t_eval, 
                             np.exp(output[:,i+1+k] - stdv[:,i+1+k]), 
                             np.exp(output[:,i+1+k] + stdv[:,i+1+k]), 
                             color=f'C{k+1}', alpha=0.2)'''

        plt.legend()
        plt.ylabel("Resource concentration")
        plt.xlabel("Time (hr)")
        # plt.savefig("figures/CRNN_mono_r.pdf")
        plt.show()
        
    return np.concatenate(true), np.concatenate(pred), np.concatenate(stdv), np.concatenate(spcs)

In [None]:
plot(model, gLV_data)