# Boilerplate (imports)

In [None]:
import os
import sys 
import cobra
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import math
import torch
import scipy
import scipy.stats as stats
import sys
import os.path
import gurobipy # make sure some solver backend is available for iMAT/FBA/MoMA
from projection_methods import MoMAWrapper, FbaProjection, FbaProjectionLowMidConfidence, FbaProjectionHighMidConfidence, FBAWrapper, IMATWrapper
import util

In [None]:
# Set up gurobi or other optimizer license file if needed
# os.environ['GRB_LICENSE_FILE'] = "/path/to/your/gurobi.lic"

# Test with iid noise to steady-state data

## Generate data

In [None]:
%%time

reaction_size_cutoff = 20000 # > 10600 for all models in model_organism_* variables
n_samples_per_graph = 10
base_noise_power = 1
base_frac_known_reactions = 0.1
device = torch.device("cpu")

models = list()
As = list()
ss_bases = list()
ss_fluxes = list()
perturbed_fluxes = list()
names = list()
masks = list()

data_dir = 'synthetic_data_experiment_files/data' # make sure this has the models specified below, in sbml format. 
caching_dir = 'synthetic_data_experiment_files/caching'
export_path = 'synthetic_data_experiment_files/outputs'

model_organism_model_names = ['e_coli_core', 'iML1515', 'iND750', 'iMM1415', 'RECON1', 'Recon3D']
model_organism_organisms = ['e_coli_core', 'e_coli_K_12', 'yeast', 'mouse', 'human_1', 'human_3d']
model_name_to_organism = {k: v for k, v in zip(model_organism_model_names, model_organism_organisms)}
# ordered_organisms = []

for i, model_filename in enumerate(os.listdir(data_dir)):
    name = model_filename.split('.')[0]
    if name not in model_organism_model_names:
        continue
    model = cobra.io.read_sbml_model(os.path.join(data_dir, model_filename))
    if len(model.reactions) > reaction_size_cutoff:
        continue    
    models.append(model)
    if os.path.exists(os.path.join(caching_dir, name + "_cached_A.npy")):
        print(i, name, "(cached)")
        A = np.load(os.path.join(caching_dir, name + "_cached_A.npy"))
        basis = np.load(os.path.join(caching_dir, name + "_cached_basis.npy")) 
    else:
        print(i, name)
        A = cobra.util.create_stoichiometric_matrix(model)
        basis = scipy.linalg.null_space(A) # a reactions X kernel-dimension array
        np.save(os.path.join(caching_dir, name + "_cached_A"), A)
        np.save(os.path.join(caching_dir, name + "_cached_basis"), basis) 
        
    print("Model shape: {}".format(A.shape))
    As.append(A)
    names.append(model_filename.split(".")[0])

    # A = A.astype(float)

    print("ss basis shape:", basis.shape)
    
    ss_bases.append(basis)
    
    coefficients = np.random.random(size=(basis.shape[1], n_samples_per_graph))
    steady_states = (basis @ coefficients).transpose() # a samples X reactions array.
    ss_fluxes.append(steady_states)
    
    known_indices = np.array(sorted(np.random.choice(np.array(list(range(A.shape[1]))), size=math.ceil(base_frac_known_reactions * A.shape[1]), replace=False)))
    mask = np.zeros(A.shape[1],dtype=bool)
    mask[known_indices] = True    

    noised_steady_states = [ss * np.where(mask, 1, (1 + base_noise_power * 2 * (np.random.random(size=ss.shape) - 0.5))) for ss in steady_states]
    
    # now only
    
    noised_steady_states = np.array(noised_steady_states)
    perturbed_fluxes.append(noised_steady_states)
    masks.append(mask)
    
    
ys = [torch.tensor(y.astype(np.float64)) for y in ss_fluxes]
Xs = [torch.tensor(X.astype(np.float64)) for X in perturbed_fluxes]

n_graphs = len(As)

