# 3) Analyze and evaluate optimization output - test protocols

## cell1_211006_3148

This final notebook uses the `runs.pkl` file created in notebook 2 and it analyzes:

- the distance between different feature sets in the feature space
- the distance between different feature sets in the extracellular signals

In [None]:
import pickle
import pandas as pd
import seaborn as sns
import sys
import shutil

from tqdm import tqdm
import bluepyopt as bpopt
import bluepyopt.ephys as ephys
import neuroplotlib as nplt

import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.spatial import distance
import MEAutility as mu
import json
import time
import numpy as np
from pathlib import Path
from pprint import pprint

import multimodalfitting as mf

%matplotlib notebook

In [None]:
colors_dict = {"soma": "C0",
               "all": "C1",
               "sections": "C2",
               "single": "C3"}
figsize = (10, 7)

In [None]:
# general
base_dir = Path("../../..")

In [None]:
# general
cell_name = "cell1_211006_3148" # "cell1_211006_3148" | "cell1_211011_3436"
model_name = cell_name
ephys_dir = base_dir / "experimental_data" / cell_name / "patch_data"

cell_models_folder = base_dir / "cell_models"
model_folder = cell_models_folder / cell_name

In [None]:
save_fig = False
figure_folder = Path(".") / f"figures_{cell_name}"

if save_fig:
    figure_folder.mkdir(exist_ok=True)

In [None]:
# change this with folder containing your pkl file
results_date = '220214'  # '211124' '220111' # 
result_folder = base_dir / "results" / results_date

In [None]:
protocol_for_eap = "firepattern_120"

In [None]:
pkl_file_name = "runs.pkl"

In [None]:
# select abd (axon_bearing_dendrite) option
abd = True

In [None]:
data = pickle.load(open(result_folder / pkl_file_name, 'rb'))
df_optimization = pd.DataFrame(data)
df_model = df_optimization.query(f"model == '{model_name}'")
# set strategy column
df_model.loc[:, "strategy"] = df_model["extra_strategy"].values.copy()
df_model.loc[df_model["feature_set"] == "soma", "strategy"] = "soma"

results_name = "opt_results_training"
if abd:
    results_name += "_abd"
else:
    results_name += "_noabd"

if (result_folder / f"{results_name}.pkl").is_file():
    with open(result_folder / f"{results_name}.pkl", 'rb') as f:
        opt_results_all = pickle.load(f)
else:
    print(f"Couldn't find {results_name}.pkl. Run training validation first!")

# Define cell model and load optimization output

In [None]:
cell = mf.create_experimental_model(model_name=cell_name, abd=abd)
cell_sample = mf.create_experimental_model(model_name=cell_name, release=True, abd=abd)

probe = mf.define_electrode(probe_file=model_folder / "fitting" / "efeatures" / "probe_BPO.json")
param_names = [param.name for param in cell.params.values() if not param.frozen]

In [None]:
fig, ax = plt.subplots()
min_evals = 3000

keep_idxs = []
for idx, row in df_model.iterrows():
    color = mpl.colors.to_rgb(colors_dict[row["strategy"]])
    if row["abd"]:
        keep_idxs.append(idx)
        ax.plot(row["nevals"], 
                row["logbook"].select("min"),
                color=color,
                ls='-', 
                lw=0.5,
                alpha=0.75)
    else:
        color = np.clip(np.array(color) - 0.3, 0, 1)
        ax.plot(row["nevals"], 
                row["logbook"].select("min"),
                color=color,
                ls='-', 
                lw=0.5,
                alpha=0.75)

ax.set_title("Min fitness")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Neval")
ax.set_ylabel("Min fitness")
ax.set_yscale('log')

# Load protocols and experimental features

In [None]:
extra_kwargs = mf.utils.get_extra_kwargs()
extra_kwargs["ms_cut"] = [2, 5]

In [None]:
protocols_used_for_opt = ["IV_-20", "IV_-100", "IDrest_150", "IDrest_250", "IDrest_300",
                          "APWaveform_260"]
protocols_to_exclude = ["IV", "APWaveform", "IDrest"]

In [None]:
eva_extra = mf.create_evaluator(
    model_name=model_name,
    feature_set="extra",
    extra_strategy="all",
    protocols_with_lfp=protocol_for_eap,
    all_protocols=True,
    exclude_protocols=protocols_to_exclude,
    abd=abd,
    **extra_kwargs
)

