# 3) Analyze and evaluate optimization output - training protocols

## cell1_211006_3148

This final notebook uses the `runs.pkl` file created in notebook 2 and it analyzes:

- the distance between different feature sets in the feature space
- the distance between different feature sets in the extracellular signals

In [None]:
import pickle
import pandas as pd
import seaborn as sns
import sys
import shutil
from tqdm import tqdm

import bluepyopt as bpopt
import bluepyopt.ephys as ephys
import neuroplotlib as nplt

import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.spatial import distance
import MEAutility as mu
import json
import time
import numpy as np
from pathlib import Path
from pprint import pprint

import multimodalfitting as mf

%matplotlib notebook

In [None]:
# general
base_dir = Path("../../..")

In [None]:
# general
cell_name = "cell1_211006_3148" # "cell1_211006_3148" | "cell1_211011_3436"
model_name = cell_name
ephys_dir = base_dir / "experimental_data" / cell_name / "patch_data"

cell_models_folder = base_dir / "cell_models"
model_folder = cell_models_folder / cell_name

In [None]:
# change this with folder containing your pkl file
results_date = '220214'  # '211124' '220111' # 
result_folder = base_dir / "results" / results_date

In [None]:
# select abd (axon_bearing_dendrite) option
abd = True

In [None]:
save_fig = False
figure_folder = Path(".") / f"figures_{cell_name}"

if save_fig:
    figure_folder.mkdir(exist_ok=True)

In [None]:
colors_dict = {"soma": "C0",
               "all": "C1",
               "sections": "C2",
               "single": "C3"}
feature_sets = {"soma": "soma",
                "all": "extra",
                "sections": "extra",
                "single": "extra"}
figsize = (10, 7)

# Define cell model and load optimization output

In [None]:
cell = mf.create_experimental_model(model_name=cell_name, abd=abd)
cell_sample = mf.create_experimental_model(model_name=cell_name, release=True, abd=abd)

probe = mf.define_electrode(probe_file=model_folder / "fitting" / "efeatures" / "probe_BPO.json")
param_names = [param.name for param in cell.params.values() if not param.frozen]

In [None]:
pkl_file_name = "runs.pkl"

In [None]:
data = pickle.load(open(result_folder / pkl_file_name, 'rb'))
df_optimization = pd.DataFrame(data)
df_model = df_optimization.query(f"model == '{model_name}'")
# set strategy column
df_model.loc[:, "strategy"] = df_model["extra_strategy"].values.copy()
df_model.loc[df_model["feature_set"] == "soma", "strategy"] = "soma"

results_name = f"opt_results_training_{model_name}"
if abd:
    results_name += "_abd"
else:
    results_name += "_noabd"
results_file = f"{results_name}.pkl"

opt_results_training = None
if (result_folder / results_file).is_file():
    with open(result_folder / results_file, 'rb') as f:
        opt_results_training = pickle.load(f)

In [None]:
fig, ax = plt.subplots()
min_evals = 3000

keep_idxs = []
for idx, row in df_model.iterrows():
    color = mpl.colors.to_rgb(colors_dict[row["strategy"]])
    if row["abd"]:
        keep_idxs.append(idx)
        ax.plot(row["nevals"], 
                row["logbook"].select("min"),
                color=color,
                ls='-', 
                lw=0.5,
                alpha=0.75)
    else:
        color = np.clip(np.array(color) - 0.3, 0, 1)
        ax.plot(row["nevals"], 
                row["logbook"].select("min"),
                color=color,
                ls='-', 
                lw=0.5,
                alpha=0.75)

ax.set_title("Min fitness")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Neval")
ax.set_ylabel("Min fitness")
ax.set_yscale('log')

In [None]:
df_model = df_model.query(f"abd == {abd}")

# Load protocols and experimental features

In [None]:
extra_kwargs = mf.utils.get_extra_kwargs()
extra_kwargs["ms_cut"] = [2, 5]