In [None]:
for model in models:
    # Open up all exchanges and blocked reactions with high bounds
    for r in model.reactions:
        bounds = [r.bounds[0], r.bounds[1]]
        if bounds[0] < 0:
            bounds[0] = -1000
        if bounds[1] > 0:
            bounds[1] = 1000
        # if bounds[1] == 1000:
        #     bounds[1] = 10
        # if bounds[0] == -1000:
        #     bounds[0] = -10
        if bounds == (0, 0):
            bounds = -1000, 1000
        # if "ex" in r.id.lower(): 
        #     bounds = (-10, 10)
        r.bounds = bounds

In [None]:
# Generate approximate bounds to be used for FBA and iterative projection versions
ls = []
us = []
epsilon = 1e-5
for model_index, model, X, Y, mask in zip(list(range(len(models))), models, Xs, ys, masks):
    print(model_index)
    # print(model_index, len(model.reactions), len(models), X.shape, Y.shape, mask.shape)
    l = np.zeros(shape=(X.shape[0], X.shape[1]))
    u = np.zeros(shape=(X.shape[0], X.shape[1]))
    for i, sample in enumerate(Y):
        # Change bounds to the gap that was possible during the noise generation if unknown, and roughly the measured values for measured indices
        
        sign = torch.sign(sample)
        l[i, :] = (1 - base_noise_power * sign * ~mask) * sample - epsilon  
        u[i, :] = (1 + base_noise_power * sign * ~mask) * sample + epsilon  
        
    l = torch.tensor(l, dtype=torch.float, device=device)
    u = torch.tensor(u, dtype=torch.float, device=device)
    ls.append(l)
    us.append(u)

## Apply projection and FBA on zeroed data

In [None]:
%%time

acond=1e-3
rcond=1e-3
    
methods = [
          FbaProjection, 
          FbaProjectionLowMidConfidence,
          FbaProjectionHighMidConfidence,
          FBAWrapper,
          IMATWrapper,
          MoMAWrapper
            ]
method_names = [m.__repr__(None) for m in methods]

containers = [[] for method in methods]


for i, model, A, X, mask, l, u, basis in zip(range(len(models)), models, As, Xs, masks, ls, us, ss_bases):
    print(i)

    obj_coefs = {rxn.id: rxn.objective_coefficient for rxn in model.reactions if rxn.objective_coefficient != 0}
    print(obj_coefs)
    print([r.id for r in model.reactions if 'biomass' in r.id.lower()])
    if len(obj_coefs) == 0:
        growth_related_reactions = [r.id for r in model.reactions if 'biomass' in r.id.lower()]
        assert len(growth_related_reactions) == 1
        obj_coefs = {growth_related_reactions[0]: 1}
    assert len(obj_coefs) == 1 and all([v == 1 for k, v in obj_coefs.items()])
    objective_reaction_ids = obj_coefs.popitem()[0]

    
    with model:
        # obj = linear_reaction_coefficients(model)
        # assert len(obj) == 1
        # objective_id = obj.popitem()[0].id
        unknown_indices = [q for q in range(len(mask)) if ~(mask[q])]
        measured_indices = [q for q in range(len(mask)) if mask[q]]
        for method, container in zip(methods, containers): 
            proj = method(stoichiometric_matrix=A, model=model, unknown_indices=unknown_indices, measured_indices=measured_indices, 
                          steady_state_basis_matrix=basis, objective_id=objective_reaction_ids,
                          l_bounds=l, u_bounds=u, device=device,
                          dtype=torch.float)
            y_pred = proj.forward(X.to(dtype=torch.float, device=device), l_bounds=l, u_bounds=u)
            container.append(y_pred)
            del proj

In [None]:
entries = []