## Load experimental responses

In [None]:
assert "exp" in opt_results_all
responses_experimental = opt_results_all["exp"]["responses"]
eap_exp = opt_results_all["exp"]["eap"]
num_runs = len(responses_experimental)

In [None]:
len(responses_experimental)

In [None]:
protocols_to_plot = ["firepattern_200", "HyperDepol_-160", "HyperDepol_-40", "sAHP_250", "PosCheops_300"]

In [None]:
fig_exp_intra = mf.plot_multiple_responses(responses_experimental, return_fig=True,
                                           protocol_names=protocols_to_plot,
                                           labels=[f"run{i}" for i in range(num_runs)])

In [None]:
if save_fig:
    fig_exp_intra.savefig(figure_folder / "exp_intra.png", dpi=300)
    fig_exp_extra.savefig(figure_folder / "exp_extra.png", dpi=300)

### Compute experimental extra features

In [None]:
# compute extracellular features
std_from_mean = 0.05
extra_features = mf.efeatures_extraction.compute_extra_features(
    eap_exp, fs=extra_kwargs["fs"],
    upsample=extra_kwargs["upsample"])

In [None]:
features_release = {}
for obj in eva_extra.fitness_calculator.objectives:
    features_release[obj.features[0].name] = None
# add extra features
for efeat_name, feat in extra_features.items():
    for chan, feat_val in enumerate(feat):
        fature_name = f"{protocol_for_eap}.MEA.{efeat_name}_{chan}"
        features_release[fature_name] = {}
        features_release[fature_name]["value"] = feat_val

# Compute and plot best responses

In [None]:
max_feature_value = 50
opt_results = {}

In [None]:
for strategy in np.unique(df_model.strategy):
    opt_results[strategy] = {}
    opt_df = df_model.query(f"strategy == '{strategy}'")
    best_idx = np.argmin(opt_df.best_fitness)
    params_sample = opt_df.iloc[best_idx]
    params_dict = {k: v for k, v in zip(param_names, params_sample.best_params)}
    opt_results[strategy]["best_fitness"] = params_sample.best_fitness
    opt_results[strategy]["best_params"] = params_dict
    print(f"{strategy} --  best fitness: {params_sample.best_fitness}")

In [None]:
for strategy in opt_results_all:
    opt_results[strategy] = {}
    print(f"Simulating best '{strategy}' -- seed: {opt_results_all[strategy]['best_seed']}")
    best_params = opt_results_all[strategy]["best_params"]
    t_start = time.time()
    responses = eva_extra.run_protocols(eva_extra.fitness_protocols.values(), 
                                        param_values=best_params)
    eap = mf.utils.calculate_eap(responses=responses, protocols=eva_extra.fitness_protocols, 
                                 protocol_name=protocol_for_eap, **extra_kwargs)
    t_stop = time.time()
    print(f"Simulated responses in {np.round(t_stop - t_start, 2)} s")
    eap_release_norm = eap_release / np.ptp(np.abs(eap_release), 1, keepdims=True)
    eap_norm = eap / np.ptp(np.abs(eap), 1, keepdims=True)
    eap_dist = np.sum(np.abs(eap_release_norm.ravel() - eap_norm.ravel()))
    opt_results[strategy]["eap_dist"] = eap_dist
    opt_results[strategy]["responses"] = responses
    opt_results[strategy]["eap"] = eap    

In [None]:
fp120 = eva_extra.fitness_protocols["firepattern_120"]

In [None]:
resp_fp120 = fp120.run(eva_extra.cell_model, param_values=params_dict, sim=eva_extra.sim)

In [None]:
resp_fp120['firepattern_120.soma.v']

In [None]:
plt.figure()
plt.plot(resp_fp120['firepattern_120.MEA.LFP']["time"], resp_fp120['firepattern_120.MEA.LFP']["voltage"].T)

In [None]:
eva_extra.run_protocols(eva_extra.fitness_protocols.values(), 
                                        param_values=best_params)

In [None]:
mf.plot_responses(responses)

In [None]:
del responses["firepattern_120.MEA.LFP"]

In [None]:
mf.plot_responses(responses)

In [None]:
feat_objectives = [obj.features[0].name for obj in eva_extra.fitness_calculator.objectives]

