# Boilerplate (imports)

In [None]:
import os
import sys 
import cobra
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
import torch
import scipy
import os.path
import functools
import timeit
import gurobipy # make sure some solver backend is available for iMAT/FBA/MoMA
from projection_methods import MoMAWrapper, FbaProjection, FbaProjectionLowMidConfidence, FbaProjectionHighMidConfidence, FBAWrapper, IMATWrapper

In [None]:
# Set up gurobi or other optimizer license file if needed
# os.environ['GRB_LICENSE_FILE'] = "/path/to/your/gurobi.lic"

# Test with iid noise to steady-state data

## Generate data

In [None]:
%%time

reaction_size_cutoff = 500 # > 10600 for all models in model_organism_* variables
n_samples_per_graph = 100
base_noise_power = 1
base_frac_known_reactions = 0.1
device = torch.device("cuda")

models = list()
As = list()
ss_bases = list()
ss_fluxes = list()
perturbed_fluxes = list()
names = list()
masks = list()

data_dir = 'synthetic_data_experiment_files/data' # make sure this has the models specified below, in sbml format. 
caching_dir = 'synthetic_data_experiment_files/caching'
export_path = 'synthetic_data_experiment_files/outputs'
    
model_organism_model_names = ['e_coli_core', 'iML1515', 'iND750', 'iMM1415', 'RECON1', 'Recon3D']
model_organism_organisms = ['e_coli_core', 'e_coli_K_12', 'yeast', 'mouse', 'human_1', 'human_3d']
model_name_to_organism = {k: v for k, v in zip(model_organism_model_names, model_organism_organisms)}
# ordered_organisms = []

for i, model_filename in enumerate(os.listdir(data_dir)):
    name = model_filename.split('.')[0]
    if name not in model_organism_model_names:
        continue
    model = cobra.io.read_sbml_model(os.path.join(data_dir, model_filename))
    if len(model.reactions) > reaction_size_cutoff:
        continue    
    models.append(model)
    if os.path.exists(os.path.join(caching_dir, name + "_cached_A.npy")):
        print(i, name, "(cached)")
        A = np.load(os.path.join(caching_dir, name + "_cached_A.npy"))
        basis = np.load(os.path.join(caching_dir, name + "_cached_basis.npy")) 
    else:
        print(i, name)
        A = cobra.util.create_stoichiometric_matrix(model)
        basis = scipy.linalg.null_space(A) # a reactions X kernel-dimension array
        np.save(os.path.join(caching_dir, name + "_cached_A"), A)
        np.save(os.path.join(caching_dir, name + "_cached_basis"), basis) 
        
    print("Model shape: {}".format(A.shape))
    As.append(A)
    names.append(model_filename.split(".")[0])

    # A = A.astype(float)

    print("ss basis shape:", basis.shape)
    
    ss_bases.append(basis)
    
    coefficients = np.random.random(size=(basis.shape[1], n_samples_per_graph))
    steady_states = (basis @ coefficients).transpose() # a samples X reactions array.
    ss_fluxes.append(steady_states)
    
    known_indices = np.array(sorted(np.random.choice(np.array(list(range(A.shape[1]))), size=math.ceil(base_frac_known_reactions * A.shape[1]), replace=False)))
    mask = np.zeros(A.shape[1],dtype=bool)
    mask[known_indices] = True    

    noised_steady_states = [ss * (1 + base_noise_power * 2 * (np.random.random(size=ss.shape) - 0.5))
                            for ss in steady_states]
    blanked_noised_steady_states = [np.where(mask, ss, 0) for ss in noised_steady_states]

    blanked_noised_steady_states = np.array(blanked_noised_steady_states)
    perturbed_fluxes.append(blanked_noised_steady_states)
    masks.append(mask)
    
    
ys = [torch.tensor(y.astype(np.float64)) for y in ss_fluxes]
Xs = [torch.tensor(X.astype(np.float64)) for X in perturbed_fluxes]

n_graphs = len(As)