In [None]:
extra_kwargs

In [None]:
protocol_for_eap = "IDrest_300"

In [None]:
eva_extra = mf.create_evaluator(
    model_name=model_name,
    feature_set="extra",
    extra_strategy="all",
    protocols_with_lfp=protocol_for_eap,
    abd=abd,
    **extra_kwargs
)

In [None]:
# check num features
for strategy in np.unique(df_model.strategy):
    extra_strategy = strategy if strategy in ["all", "single", "sections"] else None
    eva = mf.create_evaluator(
        model_name=model_name,
        feature_set=feature_sets[strategy],
        extra_strategy=extra_strategy,
        protocols_with_lfp="IDrest_300",
        **extra_kwargs
    )
    print(f"Strategy {strategy} --> num features {len(eva.fitness_calculator.objectives)}")
    
print(f"Validation: --> num features {len(eva_extra.fitness_calculator.objectives)}")

## Load experimental responses

In [None]:
from bluepyefe.extract import read_recordings, extract_efeatures_at_targets, compute_rheobase,\
    group_efeatures, create_feature_protocol_files, convert_legacy_targets
from bluepyefe.plotting import plot_all_recordings_efeatures

from multimodalfitting.efeatures_extraction import build_wcp_metadata, wcp_reader, get_ecode_targets, \
    ecodes_wcp_timings

In [None]:
# define timings for this experiment
ecodes_cell_timings = {
    "IDthres": {
        'ton': 400,
        'toff': 670
    },
    "firepattern": {
        'ton': 500,
        'toff': 4100
    },
    "IV": {
        'ton': 400,
        'toff': 3400
    },
    "IDrest": {
        'ton': 400,
        'toff': 1750
    },
    "APWaveform": {
        'ton': 350,
        'toff': 400
    },
    "HyperDepol": {
        'ton': 400,
        'toff': 1120,
        'tmid': 850
    },
    "sAHP": {
        'ton': 400,
        'toff': 1325,
        'tmid': 650,
        'tmid2': 875
    },
    "PosCheops": {
        'ton': 1000,
        't1': 9000,
        't2': 10500,
        't3': 14500,
        't4': 16000,
        'toff': 18660
    }
}

In [None]:
if cell_name == "cell1_211006_3148":
    runs = [1, 2, 3, 4, 5]  # run1 --> different rheobase
elif cell_name == "cell1_211011_3436":
    runs = [3, 4, 5, 6]
    
ecode_names = list(ecodes_cell_timings.keys())

files_list = []

for run in runs:
    rep_dict = {}
    for ecode in ecode_names:
        for patch_file in ephys_dir.iterdir():
            if f"run{run}" in patch_file.name and ecode.lower() in patch_file.name:
                rep_dict[ecode] = patch_file
    files_list.append(rep_dict)

In [None]:
files_metadata = build_wcp_metadata(cell_id=cell_name, 
                                    files_list=files_list, 
                                    ecode_timings=ecodes_cell_timings, 
                                    repetition_as_different_cells=False)

In [None]:
cells = read_recordings(
    files_metadata=files_metadata,
    recording_reader=wcp_reader
)

In [None]:
# define target features for different protocols
targets = get_ecode_targets(ecodes_cell_timings)

In [None]:
global_tolerance = 30
for target in targets:
    target["tolerance"] = global_tolerance

In [None]:
if cell_name == "cell1_211006_3148":
    majority = 0.4
else:
    majority = 0.2

In [None]:
compute_rheobase(
    cells, 
    protocols_rheobase=['IDthres'],
    rheobase_strategy="majority",
    rheobase_settings={"majority": majority}
)

In [None]:
print(f"Cell rheobase: {cells[0].rheobase}")

In [None]:
protocols = group_efeatures(cells, targets, use_global_rheobase=True)

In [None]:
protocols_opt = list(eva_extra.fitness_protocols.keys())
print(protocols_opt)

