# Validation - distribution of optimized AP potentials

This notebook evaluates qualitatively the distributions of membrane potentials and transmembrane currents of an action potential over the neuron morphology.

In [None]:
import pickle
import pandas as pd
import seaborn as sns
import sys
import shutil

import bluepyopt as bpopt
import bluepyopt.ephys as ephys

import matplotlib.pyplot as plt
from scipy.spatial import distance
import MEAutility as mu
import json
import time
import numpy as np
from pathlib import Path

import pandas as pd
import seaborn as sns
from scipy.spatial.distance import cosine

import multimodalfitting as mf

%matplotlib notebook

In [None]:
base_path = Path("../..")

In [None]:
colors_dict = {"soma": "C0",
               "all": "C1",
               "sections": "C2",
               "single": "C3"}
figsize = (10, 7)

## Load results and define model

In [None]:
probe_type = "planar"
model_name = "hay_ais" # "hay", "hay_ais", "hay_ais_hillock"
cell_models_folder = base_path / "cell_models"

model_folder = cell_models_folder / f"{model_name}"

In [None]:
# change this with folder containing your pkl file
results_date = '220209'  # '211124' '220111' # 
result_folder = base_path / "results" / results_date

In [None]:
data = pickle.load(open(result_folder / pkl_file_name, 'rb'))
df_optimization = pd.DataFrame(data)
df_model = df_optimization.query(f"model == '{model_name}'")
# set strategy column
df_model.loc[:, "strategy"] = df_model["extra_strategy"].values.copy()
df_model.loc[df_model["feature_set"] == "soma", "strategy"] = "soma"

if (result_folder / "opt_results.pkl").is_file():
    with open(result_folder / "opt_results.pkl", 'rb') as f:
        opt_results_all = pickle.load(f)

In [None]:
protocols_file = model_folder / "fitting" / "efeatures" / "protocols_BPO_all.json"
features_file = model_folder / "fitting" / "efeatures" / "features_BPO_all.json"

In [None]:
cell = mf.create_ground_truth_model(model_name=model_name, release=False)
cell_release = mf.create_ground_truth_model(model_name=model_name, release=True)

probe = mf.define_electrode(probe_type=probe_type)

param_names = [param.name for param in cell.params.values() if not param.frozen]

params_release = {}
for param in cell_release.params_by_names(param_names):
    params_release[param.name] = param.value

In [None]:
# protocol_for_eap = "IDrest_300"
protocol_for_eap = "firepattern_120"

### Define more recording points 

In [None]:
extra_kwargs = mf.utils.get_extra_kwargs()

In [None]:
eva_extra = mf.create_evaluator(
    model_name=model_name,
    feature_set="extra",
    extra_strategy="all",
    all_protocols=True,
    protocols_with_lfp=protocol_for_eap,
    **extra_kwargs
)

In [None]:
mf.plot_cell(eva_extra.cell_model, eva_extra.sim, param_values=params_release, color_ais="g")

In [None]:
positions = np.array([[-62, 828], [-3, 954], [-27, 546], [-27, 85], [134, -28], [-117, -189],
                     [13.4, -28.7], [11.7, -1.2]]) #[12.5, -15.6],
position_names = ["apical_distal_left", "apical_distal_right", "apical_middle", "apical_proximal", 
                  "basal_right", "basal_left", "ais_distal", "ais_proximal"] #"ais_middle", 

In [None]:
extra_recordings = mf.utils.extra_recordings_from_positions(cell_release, eva_extra.sim, positions, position_names)

In [None]:
extra_recordings

In [None]:
eva_extra = mf.create_evaluator(
    model_name=model_name,
    feature_set="extra",
    extra_strategy="all",
    protocols_with_lfp=protocol_for_eap,
    all_protocols=True,
    extra_recordings={protocol_for_eap: extra_recordings}, 
    **extra_kwargs
)

# Load protocols and original features

In [None]:
protocol_to_run = eva_extra.fitness_protocols[protocol_for_eap]
print(protocol_to_run)

In [None]:
# opt_results = {}
# for strategy in np.unique(df_model.strategy):
#     opt_results[strategy] = {}
#     opt_df = df_model.query(f"strategy == '{strategy}'")
#     best_idx = np.argmin(opt_df.best_fitness)
#     params_sample = opt_df.iloc[best_idx]
#     params_dict = {k: v for k, v in zip(param_names, params_sample.best_params)}
#     opt_results[strategy]["best_fitness"] = params_sample.best_fitness
#     opt_results[strategy]["best_params"] = params_dict
#     print(f"{strategy} --  best fitness: {params_sample.best_fitness}")

In [None]:
print("Computing RELEASE")
responses_release = eva_extra.run_protocol(idrest, params_release)

for strategy in opt_results_all:
    print(f"Computing '{strategy}' -- seed: {opt_results_all[strategy]['best_seed']}")
    best_params = opt_results_all[strategy]["best_params"]
    responses = eva_extra.run_protocol(idrest, best_params)
    opt_results[strategy]["responses"] = responses