In [None]:
for model in models:
    # Open up all exchanges and blocked reactions with high bounds
    for r in model.reactions:
        bounds = [r.bounds[0], r.bounds[1]]
        if bounds[0] < 0:
            bounds[0] = -1000
        if bounds[1] > 0:
            bounds[1] = 1000
        # if bounds[1] == 1000:
        #     bounds[1] = 10
        # if bounds[0] == -1000:
        #     bounds[0] = -10
        if bounds == (0, 0):
            bounds = -1000, 1000
        # if "ex" in r.id.lower(): 
        #     bounds = (-10, 10)
        r.bounds = bounds

In [None]:
# Generate approximate bounds to be used for FBA and iterative projection versions
ls = []
us = []
for model_index, model, X, mask in zip(list(range(len(models))), models, Xs, masks):
    print(model_index)
    l = np.zeros(shape=(X.shape[0], X.shape[1]))
    u = np.zeros(shape=(X.shape[0], X.shape[1]))
    for i, sample in enumerate(X):
        # Change bounds to roughly the measured values for measured indices
        for j, reaction in enumerate(model.reactions):
            if mask[j]:
                if sample[j] > 0:
                    l[i, j] = ((1 - base_noise_power) * sample[j]).item()
                    u[i, j] = ((1 + base_noise_power) * sample[j]).item()
                else:
                    l[i, j] = ((1 + base_noise_power) * sample[j].item())
                    u[i, j] = ((1 - base_noise_power) * sample[j].item())
            else:
                l[i, j] = model.reactions[j].bounds[0]
                u[i, j] = model.reactions[j].bounds[1]

    l = torch.tensor(l, dtype=torch.float)
    u = torch.tensor(u, dtype=torch.float)
    ls.append(l)
    us.append(u)

## Test runtime of projection vs FBA, on single sample and multi sample data and (for projection) using GPU

In [None]:
# %%time
pinv_size_threshold = 10000
n_timit_nonsetup_repeats = 5
acond = 1e-2
rcond = 1e-2

methods = [
          FBAWrapper,
          FbaProjection, 
          FbaProjectionLowMidConfidence,
          FbaProjectionHighMidConfidence,
          IMATWrapper,
          MoMAWrapper
            ]
method_names = [m.__repr__(None) for m in methods]