In [None]:
all_protocols = list(ecodes_cell_timings.keys())
print(all_protocols)

### Build BPO response dicts

In [None]:
responses_experimental = []
responses_all = []
num_runs = len(protocols[0].recordings)
for run in range(num_runs):
    print(f"Populating responses for run {run}")
    response_dict = {}
    response_all_dict = {}
    for protocol in protocols:
        for i, rec in enumerate(protocol.recordings):
            if i == run:
                response = bpopt.ephys.responses.TimeVoltageResponse(name=protocol.name,
                                                                     time=rec.t, voltage=rec.voltage)
                response_all_dict[f"{protocol.stimulus_name}.soma.v"] = response
                if protocol.stimulus_name in protocols_opt:
                    response_dict[f"{protocol.stimulus_name}.soma.v"] = response
    responses_experimental.append(response_dict)
    responses_all.append(response_all_dict)

In [None]:
responses_to_plot = ["APWaveform_290", "IDrest_250", "firepattern_200", "sAHP_250", "PosCheops_300"]
titles = responses_to_plot

In [None]:
fig_exp_sample = mf.plot_responses(responses_all[2], return_fig=True, 
                                   titles=titles, protocol_names=responses_to_plot,
                                   color="k")

In [None]:
if save_fig:
    fig_exp_sample.savefig(figure_folder / f"{cell_name}_sample_response.pdf")
    fig_exp_sample.savefig(figure_folder / f"{cell_name}_sample_response.png", dpi=300)  

In [None]:
# fig_exp_intra = mf.plot_multiple_responses(responses_all, return_fig=True, 
#                                            labels=[f"run{i}" for i in range(num_runs)])

In [None]:
eap_exp = np.load(model_folder / "fitting" / "efeatures" / "template_BPO.npy") / 1000

In [None]:
ax = mu.plot_mea_recording(eap_exp, probe)
fig_exp_extra = ax.get_figure()

In [None]:
vscale = np.max(np.abs(eap_exp))

In [None]:
if save_fig:
    fig_exp_intra.savefig(figure_folder / "exp_intra.png", dpi=300)
    fig_exp_extra.savefig(figure_folder / "exp_extra.png", dpi=300)

# Compute and plot best responses

In [None]:
max_feature_value = 50
if opt_results_all is None:
    opt_results_all = {}
    compute_responses = True
else:
    compute_responses = False

In [None]:
param_boundaries = {}
for param_name, param in eva_extra.cell_model.params.items():
    if not param.frozen:
        param_boundaries[param_name] = param.bounds

In [None]:
if compute_responses:
    for strategy in np.unique(df_model.strategy):
        print(f"Simulating best '{strategy}'")
        # simulate all responses
        opt_df = df_model.query(f"strategy == '{strategy}'")
        opt_results_all[strategy] = {}

        all_responses = {}
        all_eaps = {}
        all_params = {}
        for idx, row in opt_df.iterrows():
            seed = row.seed
            print("\tSeed", row.seed)
            population = row.population
            scores = [sum(pop.fitness.values) for pop in population]
            best_individual_idx = np.argmin(scores)
            params = population[best_individual_idx]
            params_dict = {k: v for k, v in zip(param_names, params)}
            for param_name, param_value in params_dict.items():
                bounds = param_boundaries[param_name]
                if param_value < bounds[0] or param_value > bounds[1]:
                    print(f"{param_name} out of bounds: {bounds}")
            all_params[seed] = params_dict
            responses_seed = eva_extra.run_protocols(eva_extra.fitness_protocols.values(), 
                                                     param_values=params_dict)
            all_responses[seed] = responses_seed
            eap_seed = mf.utils.calculate_eap(responses=responses_seed, protocols=eva_extra.fitness_protocols, 
                                              protocol_name=protocol_for_eap, align_extra=True, **extra_kwargs)
            all_eaps[seed] = eap_seed
        opt_results_all[strategy]["eaps"] = all_eaps
        opt_results_all[strategy]["responses"] = all_responses
        opt_results_all[strategy]["params"] = all_params

