In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt 
from scipy.stats import norm, linregress

import os
import time
import itertools

from glove.log_model import *

from sklearn.model_selection import KFold



In [2]:
# import file names
files = os.listdir("data/")
files = [f for f in files if "processed" in f]
files = files[:1]
files

['EXP0019_MS001_processed.csv']

# fit gLV models

In [3]:
def predict(df, species):
    
    # save measured and predicted values
    pred_species = []
    pred = []
    stdv = []
    true = []

    # pull just the community data
    test_data = process_df(df, species) 

    # plot the results
    for exp, t_span, Y_m, s_present in test_data:

        # increase evaluation time
        t_eval = np.linspace(t_span[0], t_span[-1])

        # predict 
        Y_p, Y_l, _ = model.predict(Y_m, t_eval, log=True)
        Y_std = Y_p - Y_l
        
        # set NaN to zero
        Y_p = np.nan_to_num(Y_p)
        Y_std = np.nan_to_num(Y_std)
        
        # unscale log 
        Y_m = np.einsum("ij,j->ij", np.exp(Y_m), s_present)

        ### append only end-point prediction results after initial condition! ###
        inds_pos = Y_m[-1,:] > 0 
        pred_species.append(np.array(species)[inds_pos])
        true.append(Y_m[-1,:][inds_pos])
        pred.append(Y_p[-1,:][inds_pos])
        stdv.append(Y_std[-1,:][inds_pos])
        
    # concatenate list
    pred_species = np.concatenate(pred_species)
    true = np.concatenate(true)
    pred = np.concatenate(pred)
    stdv = np.concatenate(stdv)
        
    return pred_species, true, pred, stdv

In [4]:
# run kfold for each file 
for file in files:
    
    # import data
    df = pd.read_csv(f"data/{file}")

    # determine species names 
    species = df.columns.values[2:]

    # separate mono culture data 
    # mono_df = pd.concat([df_i for name, df_i in df.groupby("Treatments") if "Mono" in name])
    dfs = [df_i for name, df_i in df.groupby("Treatments") if "Mono" not in name]

    # init kfold object
    kf = KFold(n_splits=20, shuffle=True, random_state=21)

    # keep track of all predictions
    all_pred_species = []
    all_true = []
    all_pred = []
    all_stdv = []

    # run Kfold 
    for train_index, test_index in kf.split(dfs):
        # get train df 
        train_df = pd.concat([dfs[i] for i in train_index])
        # train_df = pd.concat((mono_df, train_df))

        # get test df
        test_df = pd.concat([dfs[i] for i in test_index])

        # instantiate gLV fit 
        model = gLV(species, train_df) 

        # fit to data 
        model.fit()

        # plot fitness to data
        pred_species, true, pred, stdv = predict(test_df, species)

        # append predictions 
        all_pred_species = np.append(all_pred_species, pred_species)
        all_true = np.append(all_true, true)
        all_pred = np.append(all_pred, pred)
        all_stdv = np.append(all_stdv, stdv)

    # save prediction results to a .csv
    strain = file.split("_")[1]
    kfold_df = pd.DataFrame()
    kfold_df['species'] = all_pred_species
    kfold_df['true'] = all_true
    kfold_df['pred'] = all_pred
    kfold_df['stdv'] = all_stdv
    kfold_df.to_csv(f"kfold/{strain}_kfold_log.csv", index=False)
        
    # show prediction performance of individual species
    for sp in species:
        sp_inds = all_pred_species == sp
        R = linregress(all_true[sp_inds], all_pred[sp_inds]).rvalue
        plt.scatter(all_true[sp_inds], all_pred[sp_inds], label=f"{sp} " + "R={:.3f}".format(R))
        plt.errorbar(all_true[sp_inds], all_pred[sp_inds], yerr=all_stdv[sp_inds], 
                     fmt='.', capsize=3)

    plt.xlabel("Measured OD")
    plt.ylabel("Predicted OD")
    plt.legend()
    plt.title(strain)
    plt.savefig(f"kfold/{strain}_kfold_log.pdf", dpi=300)
    plt.show()



