# 3) Analyze and evaluate optimization output - training protocols

## cell1_211006_3148

This final notebook uses the `runs.pkl` file created in notebook 2 and it analyzes:

- the distance between different feature sets in the feature space
- the distance between different feature sets in the extracellular signals

In [None]:
import pickle
import pandas as pd
import seaborn as sns
import sys
import shutil

import bluepyopt as bpopt
import bluepyopt.ephys as ephys
import neuroplotlib as nplt

import matplotlib.pyplot as plt
from scipy.spatial import distance
import MEAutility as mu
import json
import time
import numpy as np
from pathlib import Path
from pprint import pprint

import multimodalfitting as mf

%matplotlib notebook

In [None]:
# general
base_dir = Path("../../..")

In [None]:
# general
cell_name = "cell1_211006_3148" # "cell1_211006_3148" | "cell1_211011_3436"
model_name = cell_name
ephys_dir = base_dir / "experimental_data" / cell_name / "patch_data"

cell_models_folder = base_dir / "cell_models"
model_folder = cell_models_folder / cell_name

In [None]:
# change this with folder containing your pkl file
results_date = '220209'  # '211124' '220111' # 
result_folder = base_dir / "results" / results_date

In [None]:
# select abd (axon_bearing_dendrite) option
abd = False

In [None]:
save_fig = False
figure_folder = Path(".") / f"figures_{cell_name}"

if save_fig:
    figure_folder.mkdir(exist_ok=True)

In [None]:
colors_dict = {"soma": "C0",
               "all": "C1",
               "sections": "C2",
               "single": "C3"}
feature_sets = {"soma": "soma",
                "all": "extra",
                "sections": "extra",
                "single": "extra"}
figsize = (10, 7)

# Define cell model and load optimization output

In [None]:
cell = mf.create_experimental_model(model_name=cell_name, abd=abd)
cell_sample = mf.create_experimental_model(model_name=cell_name, release=True, abd=abd)

probe = mf.define_electrode(probe_file=model_folder / "fitting" / "efeatures" / "probe_BPO.json")
param_names = [param.name for param in cell.params.values() if not param.frozen]

In [None]:
pkl_file_name = "runs.pkl"

In [None]:
data = pickle.load(open(result_folder / pkl_file_name, 'rb'))
df_optimization = pd.DataFrame(data)
df_model = df_optimization.query(f"model == '{cell_name}' and abd == {abd}")

In [None]:
# set strategy column
df_model.loc[:, "strategy"] = df_model["extra_strategy"].values.copy()
df_model.loc[df_model["feature_set"] == "soma", "strategy"] = "soma"

In [None]:
fig, ax = plt.subplots()
min_evals = 3000

keep_idxs = []
for idx, row in df_model.iterrows():
    if max(row["nevals"]) > min_evals:
        keep_idxs.append(idx)
        ax.plot(row["nevals"], 
                row["logbook"].select("min"),
                color=colors_dict[row["strategy"]],
                ls='-', 
                lw=0.8,
                alpha=0.75)
    else:
        ax.plot(row["nevals"], 
                row["logbook"].select("min"),
                color=colors_dict[row["strategy"]],
                ls='--', 
                lw=0.5,
                alpha=0.75)

ax.set_title("Min fitness")
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_xlabel("Neval")
ax.set_ylabel("Min fitness")
ax.set_yscale('log')

# Load protocols and experimental features

In [None]:
extra_kwargs = mf.utils.get_extra_kwargs()
extra_kwargs["ms_cut"] = [2, 5]

In [None]:
extra_kwargs

In [None]:
protocol_for_eap = "IDrest_300"

In [None]:
eva_extra = mf.create_evaluator(
    model_name=model_name,
    feature_set="extra",
    extra_strategy="all",
    protocols_with_lfp=protocol_for_eap,
    abd=abd,
    **extra_kwargs
)

## Load experimental responses

In [None]:
from bluepyefe.extract import read_recordings, extract_efeatures_at_targets, compute_rheobase,\
    group_efeatures, create_feature_protocol_files, convert_legacy_targets
from bluepyefe.plotting import plot_all_recordings_efeatures

from multimodalfitting.efeatures_extraction import build_wcp_metadata, wcp_reader, get_ecode_targets, \
    ecodes_wcp_timings

In [None]:
# define timings for this experiment
ecodes_cell_timings = {
    "IDthres": {
        'ton': 400,
        'toff': 670
    },
    "firepattern": {
        'ton': 500,
        'toff': 4100
    },
    "IV": {
        'ton': 400,
        'toff': 3400
    },
    "IDrest": {
        'ton': 400,
        'toff': 1750
    },
    "APWaveform": {
        'ton': 350,
        'toff': 400
    },
    "HyperDepol": {
        'ton': 400,
        'toff': 1120,
        'tmid': 850
    },
    "sAHP": {
        'ton': 400,
        'toff': 1325,
        'tmid': 650,
        'tmid2': 875
    },
    "PosCheops": {
        'ton': 1000,
        't1': 9000,
        't2': 10500,
        't3': 14500,
        't4': 16000,
        'toff': 18660
    }
}