In [None]:
if compute_responses:
    for strategy in opt_results_all:
        print(strategy)
        opt_results_all[strategy]["fitness"] = {}
        for seed, responses in opt_results_all[strategy]["responses"].items():
            extra_fitness = 0
            intra_fitness = 0
            for i in tqdm(np.arange(len(eva_extra.fitness_calculator.objectives)), 
                          desc=f"computing features {strategy}"):
                obj = eva_extra.fitness_calculator.objectives[i]
                feat = obj.features[0]
                if len(obj.features) == 1:
                    feat_value = obj.features[0].calculate_feature(responses)
                    feat_score = obj.features[0].calculate_score(responses)
                    if "MEA" in feat.name:
                        extra_fitness += feat_score
                    else:
                        intra_fitness += feat_score
            opt_results_all[strategy]["fitness"][seed] = {"intra": intra_fitness, "extra": extra_fitness, 
                                                          "total": intra_fitness + extra_fitness}
            print("seed", seed)
            print("\tINTRA", intra_fitness)
            print("\tEXTRA", extra_fitness)
            print("\tTOTAL", intra_fitness + extra_fitness)

In [None]:
seeds_array = []
strategy_array = []
intra_score_array = []
extra_score_array = []
total_score_array = []

for strategy in opt_results_all:
    for seed, fitness in opt_results_all[strategy]["fitness"].items():
        seeds_array.append(seed)
        strategy_array.append(strategy)
        intra_score_array.append(fitness["intra"])
        extra_score_array.append(fitness["extra"])
        total_score_array.append(fitness["total"])
df_fitness = pd.DataFrame({"seed": seeds_array, "strategy": strategy_array,
                           "intra_score": intra_score_array, "extra_score": extra_score_array, 
                           "total_score": total_score_array})

In [None]:
# best responses are the solutions that minimize extra_score
best_extras = df_fitness.iloc[df_fitness.groupby("strategy")["extra_score"].idxmin()]
best_extras

In [None]:
colors = [colors_dict[strat] for strat in order]
fig_intra_seeds, ax = plt.subplots(figsize=(7, 10))
sns.boxplot(data=df_fitness, x="strategy", y="intra_score", order=order, 
            palette=colors_dict, ax=ax)
ax.set_xlabel("Strategy", fontsize=15)
ax.set_ylabel("Score", fontsize=15)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=12)
ax.set_title("Intracellular\n(10 seeds)", fontsize=20)

fig_extra_seeds, ax = plt.subplots(figsize=(7, 10))
sns.boxplot(data=df_fitness, x="strategy", y="extra_score", 
            palette=colors_dict, order=order)
ax.set_xlabel("Strategy", fontsize=15)
ax.set_ylabel("Score", fontsize=15)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=12)
ax.set_title("Extracellular\n(10 seeds)", fontsize=20)

In [None]:
for idx, row in best_extras.iterrows():
    strategy = row["strategy"]
    seed = row["seed"]
    print("Strategy", strategy, "best seed", seed)
    responses = opt_results_all[strategy]["responses"][seed]
    eap = opt_results_all[strategy]["eaps"][seed]
    params = opt_results_all[strategy]["params"][seed]
    opt_results_all[strategy]["best_seed"] = seed
    opt_results_all[strategy]["best_responses"] = responses
    opt_results_all[strategy]["best_eap"] = eap
    opt_results_all[strategy]["best_params"] = params
    eap_exp_norm = eap_exp / np.ptp(np.abs(eap_exp), 1, keepdims=True)
    eap_norm = eap / np.ptp(np.abs(eap), 1, keepdims=True)
    eap_dist = np.sum(np.abs(eap_exp_norm.ravel() - eap_norm.ravel()))
    opt_results_all[strategy]["best_eap_dist"] = eap_dist
    print(eap_dist)