Total samples: 100, Updated regularization: 1.00e-05
Total weighted fitting error: 410.838
Total weighted fitting error: 308.875
Total weighted fitting error: 271.141
Total weighted fitting error: 247.717
Total weighted fitting error: 194.458
Total weighted fitting error: 181.162
Total weighted fitting error: 177.343
Total weighted fitting error: 170.099
Total weighted fitting error: 157.177
Total weighted fitting error: 142.462
Total weighted fitting error: 137.660
Total weighted fitting error: 128.928
Total weighted fitting error: 121.618
Total weighted fitting error: 120.840
Total weighted fitting error: 114.334
Total weighted fitting error: 111.436
Total weighted fitting error: 111.329
Total weighted fitting error: 107.368
Total weighted fitting error: 101.672
Total weighted fitting error: 101.430
Total weighted fitting error: 99.204
Total weighted fitting error: 95.390
Total weighted fitting error: 95.028
Total weighted fitting error: 92.085
Total weighted fitting error: 92.030
To

Total weighted fitting error: 393.753
Total weighted fitting error: 393.747
Total weighted fitting error: 393.742
Total weighted fitting error: 393.572
Total weighted fitting error: 393.569
Total weighted fitting error: 393.248
Total weighted fitting error: 393.247
Evidence 243.248
Updating hyper-parameters...
Total samples: 100, Updated regularization: 1.26e-01
Total weighted fitting error: 397.715
Total weighted fitting error: 397.646
Total weighted fitting error: 397.517
Total weighted fitting error: 397.291
Total weighted fitting error: 396.862
Total weighted fitting error: 396.110
Total weighted fitting error: 396.057
Total weighted fitting error: 395.574
Total weighted fitting error: 395.565
Total weighted fitting error: 395.234
Total weighted fitting error: 394.712
Total weighted fitting error: 394.703
Total weighted fitting error: 394.632
Total weighted fitting error: 394.624
Total weighted fitting error: 394.339
Total weighted fitting error: 394.337
Total weighted fitting erro

Total samples: 100, Updated regularization: 3.84e-01
Total weighted fitting error: 399.054
Total weighted fitting error: 399.021
Total weighted fitting error: 398.968
Total weighted fitting error: 398.886
Total weighted fitting error: 398.740
Total weighted fitting error: 398.525
Total weighted fitting error: 398.525
Total weighted fitting error: 398.510
Total weighted fitting error: 398.488
Total weighted fitting error: 398.483
Total weighted fitting error: 398.446
Total weighted fitting error: 398.402
Total weighted fitting error: 398.399
Total weighted fitting error: 398.398
Total weighted fitting error: 398.396
Total weighted fitting error: 398.393
Total weighted fitting error: 398.391
Total weighted fitting error: 398.390
Total weighted fitting error: 398.390
Total weighted fitting error: 398.384
Total weighted fitting error: 398.384
Evidence 268.558
Updating hyper-parameters...
Total samples: 100, Updated regularization: 3.83e-01
Total weighted fitting error: 398.718
Evidence 269

Total weighted fitting error: 395.511
Total weighted fitting error: 395.510
Total weighted fitting error: 395.500
Total weighted fitting error: 395.487
Total weighted fitting error: 395.486
Total weighted fitting error: 395.473
Total weighted fitting error: 395.471
Total weighted fitting error: 395.455
Total weighted fitting error: 395.452
Total weighted fitting error: 395.448
Total weighted fitting error: 395.441
Total weighted fitting error: 395.437
Total weighted fitting error: 395.437
Total weighted fitting error: 395.433
Total weighted fitting error: 395.426
Total weighted fitting error: 395.425
Total weighted fitting error: 395.424
Total weighted fitting error: 395.416
Total weighted fitting error: 395.415
Total weighted fitting error: 395.413
Total weighted fitting error: 395.410
Total weighted fitting error: 395.408
Total weighted fitting error: 395.404
Total weighted fitting error: 395.403
Total weighted fitting error: 395.402
Total weighted fitting error: 395.400
Total weight