In [None]:
for strategy in np.unique(df_model.strategy):
    responses = opt_results[strategy]["responses"]
    eap = opt_results[strategy]["eap"]
    extra_features_strategy = mf.efeatures_extraction.compute_extra_features(
                                    eap, fs=extra_kwargs["fs"],
                                    upsample=extra_kwargs["upsample"])
    opt_results[strategy]["extra_features"] = extra_features_strategy

    features_best = {}
    feat_release_keys = list(features_release.keys())
    for i in tqdm(np.arange(len(feat_release_keys)), 
                  desc=f"computing features {strategy}"):
        feat_name = feat_release_keys[i]
        features_best[feat_name] = {}
        if feat_name in feat_objectives:
            obj = eva_extra.fitness_calculator.objectives[feat_objectives.index(feat_name)]
            feat = obj.features[0]
            if len(obj.features) == 1:
                feat_score = feat.calculate_score(responses)
                features_best[feat_name]["score"] = feat_score
            else:
                print(f"More than one feature for objective: {obj.name}")
        else:
            # extra
            protocol, _, efeat_full = feat_name.split(".")
            efeat_split = efeat_full.split("_")
            chan = efeat_split[-1]
            efeat = "_".join(efeat_split[:-1])
            chan = int(chan)
            feat_value = extra_features_strategy[efeat][chan]
            release_value = features_release[feat_name]["value"]
            if release_value != 0:
                feat_score = abs(feat_value - release_value) / abs(std_from_mean * release_value)
            else:                    
                feat_score = abs(feat_value - release_value)
            features_best[feat_name]["value"] = feat_value
            features_best[feat_name]["score"] = feat_score
            
    opt_results[strategy]["features"] = features_best
    

In [None]:
responses_exp = responses_experimental[1]

In [None]:
protocols_to_plot = ["firepattern_200", "HyperDepol_-160", "HyperDepol_-40", 
                     "sAHP_250", "PosCheops_300"]
titles = protocols_to_plot
figs_intra = {}
for strategy in np.unique(df_model.strategy):
    responses_to_plot = [responses_exp, opt_results[strategy]["responses"]]
    colors = ["k", colors_dict[strategy]]
    labels = ["GT", strategy.upper()]
    fig = mf.plot_multiple_responses(responses_to_plot, 
                                    colors=colors, return_fig=True, 
#                                     labels=labels,
                                    protocol_names=protocols_to_plot,
                                    titles=titles,
                                    figsize=(7, 12))
    figs_intra[strategy] = fig

In [None]:
order_full = ["soma", "all", "sections", "single"]
order = []
for strategy in order_full:
    if strategy in opt_results:
        order.append(strategy)

## Compare features

In [None]:
feature_name_array = []
feature_set_array = []
feature_score_array = []
feature_type_array = []
protocol_type_array = []

for strategy, strategy_dict in opt_results.items():
    feats = strategy_dict["features"]
    for feat_name, feat_dict in feats.items():
        feature_set_array.append(strategy)
        feature_name_array.append(feat_name)
        if "MEA" not in feat_name:
            feature_type_array.append("intra")
        else:
            feature_type_array.append("extra")
        feature_score_array.append(feat_dict["score"])
        protocol_type = feat_name.split(".")[0].split("_")[0]
        protocol_type_array.append(protocol_type)
        
df_feats = pd.DataFrame({"feature_set": feature_set_array, "feat_name": feature_name_array,
                         "feature_type": feature_type_array, "feat_score": feature_score_array, 
                         "protocol_type": protocol_type_array})

In [None]:

df_feats_intra = df_feats.query("feature_type == 'intra'")
df_feats_extra = df_feats.query("feature_type == 'extra'")

fig_feat_intra, ax = plt.subplots(figsize=(7, 10))
sns.boxplot(data=df_feats_intra, x="feature_set", y="feat_score", order=order, #hue="protocol_type", 
            ax=ax)