for i, y in enumerate(ys):
    print(i)
    for method_name, container in zip(method_names, containers):
        for metric in ['spearmanr', 'l2']:
            if metric == 'l2':
                continue
            corr_func = stats.spearmanr if metric == 'spearmanr' else lambda a, b, nan_policy='omit': (-scipy.linalg.norm((a - b).values, axis=0), 0)
            for masking in ['all', 'known', 'unknown']:
                if masking == 'all':
                    continue
                    # mask = torch.ones_like(y[0], dtype=torch.bool)
                elif masking == 'known':
                    mask = torch.tensor(masks[i], dtype=torch.bool)
                else:
                    mask = ~torch.tensor(masks[i], dtype=torch.bool)
                for axis in ['per sample', 'per reaction']:
                    data = y[:, mask]
                    preds = container[i][:, mask]
                    if axis == 'per sample':
                        # continue
                        data = data.transpose(0, 1)
                        preds = preds.transpose(0, 1)
                        # assert sizes match and make them into dfs
                    assert data.shape == preds.shape
                    data = pd.DataFrame(data.cpu().numpy())
                    preds = pd.DataFrame(preds.cpu().numpy())
                    corrs, _, full_corr, _ = util.run_correlation_tests(data, preds, corr_func=corr_func, parallel=False)
                    # function is run per columns, transpose if doing per sample.
                    if axis == 'per sample':
                        assert len(corrs) == y[:, mask].shape[0]
                    else:
                        assert len(corrs) == y[:, mask].shape[1]
                    for val in corrs.values():
                        entries.append({'model': model_name_to_organism[models[i].id], 'method': method_name, 'metric': metric, 'axis': axis, 'masking': masking, 'value': val})

In [None]:
# convert entries to a datafram
results = pd.DataFrame.from_records(entries)
results

In [None]:
sns.set_style('whitegrid')
filtered = results.loc[(results['metric'] == 'spearmanr') &  (results['axis'] == 'per reaction') & (results['masking'] == 'unknown')]

filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed").str.replace("Reference Values", "Mid-bound Benchmark")
filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])

plt.tight_layout()
sns.catplot(data=filtered.sort_values(by=['method', 'model']), hue="method", x='model', y="value",
                          height=8,
               aspect=1.7,
               kind='bar')
filename =  "exact_rxn_unknown.png"
plt.gcf().savefig(os.path.join(export_path, filename))

In [None]:
sns.set_style('darkgrid')
sns.set(style="darkgrid", font_scale=2)
filtered = results.loc[(results['metric'] == 'spearmanr') & (results['masking'] == 'unknown')]
filtered = filtered.loc[filtered['model'] == 'human_1']
filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed").str.replace("Reference Values", "Mid-bound Benchmark")
filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])


g=sns.catplot(data=filtered.sort_values(by=['method', 'model']), hue="method", x='axis', y="value",
               height=8,
               aspect=1.7,
               kind='bar'
           )
plt.ylabel("Spearman Correlation")
g._legend.set_title("Method")
g._legend.set_bbox_to_anchor((1.12, 0.7))
plt.tight_layout()

filename =  "exact_rxn_unknown.png"
plt.gcf().savefig(os.path.join(export_path, filename))

# Check dependences of performance on fraction of unknown reactions

In [None]:
methods = [
          FbaProjection, 
          FbaProjectionLowMidConfidence,
          FbaProjectionHighMidConfidence,
          FBAWrapper,
          IMATWrapper,
          MoMAWrapper
            ]
method_names = [m.__repr__(None) for m in methods]

In [None]:
%%time
known_reactions_steps = list(np.linspace(start=0.05, stop=0.5, num=10, endpoint=False))

acond=1e-2
rcond=1e-2