In [None]:
if cell_name == "cell1_211006_3148":
    runs = [1, 2, 3, 4, 5]  # run1 --> different rheobase
elif cell_name == "cell1_211011_3436":
    runs = [3, 4, 5, 6]
    
ecode_names = list(ecodes_cell_timings.keys())

files_list = []

for run in runs:
    rep_dict = {}
    for ecode in ecode_names:
        for patch_file in ephys_dir.iterdir():
            if f"run{run}" in patch_file.name and ecode.lower() in patch_file.name:
                rep_dict[ecode] = patch_file
    files_list.append(rep_dict)

In [None]:
files_metadata = build_wcp_metadata(cell_id=cell_name, 
                                    files_list=files_list, 
                                    ecode_timings=ecodes_cell_timings, 
                                    repetition_as_different_cells=False)
pprint(files_metadata[cell_name])

In [None]:
cells = read_recordings(
    files_metadata=files_metadata,
    recording_reader=wcp_reader
)

In [None]:
# define target features for different protocols
targets = get_ecode_targets(ecodes_cell_timings)

In [None]:
global_tolerance = 30
for target in targets:
    target["tolerance"] = global_tolerance

In [None]:
if cell_name == "cell1_211006_3148":
    majority = 0.4
else:
    majority = 0.2

In [None]:
compute_rheobase(
    cells, 
    protocols_rheobase=['IDthres'],
    rheobase_strategy="majority",
    rheobase_settings={"majority": majority}
)

In [None]:
print(f"Cell rheobase: {cells[0].rheobase}")

In [None]:
protocols = group_efeatures(cells, targets, use_global_rheobase=True)

In [None]:
protocols_opt = list(eva_extra.fitness_protocols.keys())
print(protocols_opt)

In [None]:
all_protocols = list(ecodes_cell_timings.keys())
print(all_protocols)

### Build BPO response dicts

In [None]:
responses_experimental = []
responses_all = []
num_runs = len(protocols[0].recordings)
for run in range(num_runs):
    print(f"Populating responses for run {run}")
    response_dict = {}
    response_all_dict = {}
    for protocol in protocols:
        for i, rec in enumerate(protocol.recordings):
            if i == run:
                response = bpopt.ephys.responses.TimeVoltageResponse(name=protocol.name,
                                                                     time=rec.t, voltage=rec.voltage)
                response_all_dict[f"{protocol.stimulus_name}.soma.v"] = response
                if protocol.stimulus_name in protocols_opt:
                    response_dict[f"{protocol.stimulus_name}.soma.v"] = response
    responses_experimental.append(response_dict)
    responses_all.append(response_all_dict)

In [None]:
responses_all

In [None]:
responses_to_plot = ["APWaveform_290", "IDrest_250", "firepattern_200", "PosCheops_300"]
titles = responses_to_plot

In [None]:
fig_exp_intra = mf.plot_responses(responses_all[2], return_fig=True, 
                                  protocol_names=responses_to_plot)

In [None]:
fig_exp_intra = mf.plot_multiple_responses(responses_all, return_fig=True, 
                                           labels=[f"run{i}" for i in range(num_runs)])

In [None]:
fig_exp_intra = mf.plot_multiple_responses(responses_experimental, return_fig=True, 
                                           labels=[f"run{i}" for i in range(num_runs)])

In [None]:
eap_exp = np.load(model_folder / "fitting" / "efeatures" / "template_BPO.npy") / 1000

In [None]:
ax = mu.plot_mea_recording(eap_exp, probe)
fig_exp_extra = ax.get_figure()

In [None]:
vscale = np.max(np.abs(eap_exp))

In [None]:
if save_fig:
    fig_exp_intra.savefig(figure_folder / "exp_intra.png", dpi=300)
    fig_exp_extra.savefig(figure_folder / "exp_extra.png", dpi=300)

# Compute and plot best responses

In [None]:
max_feature_value = 50
opt_results = {}

In [None]:
for strategy in np.unique(df_model.strategy):
    opt_results[strategy] = {}
    opt_df = df_model.query(f"strategy == '{strategy}'")
    best_idx = np.argmin(opt_df.best_fitness)
    params_sample = opt_df.iloc[best_idx]
    params_dict = {k: v for k, v in zip(param_names, params_sample.best_params)}
    opt_results[strategy]["best_fitness"] = params_sample.best_fitness
    opt_results[strategy]["best_params"] = params_dict
    print(f"{strategy} --  best fitness: {params_sample.best_fitness}")