records = []
for i in range(len(models)):
    print("model {}, size {}".format(i, len(models[i].reactions)))
    model = models[i]
    organism_name = model_name_to_organism[model.id]
    n_rxns = len(model.reactions)
    obj_coefs = {rxn.id: rxn.objective_coefficient for rxn in model.reactions if rxn.objective_coefficient != 0}
    print(obj_coefs)
    print([r.id for r in model.reactions if 'biomass' in r.id.lower()])
    if len(obj_coefs) == 0:
        growth_related_reactions = [r.id for r in model.reactions if 'biomass' in r.id.lower()]
        assert len(growth_related_reactions) == 1
        obj_coefs = {growth_related_reactions[0]: 1}
    assert len(obj_coefs) == 1 and all([v == 1 for k, v in obj_coefs.items()])
    objective_reaction_ids = obj_coefs.popitem()[0]
    
    for method, name in zip(methods, method_names):
        # success = False
        cur_records = []
        # while not success: # for rare SVD convergence problems
        print(name)
        # try:
        # Projection on one arbitrary sample and full matrix as batch
        A = As[i]
        # print("A nanfrac: ", np.isnan(A).mean().mean())
        # A_pinv = A_pinverses[i]
        basis = ss_bases[i]
        # l = torch.tensor(np.tile(l[r], (X.shape[0], 1)))
        # u = torch.tensor(np.tile(u[r], (X.shape[0], 1)))
        mask = masks[i]
        unknown_indices = [q for q in range(len(mask)) if ~(mask[q])]
        measured_indices = [q for q in range(len(mask)) if (mask[q])]
        X = Xs[i].to(dtype=torch.float)
        y = ys[i].to(dtype=torch.float)
        l = ls[i].to(dtype=torch.float)
        u = us[i].to(dtype=torch.float)
        l = torch.tensor(np.tile(l[0], (X.shape[0], 1)))
        u = torch.tensor(np.tile(u[0], (X.shape[0], 1)))

        if "pro" in name:
            
            timeit_results = timeit.repeat(functools.partial(method, model=model, stoichiometric_matrix=A, acond=acond, rcond=rcond,
                                                             unknown_indices=unknown_indices, 
                                                             steady_state_basis_matrix=basis, l_bounds=l, u_bounds=u, 
                                                             device=torch.device('cpu'), measured_indices=measured_indices,
                                                             objective_id=objective_reaction_ids), number=1, repeat=1)
            for t in timeit_results:
                cur_records.append({'organism': organism_name, 'reactions': n_rxns, 'method': name + "_setup", 'batch': 1, 'device': 'CPU', 'time': t})

            timeit_results = timeit.repeat(functools.partial(method, model=model, stoichiometric_matrix=A, acond=acond, rcond=rcond,
                                                             unknown_indices=unknown_indices, measured_indices=measured_indices,
                                                             steady_state_basis_matrix=basis, l_bounds=l, u_bounds=u, device=torch.device('cuda'), objective_id=objective_reaction_ids), number=1, repeat=1)
            for t in timeit_results:
                cur_records.append({'organism': organism_name, 'reactions': n_rxns, 'method': name + "_setup", 'batch': 1, 'device': 'GPU', 'time': t})

        proj = method(model=model, stoichiometric_matrix=A, unknown_indices=unknown_indices, measured_indices=measured_indices,
                      steady_state_basis_matrix=basis, acond=acond, rcond=rcond,
                      l_bounds=l, u_bounds=u, device=torch.device('cpu'),
                      objective_id=objective_reaction_ids)
        if 'pro' in name:
            gpu_proj = method(model=model, stoichiometric_matrix=A, unknown_indices=unknown_indices, measured_indices=measured_indices,
                          steady_state_basis_matrix=basis,  acond=acond, rcond=rcond,
                          l_bounds=l[[sample]], u_bounds=u[[sample]], device=torch.device('cuda'),
                          objective_id=objective_reaction_ids)

        for sample in range(X.shape[0]):
            print("sample {}/{}".format(sample, X.shape[0]))
            timeit_results = timeit.repeat(functools.partial(proj.forward, X[[sample],:], l[[sample]], u[[sample]]), number=1, repeat=n_timit_nonsetup_repeats)
            for t in timeit_results:
                cur_records.append({'organism': organism_name, 'reactions': n_rxns, 'method': name, 'batch': 1, 'device': 'CPU', 'time': t})

            if 'pro' in name:
                timeit_results = timeit.repeat(functools.partial(gpu_proj.forward, X[[sample],:].to(device=torch.device('cuda')), l[[sample]], u[[sample]]), number=1, repeat=n_timit_nonsetup_repeats)
                for t in timeit_results:
                    cur_records.append({'organism': organism_name, 'reactions': n_rxns, 'method': name, 'batch': 1, 'device': 'GPU', 'time': t})
        
        if 'pro' in name:
            proj = method(model=model, stoichiometric_matrix=A, unknown_indices=unknown_indices, measured_indices=measured_indices,
                          steady_state_basis_matrix=basis, acond=acond, rcond=rcond,
                          l_bounds=l, u_bounds=u, device=torch.device('cpu'),
                          objective_id=objective_reaction_ids)
            timeit_results = [t / X.shape[0] for t in timeit.repeat(functools.partial(proj.forward, X, l, u), number=1, repeat=n_timit_nonsetup_repeats)] 
            for t in timeit_results:
                cur_records.append({'organism': organism_name, 'reactions': n_rxns, 'method': name, 'batch': X.shape[0], 'device': 'CPU', 'time': t})
            
            # Projection on all samples, using GPU
            X = X.cuda()
            y = y.cuda()
            l = l.cuda()
            u = u.cuda()
            proj = method(model=model, stoichiometric_matrix=A, unknown_indices=unknown_indices, measured_indices=measured_indices,
                          steady_state_basis_matrix=basis, acond=acond, rcond=rcond,
                          l_bounds=l, u_bounds=u, device=torch.device('cuda'),
                          objective_id=objective_reaction_ids)
            timeit_results = [t / X.shape[0] for t in timeit.repeat(functools.partial(proj.forward, X, l, u), number=1, repeat=n_timit_nonsetup_repeats)]
            for t in timeit_results:
                cur_records.append({'organism': organism_name, 'reactions': n_rxns, 'method': name, 'batch': X.shape[0], 'device': 'GPU', 'time': t})
        # success = True
        records.extend(cur_records)
        # except Exception as e:
        #     print("Failed with method {}. Rerunning".format(name))
        #     print(e)