In [None]:
responses_exp = responses_experimental[1]

In [None]:
# plot
figs_intra = {}
figs_extra = {}
protocols_to_plot = ["APWaveform_290", "IDrest_250", "IV_-100"]
titles = protocols_to_plot
for strategy in np.unique(df_model.strategy):
    responses_to_plot = [responses_exp, opt_results_all[strategy]["best_responses"]]
    colors = ["k", colors_dict[strategy]]
    labels = ["GT", strategy.upper()]
    fig_intra = mf.plot_multiple_responses(responses_to_plot, 
                                           protocol_names=protocols_to_plot,
                                           colors=colors, 
                                           titles=titles,
                                           return_fig=True, 
                                           labels=labels)
    eap = opt_results_all[strategy]["best_eap"]
    fig_extra, ax_extra = plt.subplots(figsize=figsize)
    ax_extra = mu.plot_mea_recording(eap_exp, probe, vscale=vscale, lw=1, ax=ax_extra)
    ax_extra.get_lines()[-1].set_label("GT")
    ax_extra = mu.plot_mea_recording(eap, probe, ax=ax_extra, vscale=vscale, 
                                     colors=colors_dict[strategy], lw=1)
    ax_extra.get_lines()[-1].set_label(strategy.upper())
    ax_extra.set_title("EAP", fontsize=15)
    ax_extra.legend()
    figs_intra[strategy] = fig_intra
    figs_extra[strategy] = fig_extra

In [None]:
# for strategy in np.unique(df_model.strategy):
#     opt_results[strategy] = {}
#     opt_df = df_model.query(f"strategy == '{strategy}'")
#     best_idx = np.argmin(opt_df.best_fitness)
#     params_sample = opt_df.iloc[best_idx]
#     params_dict = {k: v for k, v in zip(param_names, params_sample.best_params)}
#     opt_results[strategy]["best_fitness"] = params_sample.best_fitness
#     opt_results[strategy]["best_params"] = params_dict
#     print(f"{strategy} --  best fitness: {params_sample.best_fitness}")

In [None]:
# for strategy in np.unique(df_model.strategy):
#     print(f"Simulating best '{strategy}'")
#     responses = eva_extra.run_protocols(eva_extra.fitness_protocols.values(), 
#                                         param_values=opt_results[strategy]["best_params"])
#     opt_results[strategy]["responses"] = responses
#     eap = mf.utils.calculate_eap(responses=responses, protocols=eva_extra.fitness_protocols, 
#                                  protocol_name=protocol_for_eap, align_extra=True, **extra_kwargs)
#     opt_results[strategy]["eap"] = eap   
#     eap_dist = distance.cosine(eap_exp.ravel(), eap.ravel())
#     opt_results[strategy]["eap_dist"] = eap_dist

In [None]:
for strategy in np.unique(df_model.strategy):
    responses = opt_results_all[strategy]["best_responses"]
    features_best = {}
    for i in tqdm(np.arange(len(eva_extra.fitness_calculator.objectives)), 
                  desc=f"computing features {strategy}"):
        obj = eva_extra.fitness_calculator.objectives[i]
        feat = obj.features[0]
        features_best[feat.name] = {}
        if len(obj.features) == 1:
            feat_value = obj.features[0].calculate_feature(responses)
            if feat_value is None:
                feat_value = max_feature_value
            features_best[feat.name]["value"] = feat_value
            if "MEA" not in feat.name:
                feat_score = np.abs(feat.exp_mean - feat_value) / feat.exp_std
            else:
                feat_score = np.abs(distance.cosine(feat.exp_mean, feat_value))
            features_best[feat.name]["score"] = feat_score
        else:
            print(f"More than one feature for objective: {obj.name}")
    opt_results_all[strategy]["features"] = features_best

In [None]:
opt_results_all["exp"] = {}
opt_results_all["exp"]["responses"] = responses_all
opt_results_all["exp"]["eap"] = eap_exp