In [None]:
mf.plot_responses(opt_results["all"]["responses"], protocol_names=["soma", "ais"])

In [None]:
mf.plot_responses(opt_results["all"]["responses_cut"], protocol_names=["soma", "ais"])

In [None]:
ms_after = 50
ms_before = 10

In [None]:
responses_cut_release = mf.utils.get_peak_cutout(responses_release, ms_before=ms_before, 
                                                 ms_after=ms_after, average=True)

for strategy in np.unique(df_model.strategy):
    print(f"Cutting {strategy}")
    responses = opt_results[strategy]["responses"]
    responses_cut = mf.utils.get_peak_cutout(responses, ms_before=ms_before, 
                                             ms_after=ms_after, average=True)
    opt_results[strategy]["responses_cut"] = responses_cut

In [None]:
responses_to_plot = []
for resp in responses_release:
    if any(pos_name in resp for pos_name in position_names):
        responses_to_plot.append(resp)

In [None]:
xlims = {"apical_distal_left": [-ms_before, ms_after], 
         "apical_distal_right": [-ms_before, ms_after],
         "apical_middle": [-ms_before, ms_after], 
         "apical_proximal": [-5, 20], 
         "basal_right": [-3, 10],
         "basal_left": [-3, 10],
         "ais_distal": [-2, 4],
         "ais_middle": [-2, 4], 
         "ais_proximal": [-2, 4]}

In [None]:
figs_traces = {}

In [None]:
exclude_strategies = []

In [None]:
add_title = True

In [None]:
distance_arr = []
strategy_arr = []
position_arr = []
for resp_name in responses_to_plot:
    fig, ax = plt.subplots(figsize=figsize)
    position = resp_name.split(".")[1]
    times = np.linspace(-ms_before, ms_after, len(responses_cut_release[resp_name]["time"]))
    ax.plot(times, responses_cut_release[resp_name]["voltage"], 
            color="k", label="GT", lw=3)
    for strategy in opt_results:
        if strategy not in exclude_strategies:
            resp_cut = opt_results[strategy]["responses_cut"]
            ax.plot(times, resp_cut[resp_name]["voltage"], 
                    color=colors_dict[strategy], label=strategy.upper(), alpha=0.7, lw=1.5)
#             dist = cosine(responses_cut_release[resp_name]["voltage"],
#                           resp_cut[resp_name]["voltage"])
            dist = np.sum(np.abs(responses_cut_release[resp_name]["voltage"] - resp_cut[resp_name]["voltage"]))
            position_arr.append(position)
            distance_arr.append(dist)
            strategy_arr.append(strategy)
    ax.set_xlabel("time (ms)", fontsize=20)
    ax.set_ylabel("$V_m$ (mV)", fontsize=20)
    ax.set_xlim(xlims[position])
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    if add_title:
        ax.set_title(f"{position}", fontsize=15)
    ax.axvline(0, color="gray", ls="--")
    ax.legend(fontsize=20)
    figs_traces[position] = fig

df = pd.DataFrame({"strategy": strategy_arr, "distance": distance_arr, "position": position_arr})

In [None]:
figs_dists = {}

In [None]:
order = ["soma", "all", "sections", "single"]

In [None]:
for position in np.unique(df["position"]):
    print(position)
    fig, ax = plt.subplots(figsize=figsize)
    df_pos = df.query(f"position == '{position}'")
    sns.barplot(data=df_pos, x="strategy", y="distance", ax=ax, order=order)
#                 color=[colors_dict[o] for o in order])
    ax.set_xlabel("Strategy", fontsize=20)
    ax.set_ylabel("cos. dist.", fontsize=20)
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    if add_title:
        ax.set_title(f"{position}", fontsize=15)
    figs_dists[position] = fig

In [None]:
fig_cell, ax_cell = plt.subplots(figsize=(7, 10))
mf.plot_cell(eva_extra.cell_model, eva_extra.sim, param_values=params_release, ax=ax_cell,
             alpha=0.5, color="black", detailed=True)

cmap = "tab20"
cm = plt.get_cmap(cmap)
for i, (pos, pos_name) in enumerate(zip(positions, position_names)):
    color = cm(i / len(positions))
    ax_cell.plot(*pos, label=pos_name, marker="o", alpha=0.8, color=color, markersize=3)

In [None]:
save_fig = True
# figure_folder = Path(".") / "figures_hay_ais"

figure_folder = Path("/Users/abuccino/Documents/Submissions/papers/multimodal/hay_ais/ap_distr")

if save_fig:
    figure_folder.mkdir(exist_ok=True)

In [None]:
if save_fig:
    fig_cell.savefig(figure_folder / "cell.pdf", transparent=True)
    
    for position in figs_traces:
        figs_traces[position].savefig(figure_folder / f"trace_{position}.pdf", transparent=True)
        figs_dists[position].savefig(figure_folder / f"dist_{position}.pdf", transparent=True)        