In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import linregress, spearmanr

from armored.models import *
from armored.preprocessing import *

from sklearn.model_selection import KFold

import re
import itertools

from tqdm import tqdm

import shap

colors = [
    "#1f77b4",  # Blue
    "#ff7f0e",  # Orange
    "#2ca02c",  # Green
    "#d62728",  # Red
    "#9467bd",  # Purple
    "#8c564b",  # Brown
    "#e377c2",  # Pink
    "#7f7f7f",  # Gray
    "#bcbd22",  # Olive
    "#17becf",  # Teal
    "#9edae5",  # Light Blue
    "#c7c7c7",  # Light Gray
    "#c49c94",  # Light Red
    "#98df8a",  # Light Green
    "#f7b6d2"   # Light Pink
]

In [3]:
# import community data
df_comm = pd.read_csv("Data/arc_allcomm.csv")

# import monoculture data
df_mono = pd.read_csv("Data/arc_allmono.csv")

# combine data for Fit
df = pd.concat((df_comm, df_mono))

# define species and pH
species = ['AC_OD', 'BA_OD', 'BC_OD', 'BL_OD', 'BT_OD', 'BV_OD', 
           'CC_OD', 'DF_OD', 'wt_OD', 'delarc_OD', 'parc-_OD', 'parc+_OD']
controls = []
metabolites = ['pH']

# concatenate all observed and all system variables 
observed = np.concatenate((np.array(species), np.array(metabolites)))
system_variables = np.concatenate((np.array(species), np.array(metabolites), np.array(controls)))
system_variables

array(['AC_OD', 'BA_OD', 'BC_OD', 'BL_OD', 'BT_OD', 'BV_OD', 'CC_OD',
       'DF_OD', 'wt_OD', 'delarc_OD', 'parc-_OD', 'parc+_OD', 'pH'],
      dtype='<U32')

In [4]:
# average replicates in train set for fitting data transform
df_copy = df.copy()
df_copy['Experiments'] = [re.split(r'(_\d+)', exp_name, maxsplit=1)[0] for exp_name in df_copy.Experiments.values]
df_avg = []
for exp_name, df_exp in df_copy.groupby("Experiments"):
    df_groups = df_exp.groupby("Time")
    df_avg_i = df_groups[system_variables].mean().reset_index()
    df_avg_i.insert(0, "Experiments", [exp_name]*df_avg_i.shape[0])
    df_avg.append(df_avg_i)
df_avg = pd.concat(df_avg)

# scale data 
# scaler = MinQuantileScaler(observed, system_variables, quantile=.75)
scaler = MinMaxScaler(observed, system_variables)

scaler.fit(df_avg)
df_scaled = scaler.transform(df.copy())
df_avg_scaled = scaler.transform(df_avg.copy())

# format data into matrix [n_samples, n_timepoints, dt+n_outputs+n_controls]
data = format_data(df, species, metabolites, controls, observed=observed)
data_avg = format_data(df_avg, species, metabolites, controls, observed=observed)

data_scaled = format_data(df_scaled, species, metabolites, controls, observed=observed)
data_avg_scaled = format_data(df_avg_scaled, species, metabolites, controls, observed=observed)

In [5]:
# instantiate model
brnn = miRNN(n_species=len(species), 
             n_metabolites=len(metabolites), 
             n_controls=len(controls), 
             n_hidden=16)

In [6]:
# fit model
brnn.fit(data_scaled, alpha_0=1e-3, evd_tol=1e-3)