known_varying_entries = list()
# known_varying_entries = [e for e in known_varying_entries if e['method'] != 'FBAproHighMid']
for i, known_frac in enumerate(known_reactions_steps):
    print("known frac: {}".format(known_reactions_steps[i]))
    ss_fluxes = list()
    perturbed_fluxes = list()
    names = list()
    masks = list()

    coefficients = [np.random.random(size=(basis.shape[1], n_samples_per_graph)) for basis in ss_bases]
    steady_states = [(basis @ coef).transpose() for basis, coef in zip(ss_bases, coefficients)]     
    known_indices_list = [sorted(np.random.choice(np.array(list(range(nss.shape[1]))), size=math.ceil(known_frac * nss.shape[1]), 
                                      replace=False))
                         for nss in steady_states]
    masks = list()
    noised_steady_states = list()
    for ss, idxs in zip(steady_states, known_indices_list):
        mask = np.zeros(ss.shape[1],dtype=bool)
        mask[idxs] = True
        # print("mask sum:", sum(mask))
        masks.append(mask)
        noised = list()
        # print("ss shape:", ss.shape)
        for s in ss:
            # print("s shape:", s.shape)
            ns = s * np.where(mask, 1, (1 + base_noise_power * 2 * (np.random.random(size=s.shape) - 0.5)))
            noised.append(ns)
        noised_steady_states.append(noised)
    
    ys = [torch.tensor(y.astype(np.float64), device=device) for y in steady_states]
    Xs = [torch.tensor(np.array(X).astype(np.float64), device=device) for X in noised_steady_states]
        
    # Generate approximate bounds to be used for FBA and iterative projection versions
    ls = []
    us = []
    for model_index, model, X, Y, mask in zip(list(range(len(models))), models, Xs, ys, masks):
        l = torch.zeros((X.shape[0], X.shape[1]), dtype=torch.float, device=device)
        u = torch.zeros((X.shape[0], X.shape[1]), dtype=torch.float, device=device)
        for i, sample in enumerate(Y):
            # Change bounds to the gap that was possible during the noise generation if unknown, and roughly the measured values for measured indices
            sign = torch.sign(sample)
            mask = torch.tensor(mask, dtype=torch.bool, device=device)
            l[i, :] = (1 - base_noise_power * sign * ~mask) * sample - epsilon  
            u[i, :] = (1 + base_noise_power * sign * ~mask) * sample + epsilon  
    
        ls.append(l)
        us.append(u)
        
    containers = [[] for method in methods]

    for model, A, basis, X, y, mask, l, u in zip(models, As, ss_bases, Xs, ys, masks, ls, us):
        obj_coefs = {rxn.id: rxn.objective_coefficient for rxn in model.reactions if rxn.objective_coefficient != 0}
        print(obj_coefs)
        print([r.id for r in model.reactions if 'biomass' in r.id.lower()])
        if len(obj_coefs) == 0:
            growth_related_reactions = [r.id for r in model.reactions if 'biomass' in r.id.lower()]
            assert len(growth_related_reactions) == 1
            obj_coefs = {growth_related_reactions[0]: 1}
        assert len(obj_coefs) == 1 and all([v == 1 for k, v in obj_coefs.items()])
        objective_reaction_ids = obj_coefs.popitem()[0]

        with model:
            unknown_indices = [q for q in range(len(mask)) if ~(mask[q])]
            measured_indices = [q for q in range(len(mask)) if mask[q]]

            for method, container in zip(methods, containers): 
                # print("A shape, X shape:", A.shape, X.shape)
                proj = method(model=model, stoichiometric_matrix=A,
                              unknown_indices=unknown_indices, measured_indices=measured_indices, steady_state_basis_matrix=basis, 
                              l_bounds=l, u_bounds=u, device=device, acond=acond, rcond=rcond,
                              objective_id=objective_reaction_ids,
                              dtype=torch.float)
                y_pred = proj.forward(X.to(dtype=torch.float, device=device), l_bounds=l, u_bounds=u)
                container.append(y_pred)
                del proj
    for i, y in enumerate(ys):
        print(i)
        for method_name, container in zip(method_names, containers):
            for metric in ['spearmanr', 'l2']:
                if metric == 'l2':
                    continue
                corr_func = stats.spearmanr if metric == 'spearmanr' else lambda a, b, nan_policy='omit': (-scipy.linalg.norm((a - b).values, axis=0), 0)
                for masking in ['all', 'known', 'unknown']:
                    if masking == 'all':
                        continue
                        mask = torch.ones_like(y[0], dtype=torch.bool, device=device)
                    elif masking == 'known':
                        mask = masks[i]
                    else:
                        mask = ~masks[i]
                    for axis in ['per sample', 'per reaction']:
                        # print("mask shape and sum:", mask.shape, sum(mask))
                        # data = y[:, mask]
                        preds = container[i][:, mask].cpu().numpy()
                        data = y[:, mask].cpu().numpy()
                        if axis == 'per sample':
                            # continue
                            data = data.transpose()
                            preds = preds.transpose()
                            # assert sizes match and make them into dfs
                        assert data.shape == preds.shape
                        data = pd.DataFrame(data)
                        preds = pd.DataFrame(preds)
                        
                        # print(data.shape, preds.shape, data.columns, preds.columns)
                        
                        corrs, _, full_corr, _ = util.run_correlation_tests(data, preds, corr_func=corr_func, parallel=False)
                        # function is run per columns, transpose if doing per sample.
                        # print(len(corrs), y[:, mask].shape, data.shape, preds.shape)
                        if axis == 'per sample':
                            assert len(corrs) == y[:, mask].shape[0]
                        else:
                            assert len(corrs) == y[:, mask].shape[1]
                        for val in corrs.values():
                            known_varying_entries.append({'known_frac': known_frac, 'model': model_name_to_organism[models[i].id], 'method': method_name, 'metric': metric, 'axis': axis, 'masking': masking, 'value': val})
        
    print("Finished iteration for known fraction {}".format(known_frac))
    