In [None]:
for strategy in np.unique(df_model.strategy):
    print(f"Simulating best '{strategy}'")
    responses = eva_extra.run_protocols(eva_extra.fitness_protocols.values(), 
                                        param_values=opt_results[strategy]["best_params"])
    opt_results[strategy]["responses"] = responses
    eap = mf.utils.calculate_eap(responses=responses, protocols=eva_extra.fitness_protocols, 
                                 protocol_name=protocol_for_eap, align_extra=True, **extra_kwargs)
    opt_results[strategy]["eap"] = eap   
    eap_dist = distance.cosine(eap_exp.ravel(), eap.ravel())
    opt_results[strategy]["eap_dist"] = eap_dist

In [None]:
for strategy in np.unique(df_model.strategy):
    responses = opt_results[strategy]["responses"]
    features_best = {}
    for obj in eva_extra.fitness_calculator.objectives:
        feat = obj.features[0]
        features_best[feat.name] = {}
        if len(obj.features) == 1:
            feat_value = obj.features[0].calculate_feature(responses)
            if feat_value is None:
                feat_value = max_feature_value
            features_best[feat.name]["value"] = feat_value
            if "MEA" not in feat.name:
                feat_score = np.abs(feat.exp_mean - feat_value) / feat.exp_std
            else:
                feat_score = np.abs(distance.cosine(feat.exp_mean, feat_value))
            features_best[feat.name]["score"] = feat_score
        else:
            print(f"More than one feature for objective: {obj.name}")
    opt_results[strategy]["features"] = features_best

In [None]:
# plot
response_experimental_plot = responses_experimental[1]
for strategy in np.unique(df_model.strategy):
    responses_to_plot = [response_experimental_plot, opt_results[strategy]["responses"]]
    eap = opt_results[strategy]["eap"]
    colors = ["k", colors_dict[strategy]]
    labels = ["GT", strategy.upper()]
    fig_extra_intra_single = mf.plot_multiple_responses(responses_to_plot, 
                                                        colors=colors, return_fig=True, 
                                                        labels=labels)
    fig, ax_extra = plt.subplots(figsize=figsize)
    ax_extra = mu.plot_mea_recording(eap_exp, probe, vscale=vscale, lw=1, ax=ax_extra)
    ax_extra.get_lines()[-1].set_label("GT")
    ax_extra = mu.plot_mea_recording(eap, probe, ax=ax_extra, vscale=vscale, 
                                     colors=colors_dict[strategy], lw=1)
    ax_extra.get_lines()[-1].set_label(strategy.upper())
    ax_extra.set_title("EAP", fontsize=15)
    ax_extra.legend()
#     ax_extra = mf.plot_multiple_eaps(responses_to_plot, 
#                                      eva_extra.fitness_protocols, probe,
#                                      protocol_name=protocol_for_eap, 
#                                      colors=colors, labels=labels)

In [None]:
for strategy in opt_results:
    print(f"Cosine dist {strategy}: {opt_results[strategy]['eap_dist']}")

## Compare best-fitted models

In [None]:
df_test = pd.DataFrame.from_dict(opt_results, orient="index")
df_test["strategy"] = df_test.index

### Compare features

In [None]:
feature_name_array = []
feature_set_array = []
feature_score_array = []
feature_type_array = []

for strategy, res in opt_results.items():
    feats = res["features"]
    for feat_name, feat_dict in feats.items():
        feature_set_array.append(strategy)
        feature_name_array.append(feat_name)
        if "MEA" not in feat_name:
            feature_type_array.append("intra")
        else:
            feature_type_array.append("extra")
        feature_score_array.append(feat_dict["score"])
        
df_feats = pd.DataFrame({"feature_set": feature_set_array, "feat_name": feature_name_array,
                         "feat_score": feature_score_array, "feature_type": feature_type_array})

In [None]:
fig_feat_intra, ax = plt.subplots(figsize=figsize)

sns.boxplot(data=df_feats.query("feature_type == 'intra'"), y="feature_set", x="feat_score", ax=ax)
ax.set_ylabel("Feature scores (intracellular)", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title("Intracellular features", fontsize=15)

fig_feat_extra, ax = plt.subplots(figsize=figsize)

sns.boxplot(data=df_feats.query("feature_type == 'extra'"), 
              y="feature_set", x="feat_score", ax=ax)
ax.set_ylabel("Feature scores (extracellular)", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title("Extracellular features", fontsize=15)


### Compare EAP dist

In [None]:
fig_cos, ax = plt.subplots()
sns.barplot(data=df_test, x="strategy", y="eap_dist", ax=ax)
ax.set_ylabel("Cosine distance", fontsize=12)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title("Extracellular difference", fontsize=15)