runtimes_df = pd.DataFrame(records)

In [None]:
runtimes_df.to_csv(os.path.join(export_path, "figure_recreation_data", "runtimes_df"))

In [None]:
sns.set_style("whitegrid")
# compare the runtimes of the different methods against n_rxns as an x axis
# first make a combined device X batch columns
runtimes_df['organism_device_batch'] = runtimes_df['organism'] + ', ' + runtimes_df['device'] + ', ' + runtimes_df['batch'].astype(str)
runtimes_df['device_batch'] = runtimes_df['device'] + ', ' + runtimes_df['batch'].astype(str)

In [None]:
# Now plot FBA vs FBApro, smooth it out with confidence bands
plt.figure(figsize=(20, 10))
sns.barplot(data=runtimes_df.loc[~runtimes_df["method"].str.contains('setup')], x='organism_device_batch', y='time', hue='method')
plt.title("FBA vs FBApro runtime")
plt.legend()
plt.yscale('log')
plt.xticks(rotation=45)
plt.show()
# sns.lineplot(data=runtimes_df, x='reactions', y='time', hue='method', style='device_batch', markers=True)

In [None]:
# Now plot FBA vs FBApro, smooth it out with confidence bands
plt.figure(figsize=(20, 10))

filtered = runtimes_df.loc[~runtimes_df["method"].str.contains('setup')]
filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")

filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])


sns.barplot(data=filtered, x='device_batch', y='time', hue='method')
plt.title("FBA vs FBApro runtime")
plt.legend()
plt.yscale('log')
# plt.xticks(rotation=45)
filename =  "all_devices_timing_methods.png"
plt.gcf().savefig(os.path.join(export_path, filename))
plt.show()
# sns.lineplot(data=runtimes_df, x='reactions', y='time', hue='method', style='device_batch', markers=True)

In [None]:
filter = ~runtimes_df["method"].str.contains("FBApro") | (runtimes_df["batch"] == 100)
runtimes_df['organism, #reactions'] = runtimes_df['organism'] + ", " + runtimes_df['reactions'].astype(str)

filtered = runtimes_df.loc[filter]
filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")

filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])


g = sns.catplot(data=filtered.sort_values(['method', 'organism']), x='organism, #reactions', y='time', hue='method',
               kind='bar', height=8, aspect=1.7)
sns.move_legend(g, loc="upper left", bbox_to_anchor=(-0.1, 1.2))
plt.title("FBA vs FBApro runtime")
plt.yscale('log')
plt.tight_layout()
filename = "timing_methods_different_models.png"
plt.gcf().savefig(os.path.join(export_path, filename), bbox_inches='tight')
plt.show()
# sns.lineplot(data=runtimes_df, x='reactions', y='time', hue='method', style='device_batch', markers=True)

In [None]:
sns.set(style="darkgrid", font_scale=2)

filter = ~runtimes_df["method"].str.contains("FBApro") | (runtimes_df["batch"] == 100)
runtimes_df['organism, #reactions'] = runtimes_df['organism'].str.replace("human_1", "Recon1") + ", " + runtimes_df['reactions'].astype(str)
filtered = runtimes_df.loc[filter]
filtered = filtered.loc[filtered['organism'] == 'human_1']
filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")

filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])


g = sns.catplot(data=filtered.sort_values(['method', 'organism']), x='organism, #reactions', y='time', hue='method',
           kind='bar', height=8, aspect=1.7)
sns.move_legend(g, loc="upper left", bbox_to_anchor=(0.1, 0.95))


# plt.title("FBA vs FBApro runtime")
# plt.legend()
plt.xlabel("")
plt.ylabel("Time (seconds)")
g.set(xticklabels=[])
plt.yscale('log')
filename =  "timing_methods.png"
plt.tight_layout()
# g._legend.set_title("Method")
# g._legend.set_bbox_to_anchor((1.12, 0.7))
plt.gcf().savefig(os.path.join(export_path, filename))
plt.show()
# sns.lineplot(data=runtimes_df, x='reactions', y='time', hue='method', style='device_batch', markers=True)


In [None]:
# Now plot FBA vs FBApro, smooth it out with confidence bands
data = runtimes_df.loc[runtimes_df["method"].str.contains('setup')]
data['method'] = data['method'].str.replace("_setup", "")
data['method'] = data['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")
data = data.loc[(data['organism'] == 'human_1') & (data['device'] == 'CPU')]
g = sns.barplot(data=data, x='device', y='time', hue='method')
# plt.title("FBA vs FBApro runtime")
plt.legend()
plt.xlabel("")
plt.ylabel("Time (seconds)")
g.set(xticklabels=[])
plt.yscale('log')
filename =  "timing_setup.png"
plt.gcf().savefig(os.path.join(export_path, filename))
plt.show()
# sns.lineplot(data=runtimes_df, x='reactions', y='time', hue='method', style='device_batch', markers=True)

In [None]:

# Now plot FBA vs FBApro, smooth it out with confidence bands
filtered = runtimes_df.sort_values('reactions').loc[runtimes_df["method"].str.contains('setup')]
filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")
filtered['method'] = filtered['method'].str.replace("_setup", "")
filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")
sns.barplot(data=filtered, 
            x='organism, #reactions', y='time', hue='method')
# plt.title("FBA vs FBApro runtime")
plt.legend()
plt.yscale('log')
filename =  "timing_setup_multi_organisms.png"
plt.gcf().savefig(os.path.join(export_path, filename))
plt.show()
# sns.lineplot(data=runtimes_df, x='reactions', y='time', hue='method', style='device_batch', markers=True)

## Test ammortized runtime of FBApro variants as function of number of samples

In [None]:
# Generate approximate bounds to be used for FBA and iterative projection versions
ls = []
us = []
for model_index, model, X, mask in zip(list(range(len(models))), models, Xs, masks):
    print(model_index)
    l = np.zeros(shape=(X.shape[0], X.shape[1]))
    u = np.zeros(shape=(X.shape[0], X.shape[1]))
    for i, sample in enumerate(X):
        # Change bounds to roughly the measured values for measured indices
        for j, reaction in enumerate(model.reactions):
            if mask[j]:
                if sample[j] > 0:
                    l[i, j] = ((1 - base_noise_power) * sample[j]).item()
                    u[i, j] = ((1 + base_noise_power) * sample[j]).item()
                else:
                    l[i, j] = ((1 + base_noise_power) * sample[j].item())
                    u[i, j] = ((1 - base_noise_power) * sample[j].item())
            else:
                l[i, j] = model.reactions[j].bounds[0]
                u[i, j] = model.reactions[j].bounds[1]

    l = torch.tensor(l, dtype=torch.float)
    u = torch.tensor(u, dtype=torch.float)
    ls.append(l)
    us.append(u)

In [None]:
# %%time
pinv_size_threshold = 1000
n_timit_nonsetup_repeats = 50
acond = 1e-2
rcond = 1e-2
n_samples_per_graph = 100

methods = [
          FBAWrapper,
          FbaProjection, 
          FbaProjectionLowMidConfidence,
          FbaProjectionHighMidConfidence,
          IMATWrapper,
          MoMAWrapper
            ]
method_names = [m.__repr__(None) for m in methods]

ammortized_records = []
for i in range(len(models)):
    print("model {}, size {}".format(i, len(models[i].reactions)))
    model = models[i]
    organism_name = model_name_to_organism[model.id]
    n_rxns = len(model.reactions)
    obj_coefs = {rxn.id: rxn.objective_coefficient for rxn in model.reactions if rxn.objective_coefficient != 0}
    print(obj_coefs)
    print([r.id for r in model.reactions if 'biomass' in r.id.lower()])
    if len(obj_coefs) == 0:
        growth_related_reactions = [r.id for r in model.reactions if 'biomass' in r.id.lower()]
        assert len(growth_related_reactions) == 1
        obj_coefs = {growth_related_reactions[0]: 1}
    assert len(obj_coefs) == 1 and all([v == 1 for k, v in obj_coefs.items()])
    objective_reaction_ids = obj_coefs.popitem()[0]
    
    for method, name in zip(methods, method_names):
        if "pro" not in name:
            continue
        # success = False
        cur_records = []
        # while not success: # for rare SVD convergence problems
        print(name)
        # try:
        # Projection on one arbitrary sample and full matrix as batch
        A = As[i]
        # print("A nanfrac: ", np.isnan(A).mean().mean())
        # A_pinv = A_pinverses[i]
        basis = ss_bases[i]
        # l = torch.tensor(np.tile(l[r], (X.shape[0], 1)))
        # u = torch.tensor(np.tile(u[r], (X.shape[0], 1)))
        mask = masks[i]
        unknown_indices = [q for q in range(len(mask)) if ~(mask[q])]
        measured_indices = [q for q in range(len(mask)) if (mask[q])]
        X = Xs[i].to(dtype=torch.float, device=torch.device("cuda"))
        y = ys[i].to(dtype=torch.float, device=torch.device("cuda"))
        l = ls[i].to(dtype=torch.float)
        u = us[i].to(dtype=torch.float)
        l = torch.tensor(np.tile(l[0], (X.shape[0], 1)), device=torch.device("cuda"))
        u = torch.tensor(np.tile(u[0], (X.shape[0], 1)), device=torch.device("cuda"))

        

        for samples_frac in np.linspace(start=0, stop=1, num=100)[1:]:
            n_samples = int(math.floor(n_samples_per_graph * samples_frac))  
            random_samples = np.random.randint(0, Xs[i].shape[0], size=(n_samples,))

            proj = method(model=model, stoichiometric_matrix=A, unknown_indices=unknown_indices, measured_indices=measured_indices,
                          steady_state_basis_matrix=basis, acond=acond, rcond=rcond,
                          l_bounds=l[random_samples], u_bounds=u[random_samples], device=torch.device('cuda'),
                          objective_id=objective_reaction_ids)

            timeit_results = timeit.repeat(functools.partial(proj.forward, X[random_samples,:], l[random_samples], u[random_samples]), number=1, repeat=n_timit_nonsetup_repeats)
            for t in timeit_results:
                cur_records.append({'organism': organism_name, 'reactions': n_rxns, 'method': name, 'batch': n_samples, 'time': t / n_samples})
            ammortized_records.extend(cur_records)
ammortized_runtimes_df = pd.DataFrame(ammortized_records)

In [None]:
ammortized_runtimes_df = pd.DataFrame(ammortized_records)

In [None]:
ammortized_runtimes_df.to_csv(os.path.join(export_path, "figure_recreation_data", "ammortized_runtimes_df"))

In [None]:
sns.set(style="darkgrid", font_scale=2)

filtered = ammortized_runtimes_df.copy()
filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")

filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                ])


sns.lineplot(filtered,
            x='batch',
            y='time',
            hue='method')
plt.yscale('log')

filename =  "timing_per_batch_size.png"
plt.gcf().savefig(os.path.join(export_path, filename))


In [None]:
sns.barplot(ammortized_runtimes_df,
            x='batch',
            y='time',
            hue='method')