known_varying_results = pd.DataFrame.from_records(known_varying_entries)

In [None]:
filtered = known_varying_results.loc[(known_varying_results['metric'] == 'spearmanr') &  (known_varying_results['axis'] == 'per reaction') & (known_varying_results['masking'] == 'unknown')]

sns.set(rc={'figure.figsize':(10, 6)})

filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")

filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])
plt.xlabel("Fraction of known reactions")
plt.ylabel("Per Reaction Spearman Correlation")

# sns.lineplot(data=filtered, hue="method", x='known_frac', y="value", style='model')
sns.lineplot(data=filtered.sort_values(by=["method", "model"]), hue="method", x='known_frac', y="value", marker="o")
filename =  "exact_knownfrac_varying_rxn_unknown.png"
plt.gcf().savefig(os.path.join(export_path, filename))

In [None]:
filtered = known_varying_results.loc[(known_varying_results['metric'] == 'spearmanr') &  (known_varying_results['axis'] == 'per reaction') & (known_varying_results['masking'] == 'known')]
# sns.lineplot(data=filtered, hue="method", x='known_frac', y="value", style='model')
sns.lineplot(data=filtered.sort_values(by=["method", "model"]), hue="method", x='known_frac', y="value", marker="o")

In [None]:
filtered = known_varying_results.loc[(known_varying_results['metric'] == 'spearmanr') &  (known_varying_results['axis'] == 'per sample') & (known_varying_results['masking'] == 'unknown')]


sns.set(rc={'figure.figsize':(10, 6)})

filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")

filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])
plt.xlabel("Fraction of known reactions")
plt.ylabel("Per Sample Spearman Correlation")

# sns.lineplot(data=filtered, hue="method", x='known_frac', y="value", style='model')
sns.lineplot(data=filtered.sort_values(by=["method", "model"]), hue="method", x='known_frac', y="value", marker="o")
filename =  "exact_knownfrac_varying_sample_unknown.png"
plt.gcf().savefig(os.path.join(export_path, filename))

In [None]:
filtered = known_varying_results.loc[(known_varying_results['metric'] == 'spearmanr') &  (known_varying_results['axis'] == 'per sample') & (known_varying_results['masking'] == 'known')]
# sns.lineplot(data=filtered, hue="method", x='known_frac', y="value", style='model')
sns.lineplot(data=filtered.sort_values(by=["method", "model"]), hue="method", x='known_frac', y="value", marker="o")

In [None]:
known_varying_results.to_csv(os.path.join(export_path, "figure_recreation_data", "exact_noisy_data_known"))

# Check dependences of performance on noise power

In [None]:
methods = [
          FbaProjection, 
          FbaProjectionLowMidConfidence,
          FbaProjectionHighMidConfidence,
          FBAWrapper,
          IMATWrapper,
          MoMAWrapper
            ]
method_names = [m.__repr__(None) for m in methods]

In [None]:
%%time
noise_power_steps = list(np.linspace(start=0.01, stop=4, num=20, endpoint=True))