Total weighted fitting error: 398.340
Total weighted fitting error: 398.256
Total weighted fitting error: 398.168
Total weighted fitting error: 398.158
Total weighted fitting error: 398.141
Total weighted fitting error: 398.120
Total weighted fitting error: 398.119
Total weighted fitting error: 398.119
Total weighted fitting error: 398.118
Total weighted fitting error: 398.116
Total weighted fitting error: 398.116
Total weighted fitting error: 398.116
Total weighted fitting error: 398.116
Total weighted fitting error: 398.116
Total weighted fitting error: 398.116
Total weighted fitting error: 398.116
Total weighted fitting error: 398.116
Evidence 313.113
Updating hyper-parameters...
Total samples: 100, Updated regularization: 1.04e+00
Total weighted fitting error: 398.453
Total weighted fitting error: 398.411
Total weighted fitting error: 398.361
Total weighted fitting error: 398.320
Total weighted fitting error: 398.317
Total weighted fitting error: 398.311
Total weighted fitting erro

Total weighted fitting error: 394.057
Total weighted fitting error: 394.023
Total weighted fitting error: 393.962
Total weighted fitting error: 393.928
Total weighted fitting error: 393.601
Total weighted fitting error: 393.599
Total weighted fitting error: 393.203
Total weighted fitting error: 393.195
Total weighted fitting error: 393.125
Total weighted fitting error: 392.503
Total weighted fitting error: 392.497
Total weighted fitting error: 392.485
Total weighted fitting error: 392.465
Total weighted fitting error: 392.432
Total weighted fitting error: 392.372
Total weighted fitting error: 392.371
Total weighted fitting error: 391.606
Total weighted fitting error: 391.591
Total weighted fitting error: 391.564
Total weighted fitting error: 391.523
Total weighted fitting error: 391.447
Total weighted fitting error: 391.445
Total weighted fitting error: 391.115
Total weighted fitting error: 391.108
Total weighted fitting error: 391.095
Total weighted fitting error: 391.073
Total weight

Total samples: 100, Updated regularization: 1.67e-01
Total weighted fitting error: 398.060
Total weighted fitting error: 398.034
Total weighted fitting error: 397.989
Total weighted fitting error: 397.936
Total weighted fitting error: 397.929
Total weighted fitting error: 397.878
Total weighted fitting error: 397.818


KeyboardInterrupt: 

In [5]:
for file in files:
    strain = file.split("_")[1]
    kfold_df = pd.read_csv(f"kfold/{strain}_kfold_log.csv")
        
    all_pred_species = kfold_df['species'].values
    all_true = kfold_df['true'].values 
    all_pred = kfold_df['pred'].values
    all_stdv = kfold_df['stdv'].values
        
    R_overall = linregress(all_true, all_pred).rvalue
        
    # show prediction performance of individual species
    for sp in species:
        sp_inds = all_pred_species == sp
        R = linregress(all_true[sp_inds], all_pred[sp_inds]).rvalue
        plt.scatter(all_true[sp_inds], all_pred[sp_inds], label=f"{sp} " + "R={:.3f}".format(R))
        plt.errorbar(all_true[sp_inds], all_pred[sp_inds], yerr=all_stdv[sp_inds], 
                     fmt='.', capsize=3)

    plt.xlabel("Measured OD")
    plt.ylabel("Predicted OD")
    plt.legend()
    plt.title(strain + " R={:.2f}".format(R_overall))
    plt.show()

FileNotFoundError: [Errno 2] No such file or directory: 'kfold/MS001_kfold_log.csv'