Total measurements: 6156, Number of parameters: 717, Initial regularization: 1.00e-03
Loss: 1110.632, Residuals: -0.00695
Loss: 1051.771, Residuals: -0.02609
Loss: 967.924, Residuals: -0.01695
Loss: 926.150, Residuals: -0.00036
Loss: 901.454, Residuals: -0.00023
Loss: 844.330, Residuals: -0.01079
Loss: 823.322, Residuals: 0.00133
Loss: 680.525, Residuals: 0.00287
Loss: 678.264, Residuals: -0.00040
Loss: 661.246, Residuals: 0.00074
Loss: 631.117, Residuals: 0.00197
Loss: 629.958, Residuals: -0.00033
Loss: 586.862, Residuals: 0.00133
Loss: 584.798, Residuals: -0.00152
Loss: 565.582, Residuals: -0.00076
Loss: 534.548, Residuals: 0.00109
Loss: 522.714, Residuals: 0.00179
Loss: 521.163, Residuals: -0.00046
Loss: 518.238, Residuals: -0.00042
Loss: 513.780, Residuals: -0.00020
Loss: 508.634, Residuals: 0.00082
Loss: 508.387, Residuals: -0.00009
Loss: 499.287, Residuals: 0.00012
Loss: 489.160, Residuals: 0.00000
Loss: 487.991, Residuals: 0.00048
Loss: 481.780, Residuals: 0.00165
Loss: 478.813,

In [7]:
# concatenate data points
Xs = []
all_exp_names = []

for (T, X, U, Y, exp_names) in data_scaled:
    
    all_exp_names.append(exp_names)
    for xi, ui in zip(X, U):
        
        # append design condition
        Xs.append(np.append(xi, ui[0]))
        
# stack 
X = np.stack(Xs)  
all_exp_names = np.concatenate(all_exp_names)

In [8]:
# all conditions without ecoli
no_ecoli_idx = np.sum(X[:, -5:-1], 1) == 0
X_no_ecoli = X[no_ecoli_idx]
no_ecoli_exp_names = all_exp_names[no_ecoli_idx]

In [9]:
# species that aren't ecoli
not_ecoli = ['AC_OD', 'BA_OD', 'BC_OD', 'BL_OD', 'BT_OD', 'BV_OD', 'CC_OD', 'DF_OD', 'pH']

# set of ecoli strains
ecoli_strains =  ['wt_OD', 'delarc_OD', 'parc-_OD', 'parc+_OD']

# loop over ecoli strains
for ecoli_strain in ecoli_strains:

    # species + ecoli strain
    species_and_strain = not_ecoli + [ecoli_strain]
    
    # loop over receiver species
    for receiver in species_and_strain:
        
        # index of target species
        i = list(system_variables).index(receiver)
    
        # create wrapper for brnn to match SHAP model 
        def model(X):

            # matrix of predictions over time
            U = np.empty(shape=(len(X), 5, 0))
            Y = brnn.forward_batch(brnn.params, X, U)

            # return endpoint species predictions
            return Y[:, -1, i]

        # matrix of conditions to explain includes all conditions without ecoli strain
        # and all conditions with just the one ecoli strain
        ecoli_idx = list(system_variables).index(ecoli_strain)
        strain_samples = X[:, ecoli_idx] > 0
        X_strain = np.concatenate((X[strain_samples], X_no_ecoli), axis=0)
        strain_exp_names = np.append(all_exp_names[strain_samples], no_ecoli_exp_names)
        
        # compute the SHAP values for the model
        explainer = shap.Explainer(model, X_strain)
        shap_values = explainer(X_strain)

        # init df to save shap values
        df_sensitivity = pd.DataFrame()
        df_sensitivity["Experiments"] = strain_exp_names
        
        # loop over affector species
        for j, affector in enumerate(system_variables):
            
            # only care about current ecoli strain
            if affector in species_and_strain:
                
                # name of interaction edge
                interaction_name = receiver + "<--" + affector

                # add shap values
                df_sensitivity[interaction_name] = shap_values.values[:, j]

        # add exp conditions
        for j, affector in enumerate(system_variables):
            df_sensitivity[affector] = X_strain[:, j]

        # save df 
        df_sensitivity.to_csv(f"insights/{ecoli_strain}/{receiver}_shap.csv", index=False)

PermutationExplainer explainer: 203it [00:15,  6.01it/s]                                            
