# Test model example 

This notebook tests the different models (HAY, HAY_AIS, HAY_AIS_HILLOCK) on a simple step stimulus.

In [None]:
import json
import sys
import os
import time
import numpy as np
from pathlib import Path

import matplotlib.pyplot as plt

import bluepyopt as bpopt
import bluepyopt.ephys as ephys

import MEAutility as mu

import shutil

import multimodalfitting as mf

%matplotlib widget

In [None]:
base_dir = Path("..")

In [None]:
probe_type = "planar"
model_name = "hay_ais" # "hay", "hay_ais", "hay_ais_hillock"
cell_models_folder = base_dir / "cell_models"

model_folder = cell_models_folder / f"{model_name}"

In [None]:
ais_recordings = mf.utils.get_ais_extra_recordings()

In [None]:
# instantiate cell model
cell_opt = mf.create_ground_truth_model(model_name=model_name, release=False)
param_names = [p_name for (p_name, p) in cell_opt.params.items() if not p.frozen]

In [None]:
param_names

In [None]:
cell_opt.params

In [None]:
# define a probe (this can also be defined externally)
probe = mf.define_electrode(probe_type=probe_type)

# instantiate cell model
cell = mf.create_ground_truth_model(model_name=model_name, release=True, electrode=probe)

# define a test protocol
protocols = mf.define_test_step_protocol(step_amplitude=0.5, tot_duration=2000, step_duration=1500, probe=probe,
                                         extra_recordings=dict(TestStep=ais_recordings))

# instantiate simulator
sim = ephys.simulators.LFPySimulator(cvode_active=True, mechanisms_directory=model_folder)

In [None]:
t_start = time.time()
responses = protocols["TestStep"].run(cell, param_values={}, sim=sim)
print(f"Elapsed time: {time.time() - t_start}")

In [None]:
fig = mf.plot_responses(responses)

In [None]:
eap = mf.calculate_eap(responses, protocols=protocols, protocol_name="TestStep")

In [None]:
max_chan, _ = np.unravel_index(np.argmin(eap), eap.shape)

In [None]:
ax_eap.plot(*probe.positions[max_chan, :-1], color="k", marker="o")

In [None]:
pro

In [None]:
fig_eap, ax_eap = plt.subplots()
ax_eap = mu.plot_probe(probe, ax=ax_eap, type="planar") #, alpha_prb=0.1, alpha_elec=0.3)
ax_eap = mu.plot_mea_recording(eap, probe, colors="k", ax=ax_eap, lw=0.5)
ax_eap = mu.plot_mea_recording(eap, probe, channels=[max_chan], lw=2, colors="g", ax=ax_eap)
mf.plot_cell(cell, sim=sim, color_ais="g", color_myelin="b", detailed=True, ax=ax_eap)

In [None]:
fig_eap.savefig("eap_probe_zoom.pdf", dpi=300, transparent=True)

In [None]:
ax_eap.plot(*probe.positions[max_chan, :-1], color="k", marker="o")

In [None]:
fig, ax = plt.subplots()
mf.plot_cell(cell, sim=sim, color_ais="g", color_myelin="b", detailed=True, ax=ax)

In [None]:
fig.savefig("hay_ais_zoom.pdf", dpi=300, transparent=True)

In [None]:
responses.keys()

In [None]:
responses_to_plot = ['TestStep.soma.v', 'TestStep.ais_mid_v.v', 'TestStep.ais_end_v.v']

In [None]:
fig_ap, ax = plt.subplots()
labels = ["soma", "AIS middle", "AIS distal"]
for i, resp in enumerate(responses_to_plot):
    response = responses[resp]
    ax.plot(response["time"], response["voltage"], label=labels[i], lw=2, alpha=0.8)
ax.set_xlim(554, 562)
ax.legend()
ax.axis("off")
ax.plot([555, 556], [-80, -80], color="k")
ax.plot([555, 555], [-80, -60], color="k")
# ax.set_xlabel("Time (ms)", fontsize=15)
# ax.set_ylabel("Vm (mV)", fontsize=15)

In [None]:
fig.savefig("hay_ais_ap_reponses.pdf", dpi=300, transparent=True)

In [None]:
fig.savefig("ais_traces.pdf")

## (optional) save EAP for comparison among models

In [None]:
eap_folder = Path("eap")
eap_folder.mkdir(exist_ok=True)

np.save(eap_folder / f"eap_{model_name}.npy", eap)



# plot ecode responses

In [None]:
response_folder = model_folder / "fitting" / "responses"

In [None]:
import pandas as pd

In [None]:
ecode_response_dict = {}
for protocol_folder in response_folder.iterdir():
    protocol_name = protocol_folder.name
    ecode_response_dict[protocol_name] = []
    for sweep in protocol_folder.iterdir():
        response = pd.read_csv(sweep)
        ecode_response_dict[protocol_name].append(response)

In [None]:
ecode_response_dict.keys()

In [None]:
protocol_order = ['firepattern', 'IV', 'APWaveform', 'IDrest',  'HyperDepol', 'sAHP', 'PosCheops']

time_scalebar = {'APWaveform': 25, 'PosCheops': 1000}

figs_responses = {}

for protocol in protocol_order:
    fig, axs = plt.subplots(nrows=2, figsize=(10, 10), sharex=True)
    for i, sweep in enumerate(ecode_response_dict[protocol]):
        axs[0].plot(sweep["time"], sweep["voltage"], alpha=0.8, color=f"grey")
        axs[1].plot(sweep["time"], sweep["current"], alpha=0.8, color=f"grey", lw=2)
    axs[0].axis("off")
    axs[1].axis("off")    
    if protocol in time_scalebar:
        scalebar = time_scalebar[protocol]
    else:
        scalebar = 100
    axs[0].plot([100, 100 + scalebar], [-100, -100], color="k")
    axs[0].plot([100, 100], [-100, -110], color="k")
    axs[1].plot([100, 100], [0.1, 0.2], color="k")
    figs_responses[protocol] = fig

In [None]:
for fig_name, fig in figs_responses.items():
    fig.savefig(f"{fig_name}.pdf", transparent=True)

In [None]:
sweep_example0 = ecode_response_dict["IDrest"][0]
sweep_example1 = ecode_response_dict["IDrest"][1]

In [None]:
plt.figure()
plt.plot(sweep_example0["time"], sweep_example0["current"])
plt.plot(sweep_example1["time"], sweep_example1["current"])