n = len(df_feats_intra.query("feature_set == 'soma'"))
# g = sns.swarmplot(data=df_feats, y="feature_set", x="feat_score", ax=ax)
ax.set_ylabel("Feature scores", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title(f"Intracellular features\n(n={n})", fontsize=20)
ax.set_xlabel("Strategy", fontsize=15)
ax.set_ylabel("Score", fontsize=15)
ax.set_ylim([0, 21])

fig_feat_extra, ax = plt.subplots(figsize=(7, 10))
sns.boxplot(data=df_feats_extra, 
            x="feature_set", y="feat_score", order=order, ax=ax)
n = len(df_feats_extra.query("feature_set == 'soma'"))
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title(f"Extracellular features\n(n={n})", fontsize=20)
ax.set_xlabel("Strategy", fontsize=15)
ax.set_ylabel("Score", fontsize=15)
ax.set_ylim(0, 21)

In [None]:
for strategy in np.unique(df_model.strategy):
    responses = opt_results[strategy]["responses"]
    features_best = {}
    for obj in eva_extra.fitness_calculator.objectives:
        feat = obj.features[0]
        features_best[feat.name] = {}
        if len(obj.features) == 1:
            feat_value = obj.features[0].calculate_feature(responses)
            if feat_value is None:
                feat_value = max_feature_value
            features_best[feat.name]["value"] = feat_value
            if "MEA" not in feat.name:
                feat_score = np.abs(feat.exp_mean - feat_value) / feat.exp_std
            else:
                feat_score = np.abs(distance.cosine(feat.exp_mean, feat_value))
            features_best[feat.name]["score"] = feat_score
        else:
            print(f"More than one feature for objective: {obj.name}")
    opt_results[strategy]["features"] = features_best

In [None]:
# plot
response_experimental_plot = responses_experimental[1]
for strategy in np.unique(df_model.strategy):
    responses_to_plot = [response_experimental_plot, opt_results[strategy]["responses"]]
    eap = opt_results[strategy]["eap"]
    colors = ["k", colors_dict[strategy]]
    labels = ["GT", strategy.upper()]
    fig_extra_intra_single = mf.plot_multiple_responses(responses_to_plot, 
                                                        colors=colors, return_fig=True, 
                                                        labels=labels)
    fig, ax_extra = plt.subplots(figsize=figsize)
    ax_extra = mu.plot_mea_recording(eap_exp, probe, vscale=vscale, lw=1, ax=ax_extra)
    ax_extra.get_lines()[-1].set_label("GT")
    ax_extra = mu.plot_mea_recording(eap, probe, ax=ax_extra, vscale=vscale, 
                                     colors=colors_dict[strategy], lw=1)
    ax_extra.get_lines()[-1].set_label(strategy.upper())
    ax_extra.set_title("EAP", fontsize=15)
    ax_extra.legend()
#     ax_extra = mf.plot_multiple_eaps(responses_to_plot, 
#                                      eva_extra.fitness_protocols, probe,
#                                      protocol_name=protocol_for_eap, 
#                                      colors=colors, labels=labels)

In [None]:
for strategy in opt_results:
    print(f"Cosine dist {strategy}: {opt_results[strategy]['eap_dist']}")

## Compare best-fitted models

In [None]:
df_test = pd.DataFrame.from_dict(opt_results, orient="index")
df_test["strategy"] = df_test.index

### Compare features

In [None]:
feature_name_array = []
feature_set_array = []
feature_score_array = []
feature_type_array = []

for strategy, res in opt_results.items():
    feats = res["features"]
    for feat_name, feat_dict in feats.items():
        feature_set_array.append(strategy)
        feature_name_array.append(feat_name)
        if "MEA" not in feat_name:
            feature_type_array.append("intra")
        else:
            feature_type_array.append("extra")
        feature_score_array.append(feat_dict["score"])
        
df_feats = pd.DataFrame({"feature_set": feature_set_array, "feat_name": feature_name_array,
                         "feat_score": feature_score_array, "feature_type": feature_type_array})

In [None]:
fig_feat_intra, ax = plt.subplots(figsize=figsize)

sns.boxplot(data=df_feats.query("feature_type == 'intra'"), y="feature_set", x="feat_score", ax=ax)
ax.set_ylabel("Feature scores (intracellular)", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title("Intracellular features", fontsize=15)

fig_feat_extra, ax = plt.subplots(figsize=figsize)

sns.boxplot(data=df_feats.query("feature_type == 'extra'"), 
              y="feature_set", x="feat_score", ax=ax)
ax.set_ylabel("Feature scores (extracellular)", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title("Extracellular features", fontsize=15)