noise_varying_entries = list()
for i, cur_noise_power in enumerate(noise_power_steps):
    print("noise level: {}".format(noise_power_steps[i]))
    ss_fluxes = list()
    perturbed_fluxes = list()
    names = list()
    masks = list()

    coefficients = [np.random.random(size=(basis.shape[1], n_samples_per_graph)) for basis in ss_bases]
    steady_states = [(basis @ coef).transpose() for basis, coef in zip(ss_bases, coefficients)]     
    known_indices_list = [sorted(np.random.choice(np.array(list(range(nss.shape[1]))), size=math.ceil(base_frac_known_reactions * nss.shape[1]), replace=False)) for nss in steady_states]
    masks = list()
    noised_steady_states = list()
    for ss, idxs in zip(steady_states, known_indices_list):
        mask = torch.zeros(ss.shape[1],dtype=bool, device=device)
        mask[idxs] = True
        # print("mask sum:", sum(mask))
        masks.append(mask)
        noised = list()
        # print("ss shape:", ss.shape)
        for s in ss:
            # print("s shape:", s.shape)
            ns = s * np.where(mask.cpu(), 1, (1 + cur_noise_power * 2 * (np.random.random(size=s.shape) - 0.5)))
            noised.append(ns)
        noised_steady_states.append(noised)
    
    ys = [torch.tensor(y.astype(np.float64), device=device) for y in steady_states]
    Xs = [torch.tensor(np.array(X).astype(np.float64), device=device) for X in noised_steady_states]
        
    # Generate approximate bounds to be used for FBA and iterative projection versions
    ls = []
    us = []
    for model_index, model, X, Y, mask in zip(list(range(len(models))), models, Xs, ys, masks):
        l = torch.zeros((X.shape[0], X.shape[1]), dtype=torch.float, device=device)
        u = torch.zeros((X.shape[0], X.shape[1]), dtype=torch.float, device=device)
        for i, sample in enumerate(Y):
            # Change bounds to the gap that was possible during the noise generation if unknown, and roughly the measured values for measured indices
            sign = torch.sign(sample)
            l[i, :] = (1 - cur_noise_power * sign * ~mask) * sample - epsilon  
            u[i, :] = (1 + cur_noise_power * sign * ~mask) * sample + epsilon  
    
        ls.append(l)
        us.append(u)
        
    containers = [[] for method in methods]
    for model, A, basis, X, y, mask, l, u in zip(models, As, ss_bases, Xs, ys, masks, ls, us):
        obj_coefs = {rxn.id: rxn.objective_coefficient for rxn in model.reactions if rxn.objective_coefficient != 0}
        print(obj_coefs)
        print([r.id for r in model.reactions if 'biomass' in r.id.lower()])
        if len(obj_coefs) == 0:
            growth_related_reactions = [r.id for r in model.reactions if 'biomass' in r.id.lower()]
            assert len(growth_related_reactions) == 1
            obj_coefs = {growth_related_reactions[0]: 1}
        assert len(obj_coefs) == 1 and all([v == 1 for k, v in obj_coefs.items()])
        objective_reaction_ids = obj_coefs.popitem()[0]

        with model:
            unknown_indices = [q for q in range(len(mask)) if ~(mask[q])]
            measured_indices = [q for q in range(len(mask)) if mask[q]]
            for method, container in zip(methods, containers): 
                # print("A shape, X shape:", A.shape, X.shape)
                proj = method(model=model, stoichiometric_matrix=A,
                              unknown_indices=unknown_indices, measured_indices=measured_indices, steady_state_basis_matrix=basis, 
                              l_bounds=l, u_bounds=u, device=device, objective_id=objective_reaction_ids,
                              acond=1e-2, rcond=1e-2,
                              dtype=torch.float)
                y_pred = proj.forward(X.to(dtype=torch.float, device=device), l_bounds=l, u_bounds=u)
                container.append(y_pred)
                del proj
    for i, y in enumerate(ys):
        print(i)
        for method_name, container in zip(method_names, containers):
            for metric in ['spearmanr', 'l2']:
                if metric == 'l2':
                    continue
                corr_func = stats.spearmanr if metric == 'spearmanr' else lambda a, b, nan_policy='omit': (-scipy.linalg.norm((a - b).values, axis=0), 0)
                for masking in ['all', 'known', 'unknown']:
                    if masking == 'all':
                        # mask = torch.ones_like(y[0], dtype=torch.bool)
                        continue
                    elif masking == 'known':
                        mask = masks[i]
                    else:
                        mask = ~masks[i]
                    for axis in ['per sample', 'per reaction']:
                        preds = container[i][:, mask].cpu().numpy()
                        data = y[:, mask].cpu().numpy()
                        if axis == 'per sample':
                            # continue ## unused currently.
                            data = data.transpose()
                            preds = preds.transpose()
                            # assert sizes match and make them into dfs
                        assert data.shape == preds.shape
                        data = pd.DataFrame(data)
                        preds = pd.DataFrame(preds)
                        
                        # print(data.shape, preds.shape, data.columns, preds.columns)
                        
                        corrs, _, full_corr, _ = util.run_correlation_tests(data, preds, corr_func=corr_func, parallel=False)
                        # function is run per columns, transpose if doing per sample.
                        # print(len(corrs), y[:, mask].shape, data.shape, preds.shape)
                        if axis == 'per sample':
                            assert len(corrs) == y[:, mask].shape[0]
                        else:
                            assert len(corrs) == y[:, mask].shape[1]
                        for val in corrs.values():
                            noise_varying_entries.append({'noise_power': cur_noise_power, 'model': model_name_to_organism[models[i].id], 'method': method_name, 'metric': metric, 'axis': axis, 'masking': masking, 'value': val})
        
    print("Finished iteration for noise power {}".format(cur_noise_power))
    