In [None]:
with open(result_folder / results_file, 'wb') as f:
    pickle.dump(opt_results_all, f, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
# # plot
# response_experimental_plot = responses_experimental[1]
# for strategy in np.unique(df_model.strategy):
#     responses_to_plot = [response_experimental_plot, opt_results[strategy]["responses"]]
#     eap = opt_results[strategy]["eap"]
#     colors = ["k", colors_dict[strategy]]
#     labels = ["GT", strategy.upper()]
#     fig_extra_intra_single = mf.plot_multiple_responses(responses_to_plot, 
#                                                         colors=colors, return_fig=True, 
#                                                         labels=labels)
#     fig, ax_extra = plt.subplots(figsize=figsize)
#     ax_extra = mu.plot_mea_recording(eap_exp, probe, vscale=vscale, lw=1, ax=ax_extra)
#     ax_extra.get_lines()[-1].set_label("GT")
#     ax_extra = mu.plot_mea_recording(eap, probe, ax=ax_extra, vscale=vscale, 
#                                      colors=colors_dict[strategy], lw=1)
#     ax_extra.get_lines()[-1].set_label(strategy.upper())
#     ax_extra.set_title("EAP", fontsize=15)
#     ax_extra.legend()
# #     ax_extra = mf.plot_multiple_eaps(responses_to_plot, 
# #                                      eva_extra.fitness_protocols, probe,
# #                                      protocol_name=protocol_for_eap, 
# #                                      colors=colors, labels=labels)

In [None]:
# for strategy in opt_results:
#     print(f"Cosine dist {strategy}: {opt_results[strategy]['eap_dist']}")

## Compare best-fitted models

In [None]:
# df_test = pd.DataFrame.from_dict(opt_results_all, orient="index")
# df_test["strategy"] = df_test.index

### Compare features

In [None]:
feature_name_array = []
feature_set_array = []
feature_score_array = []
feature_type_array = []

for strategy in np.unique(df_model.strategy):
    res = opt_results_all[strategy]
    feats = res["features"]
    for feat_name, feat_dict in feats.items():
        feature_set_array.append(strategy)
        feature_name_array.append(feat_name)
        if "MEA" not in feat_name:
            feature_type_array.append("intra")
        else:
            feature_type_array.append("extra")
        feature_score_array.append(feat_dict["score"])
        
df_feats = pd.DataFrame({"feature_set": feature_set_array, "feat_name": feature_name_array,
                         "feat_score": feature_score_array, "feature_type": feature_type_array})

In [None]:
fig_feat_intra, ax = plt.subplots(figsize=figsize)

sns.boxplot(data=df_feats.query("feature_type == 'intra'"), y="feature_set",
            x="feat_score", ax=ax, palette=colors_dict)
ax.set_ylabel("Feature scores (intracellular)", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title("Intracellular features", fontsize=15)

fig_feat_extra, ax = plt.subplots(figsize=figsize)

sns.boxplot(data=df_feats.query("feature_type == 'extra'"), 
            y="feature_set", x="feat_score", ax=ax, palette=colors_dict)
ax.set_ylabel("Feature scores (extracellular)", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title("Extracellular features", fontsize=15)


### Compare EAP dist

In [None]:
fig_cos, ax = plt.subplots()
sns.barplot(data=df_test, x="strategy", y="eap_dist", ax=ax)
ax.set_ylabel("Cosine distance", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title("Extracellular difference", fontsize=15)

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))
ax = mu.plot_mea_recording(eap_exp, probe, ax=ax)
mf.plot_cell(cell=eva_extra.cell_model, sim=eva_extra.sim, ax=ax, param_values=opt_results["soma"]["best_params"],
             color_ais="g", color_axon="b", alpha=0.4, detailed=False)
fig_morph_eap = ax.get_figure()

In [None]:
if save_fig:
    fig_morph_eap.savefig(figure_folder / f"{cell_name}_eap_morph.png", dpi=300)