In [None]:
import pickle
import pprint
import os
from src.simulate import experiment
import src.params as params
from src.plot_visuals import generate_gif, generate_pdf
from src.plot_metrics import (
    plot_metrics,
    plot_bar,
    plot_results_matrix,
    plot_multi_corrs,
    ttests,
)

# make sure pyplot uses retina display
%config InlineBackend.figure_format = 'retina'

### Define hyperparameters and experiment conditions

In [None]:
pprint.pprint(params.SimHyperparams())
pprint.pprint(params.PlotHyperparams())

In [None]:
inhibit_range = [0.0, 0.03, 0.06, 0.13, 0.25, 0.5]
excite_range = [0.0, 0.03, 0.06, 0.13, 0.25, 0.5]
dose_names = ["zero", "min", "light", "medium", "heavy", "max"]

# 5-HT2a agonism
excite_experiment = {
    f"2a_{dose_name}": params.SimHyperparams(
        excite_str=two_a,
        inhibit_str=inhibit_range[0],
    )
    for dose_name, two_a in zip(dose_names, excite_range)
}

# 5-HT1a agonism
inhibit_experiment = {
    f"1a_{dose_name}": params.SimHyperparams(
        inhibit_str=one_a,
        excite_str=excite_range[0],
    )
    for dose_name, one_a in zip(dose_names, inhibit_range)
}

# mixed agonism
mix_ranges = [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
mix_experiment = {
    f"1a_{dose_names[one_a]}_2a_{dose_names[two_a]}": params.SimHyperparams(
        inhibit_str=inhibit_range[one_a],
        excite_str=excite_range[two_a],
    )
    for one_a, two_a in mix_ranges
}

# compare agonism
compare_ranges = [[0, 0], [0, 3], [3, 0], [0, 5], [5, 0], [3, 3], [5, 5], [5, 4]]
compare_experiment = {
    f"1a_{dose_names[one_a]}_2a_{dose_names[two_a]}": params.SimHyperparams(
        inhibit_str=inhibit_range[one_a],
        excite_str=excite_range[two_a],
    )
    for one_a, two_a in compare_ranges
}

# full agonism
full_experiment = {
    f"{one_a}_{two_a}": params.SimHyperparams(
        inhibit_str=one_a,
        excite_str=two_a,
    )
    for one_a in inhibit_range
    for two_a in excite_range
}

experiment_choices = {
    "2a_agonism": excite_experiment,
    "1a_agonism": inhibit_experiment,
    "mixed_agonism": mix_experiment,
    "compare_agonism": compare_experiment,
    "full_agonism": full_experiment,
}

plot_name = "full_agonism"
experiment_dict = experiment_choices[plot_name]
plot_params = params.PlotHyperparams()
print(f"Experiments to run: \n{list(experiment_dict.keys())})")

### Run experiment and save results

In [None]:
# run experiment and save results
metric_results = {}
sim_results = {}
all_results = {}
for exp_name, params in experiment_dict.items():
    all_metrics, avg_metrics, std_error_metrics, e_mods, x, y, zs, e_star = experiment(
        params
    )
    all_results[exp_name] = all_metrics
    metric_results[exp_name] = (avg_metrics, std_error_metrics)
    sim_results[exp_name] = (e_mods, x, y, zs, e_star)
    print("Finished experiment: ", exp_name)

if not os.path.exists("./output/results"):
    os.makedirs("./output/results")

with open(f"./output/results/metric_results_{plot_name}.pkl", "wb") as f:
    pickle.dump(metric_results, f)

with open(f"./output/results/all_results_{plot_name}.pkl", "wb") as f:
    pickle.dump(all_results, f)

if plot_params.generate:
    with open(f"./output/results/sim_results_{plot_name}.pkl", "wb") as f:
        pickle.dump(sim_results, f)

### Perform statistical analysis

In [None]:
# load all_results
with open(f"./output/results/all_results_{plot_name}.pkl", "rb") as f:
    all_results = pickle.load(f)

# run t-tests
metric = "div_monotonicity"
summary_type = "final"  # final, avg, max, min
ttests(all_results, experiment_dict, metric, summary_type)

# run correlation plots
x_metrics = ["energy", "gradient_mags", "local_minima", "state_counts"]
y_metric = "divergence"
summary_type_xs = ["avg", "avg", "avg", "final"]  # final, avg, max, min
summary_type_y = "final"  # final, avg, max, min
plot_multi_corrs(
    all_results,
    experiment_dict,
    x_metrics,
    y_metric,
    summary_type_xs,
    summary_type_y,
    plot_name,
)

### Plot result figures

In [None]:
# load metric results from pickle file
with open(f"./output/results/metric_results_{plot_name}.pkl", "rb") as f:
    metric_results = pickle.load(f)

if len(experiment_dict) == 49 or len(experiment_dict) == 36:
    plot_results_matrix(metric_results, inhibit_range, excite_range)
else:
    plot_metrics(
        metric_results,
        plt_name=plot_name,
        minimal_timeseries=plot_params.minimal_timeseries,
        format="pdf",
        render_titles=False,
        exp_dict=experiment_dict,
    )
    # Plot select bar charts
    keys = ["divergence", "div_monotonicity"]
    titles = ["Final KL-divergence", "KL-divergence monotonicity"]
    plot_bar(metric_results, keys, titles, plt_name=plot_name, format="pdf")

### Generate visualization images

In [None]:
if plot_params.generate:
    # load sim results from pickle file
    with open(f"./output/results/sim_results_{plot_name}.pkl", "rb") as f:
        sim_results = pickle.load(f)

    for exp_name, results in sim_results.items():
        e_mods, x, y, zs, e_star = results
        if plot_params.output_type == "pdf":
            generate_pdf(e_mods, e_star, x, y, zs, plot_params, exp_name, 0)
        else:
            generate_gif(e_mods, x, y, zs, e_star, plot_params, exp_name)
        print(f"Generated {exp_name} plots")