noise_varying_results = pd.DataFrame.from_records(noise_varying_entries)

In [None]:
noise_varying_results.to_csv(os.path.join(export_path, "figure_recreation_data", "exact_noisy_data_noise"))

In [None]:
filtered = noise_varying_results.loc[(noise_varying_results['metric'] == 'spearmanr') &  (noise_varying_results['axis'] == 'per reaction') & (noise_varying_results['masking'] == 'unknown')]


sns.set(rc={'figure.figsize':(10, 6)})

filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")

filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])
plt.xlabel("Noise Power δ")
plt.ylabel("Per Reaction Spearman Correlation")

# sns.lineplot(data=filtered, hue="method", x='noise_level', y="value", style='model')
sns.lineplot(data=filtered.sort_values(by=["method", "model"]), hue="method", x='noise_power', y="value", marker="o")
filename =  "exact_noise_varying_rxn_unknown.png"
plt.gcf().savefig(os.path.join(export_path, filename))

In [None]:
filtered = noise_varying_results.loc[(noise_varying_results['metric'] == 'spearmanr') &  (noise_varying_results['axis'] == 'per reaction') & (noise_varying_results['masking'] == 'known')]
sns.lineplot(data=filtered.sort_values(by=["method", "model"]), hue="method", x='noise_power', y="value", marker="o")

In [None]:
filtered = noise_varying_results.loc[(noise_varying_results['metric'] == 'spearmanr') &  (noise_varying_results['axis'] == 'per sample') & (noise_varying_results['masking'] == 'unknown')]


sns.set(rc={'figure.figsize':(10, 6)})

filtered['method'] = filtered['method'].str.replace("LowMid", "Partial").str.replace("HighMid", "Fixed")

filtered['method'] = pd.Categorical(filtered['method'], 
                                ["FBApro", 
                                 "FBAproPartial", 
                                 "FBAproFixed",
                                 "FBA",
                                 "MoMA",
                                 "iMAT",
                                ])
plt.xlabel("Noise Power δ")
plt.ylabel("Per Sample Spearman Correlation")

# sns.lineplot(data=filtered, hue="method", x='noise_level', y="value", style='model')
sns.lineplot(data=filtered.sort_values(by=["method", "model"]), hue="method", x='noise_power', y="value", marker="o")
filename =  "exact_noise_varying_sample_unknown.png"
plt.gcf().savefig(os.path.join(export_path, filename))

In [None]:
filtered = noise_varying_results.loc[(noise_varying_results['metric'] == 'spearmanr') &  (noise_varying_results['axis'] == 'per sample') & (noise_varying_results['masking'] == 'known')]
sns.lineplot(data=filtered.sort_values(by=["method", "model"]), hue="method", x='noise_power', y="value", marker="o")