# PV-IN Axon Dysfunction in Dravet syndrome


### build NEURON mod files

In [None]:
!nrnivmodl mechanisms

### imports and setup

In [None]:
%load_ext autoreload
%autoreload 2

from pathlib import Path
from functools import lru_cache
from itertools import product
from neuron import h
from neuron import gui
from tqdm import tqdm
from typing import Iterable, Union
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from pv_nrn import get_pv, get_pv_params, reset_biophys

from src.constants import *
from src.settings import *
from src.data import get_cached_df, get_file_path, set_cache_root, get_cache_root, wide_to_long, concise_df
from src.measure import get_max_propagation, get_ap_times, calculate_failures
from src.run import get_trace, set_relative_nav11bar, set_nrn_prop
from src.utils import get_key, format_nav_loc, perc_decrease, str_to_tuple, nearest_idx, nearest_value, get_last_sec
from src.vis import plot_voltage_trace, set_default_style, save_fig, get_pulse_times, get_pulse_xy

set_cache_root("D:\\.cache")

h.nrn_load_dll("nrnmech_def.dll")

# load PV template
h.load_file('PV_template.hoc')
h.load_file('PV_template_orig.hoc')

# values as per optimisation by BBP
h.celsius = 34
h.v_init = -80
h.check_simulator()

set_default_style()

# Vary the fraction of Nav1.1 in specific sections


## Voltage traces

### create method for running multiple simulations

In [None]:
def run_sims(pv, stims, nav_loc_changes, fractions, dur, load=False, arrow=False, reset_biophys=reset_biophys):
    # note that we 'tuple' the product generator to convert it to an iterable of known length for the progressbar
    pbar = tqdm(tuple(product(stims, nav_loc_changes, fractions)))
    
    base_nav = reset_biophys(pv)
    
    results = {}
    
    for stim, nav_loc, frac in pbar:
        amp, freq = stim
        key_name = get_key(pv, frac, nav_loc, stim, dur)
        pbar.set_description(f"{key_name}")

        path = get_file_path(key_name)
        long_format_path = get_file_path(key_name, ext="arrow")

        x_df = None
        if not path.exists():
            pbar.set_description(f"{key_name} running")
            reset_biophys(pv, display=False)
            if isinstance(nav_loc, str):
                set_relative_nav11bar(pv, frac, at=nav_loc, base=base_nav[nav_loc])
            else:
                # is an iterable of locations to change
                for _nav_loc in nav_loc:
                    set_relative_nav11bar(pv, frac, at=_nav_loc, base=base_nav[_nav_loc])
            # run sim and save results

            AP, x_df = get_cached_df(key_name, pv, amp, dur, stim_freq=freq, shape_plot=True)

        if arrow and not long_format_path.exists():
            """Save in .feather format, to be loaded using vaex and arrow"""
            if x_df is None:
                # load results
                pbar.set_description(f"{key_name} loading")
                AP, x_df = get_cached_df(key_name)
            # format data
            long_df = wide_to_long(x_df)
            # add metadata as columns with uniform data along the rows
            long_df[NAV_FRAC_LABEL] = frac
            long_df[NAV_PERC_LABEL] = perc_decrease(frac)
            long_df[NAV_SECTIONS_LABEL] = format_nav_loc(nav_loc)
            long_df[CURRENT_LABEL] = amp
            long_df["Stim. duration"] = dur
            long_df[STIM_FREQ_LABEL] = freq
            long_df["key"] = key_name

            # save
            pbar.set_description(f"{key_name} saving")
            long_df.to_feather(long_format_path)
        
        if load:
            if x_df is None:
                # load results
                pbar.set_description(f"{key_name} loading")
                AP, x_df = get_cached_df(key_name)
            # store in dict
            results[key_name] = {
                                "df": wide_to_long(x_df),
                                NAV_FRAC_LABEL: frac,
                                NAV_PERC_LABEL: perc_decrease(frac),
                                NAV_SECTIONS_LABEL: format_nav_loc(nav_loc),
                                CURRENT_LABEL: amp,
                                "Stim. duration": dur,
                                STIM_FREQ_LABEL: freq,
                                "APCount": AP
                               }
        
        pbar.set_description(f"done")

    return results
    

### run method with desired parameters

In [None]:
dur = 100 # excludes STIM_ONSET time (see settings.py)

# list of (current, frequency) stimulations
stims = [(0.75, 0)]

fractions = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0]
nav_loc_changes = [
                   "somatic", 
                   "ais", 
                   "nodes", 
                   ("somatic", "ais"),
                   ("ais", "nodes"),
                   ("somatic", "ais", "nodes") # control (no nav1.1 anywhere)
                   ]

formatted_nav_loc_changes = [format_nav_loc(nv) for nv in nav_loc_changes]

# create neuron and set nav11
pv = get_pv(name="default", node_spacing=33, node_length=1., ais_L=26.5)
base_nav = reset_biophys(pv)

amp_results = run_sims(pv, stims, nav_loc_changes, fractions, dur, load=True)


### create method for plotting multiple traces in a grid - NaV1.1 Fraction (columns) vs NaV1.1 Section (rows)

In [None]:
def plot_voltage_trace_grid(amp_results, plot_fractions, plot_nav_loc_changes, 
                            time_windows=None, max_dur=-1, stim_onset=0, save=True, 
                            failure_tol=2., marker_kws=None,
                            **kwargs):
    thresh = kwargs.pop("thresh", False)
    default_marker_kws = dict(c='b', marker='v', lw=0.1, alpha=0.5, zorder=9)
    ymax = 50

    if marker_kws is None:
        marker_kws = default_marker_kws
    else:
        # overwrite defaults
        marker_kws = {**default_marker_kws, **marker_kws}
    
    if time_windows is not None and not np.iterable(time_windows[0]):
        # a single time_window was passed
        time_windows = [time_windows]*len(plot_fractions)
    
    figsize = (len(plot_fractions)*1.5,len(plot_nav_loc_changes))
    
    fig, axes = plt.subplots(ncols=len(plot_fractions), 
                             nrows=len(plot_nav_loc_changes),
                             squeeze=False,
                             sharey=True, sharex=False, figsize=figsize,
                             gridspec_kw=dict(wspace=0.2, hspace=0.1))
    fig_wide, axes_wide = plt.subplots(ncols=len(plot_fractions), 
                                       nrows=len(plot_nav_loc_changes),
                                       squeeze=False,
                                       sharey=True, sharex=True, figsize=figsize,
                                       gridspec_kw=dict(wspace=0.2, hspace=0.1))
    

    for key, result in tqdm(amp_results.items()):
        
        frac = result[NAV_FRAC_LABEL]
        nav_loc = result[NAV_SECTIONS_LABEL]

        if frac not in plot_fractions or nav_loc not in plot_nav_loc_changes:
            continue

        long_df = result["df"].copy()
        
        long_df[TIME_LABEL] = long_df[TIME_LABEL]-stim_onset
        if max_dur>0:
            long_df = long_df[long_df[TIME_LABEL]<=max_dur]
            

        j = plot_fractions.index(frac)
        i = plot_nav_loc_changes.index(nav_loc)
        ap_times_ais = get_ap_times(long_df, sec="axon[0]", thresh=-20)
        ap_times_term = get_ap_times(long_df, sec=get_last_sec(long_df))
        failure_times = calculate_failures(ap_times_ais, ap_times_term, tol=failure_tol)
        if time_windows is None:
            # get last action potential (at soma) for this trace
            time_window = (ap_times_ais[-2]-1, ap_times_ais[-2]+6)
        else:
            time_window = (time_windows[i][0] - stim_onset, time_windows[i][1] - stim_onset)

        zoom_df = long_df[(long_df[TIME_LABEL]>=time_window[0]-0.5)&(long_df[TIME_LABEL]<=time_window[1]+0.5)]
        zoom_failures = failure_times[(failure_times>=time_window[0]-0.5) & (failure_times<=time_window[1]+0.5)]
        
        # display
        for _df, _axes, fts in zip([zoom_df, long_df],[axes, axes_wide], [zoom_failures, failure_times]):
            ax = _axes[i, j]
            lw = 0.5 if ax in axes_wide else 2
            size = lw*8
            plot_voltage_trace(_df, 
                               concise=True,
                               thresh=thresh, 
                               legend="brief" if (i==0 and j==(len(select_fractions)-1)) else False,
                               alpha=1,
                               palette=SECTION_PALETTE,
                               lw=lw,
                               ax=ax)
            y = np.ones(shape=len(fts))*ymax
            ax.scatter(fts+failure_tol, y, s=size, **marker_kws)
            # format plot
            if i==0:
                # first row
                ax.set(title=f"{frac*100:.0f}%")
            if j==0:
                ax.annotate(nav_loc.replace("+","+\n"),
                        xy=(-0.2, 1), xycoords="axes fraction",
                        ha="right", va="top", rotation=0,
                        fontsize="medium")

        axes[i,j].set(xlim=time_window)
        axes[i,j].patch.set_facecolor('k')
        axes[i,j].patch.set_alpha(0.1)

        ylim = long_df[VOLTAGE_LABEL].min(), long_df[VOLTAGE_LABEL].max()
        rect = plt.Rectangle((time_window[0],ylim[0]),time_window[1]-time_window[0],ymax-ylim[0], linewidth=1, edgecolor='None', facecolor='k', alpha=0.2)
        axes_wide[i,j].add_artist(rect)
    # pretty the figure

    fig.suptitle(NAV_PERC_LABEL, va="bottom")
    fig_wide.suptitle(NAV_PERC_LABEL, va="bottom")
    
    # format axes
    for _axes in [axes, axes_wide]:
        xmax = 5 if _axes in axes else 100
        for i, ax_row in enumerate(_axes):
            for j, ax in enumerate(ax_row):
                if i==axes.shape[0]-1 and j==0:
                    sns.despine(ax=ax, left=True, bottom=True)
                    # draw scale bars
                    xticks, yticks = ax.get_xticks(), ax.get_yticks()
                    xsize = xticks[-2] - xticks[1]
                    ysize = yticks[-2] - yticks[1]
                    x1 = (xticks[1], xticks[1]+xsize)
                    y1 = (-80, -80)
                    x2 = (xticks[1], xticks[1])
                    y2 = (-80, -80+ymax)
                    ax.plot(x1, y1, lw=1, color='k', clip_on=False)
                    ax.plot(x2, y2, lw=1, color='k', clip_on=False)
                    ax.set_clip_on(False)
                    ax.set_ylabel(f"{ymax:.0f} mV")
                    ax.set_xlabel(f"{xsize:.0f} ms")
                    ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
                else:
                    if i==0 and j==axes.shape[1]-1:
                        ax.legend(title=SITE_LABEL, title_fontsize="medium",
                                  loc="upper left", bbox_to_anchor=(1,1))
                    sns.despine(ax=ax, bottom=True, left=True)
                    ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
                    ax.set_xlabel("")
                    ax.set_ylabel("")
    

    
    if save:
        if plot_fractions != fractions:
            save_name = f"save/fig1a_{amp}"
        else:
            save_name = f"save/suppfig1_{amp}"
        save_fig(f"{save_name}_zoom", fig=fig)
        save_fig(f"{save_name}_wide", fig=fig_wide)

### plot chosen parameters 

In [None]:
from src.vis import get_pulse_times
chosen_stim = (0.75, 0)
amp, freq = chosen_stim
select_fractions = [1, 0.5]
select_nav_loc_changes = ["Soma+AIS+Nodes", "Soma+AIS", "Nodes"]
other_nav_changes = ["AIS+Nodes", "Soma", "AIS"]

pulse_times = get_pulse_times(freq, dur)

pulse_idx = 4 # zoom into 5th action potential

stim_onset = 20 # defined in run.py - stimulus only starts after 20 ms
ap_times = get_ap_times(list(amp_results.values())[0]["df"])

pv = get_pv(name="default", node_spacing=33, node_length=1., ais_L=26.5)

key_name = get_key(pv, 0.5, "nodes", chosen_stim, dur)
nav_loc_nodes_ap_times = get_ap_times(amp_results[key_name]["df"])

ww = window_width = (1, 6)
# time_window = (pulse_times[pulse_idx]-ww[0], pulse_times[pulse_idx]+ww[1])
# last_time_window = (pulse_times[-1]-ww[0], pulse_times[-1]+ww[1])
time_windows = [(ap_times[pulse_idx]-ww[0]-0.5, ap_times[pulse_idx]+ww[1]-0.5),
                (ap_times[pulse_idx+2]-ww[0]-1.5, ap_times[pulse_idx+2]+ww[1]-1.5),
                (ap_times[pulse_idx+4]-ww[0], ap_times[pulse_idx+4]+ww[1]),
               ]
plot_voltage_trace_grid(amp_results, select_fractions, select_nav_loc_changes, time_windows=time_windows, stim_onset=stim_onset, save=True)

### plot all fractions and sections

In [None]:
plot_voltage_trace_grid(amp_results, fractions, formatted_nav_loc_changes, time_windows=None, stim_onset=stim_onset, save=True)

## Firing rate

### Helper methods

In [None]:
# note no scale
get_ap = lambda ap: (ap["soma"].n, ap["init"].n, ap["comm"].n)

def get_all_ap_times(long_df, last_sec=None, thresh=0):
    # soma
    ap_times_soma = get_ap_times(long_df, thresh=thresh, sec="soma[0]")
    # ais
    ap_times_ais = get_ap_times(long_df, thresh=thresh, sec="axon[1]")
    # terminal
    if last_sec is None:
        last_sec = get_last_sec(long_df)
    ap_times_term = get_ap_times(long_df, thresh=thresh, sec=last_sec)
    return ap_times_soma, ap_times_ais, ap_times_term


def get_IFR(all_ap_times):
    ap_times_soma, ap_times_ais, ap_times_term = all_ap_times
    # soma
    if len(ap_times_soma)>1:
        fr_soma = 1000/np.diff(ap_times_soma).mean()
    else:
        fr_soma = 0

    # ais
    if len(ap_times_ais)>1:
        fr_ais = 1000/np.diff(ap_times_ais).mean()
    else:
        fr_ais = 0  

    # terminal
    if len(ap_times_term)>1:
        fr_term = 1000/np.diff(ap_times_term).mean()
    else:
        fr_term = 0
        
    return fr_soma, fr_ais, fr_term

def calculate_failure_rates(pulse_times, all_ap_times):
    ap_times_soma, ap_times_ais, ap_times_term = all_ap_times
    
#     failure_times_prop = calculate_failures(ap_times_ais, ap_times_term, tol=2.)
#     failure_times_init = calculate_failures(pulse_times, ap_times_ais, tol=2.)
#     failure_times_stim = calculate_failures(pulse_times, ap_times_term, tol=2.)

    failure_rate_init = (len(pulse_times)-len(ap_times_ais))/len(pulse_times)
    if failure_rate_init < 0.95 and len(ap_times_ais)>0:
        failure_rate_prop = (len(ap_times_ais)-len(ap_times_term))/len(ap_times_ais)
    else:
        failure_rate_prop = 0
    failure_rate_stim = (len(pulse_times)-len(ap_times_term))/len(pulse_times)
    
    return {
        "Propagation": failure_rate_prop, 
        "Initiation": failure_rate_init, 
        "Stimulation": failure_rate_stim
    }
                    


In [None]:
max_dur = 100

ap_cols = [SOMA_LABEL, AIS_LABEL, TERMINAL_LABEL]

ap_df = pd.DataFrame(columns=ap_cols)
ap_long_df = pd.DataFrame()
fail_df = pd.DataFrame()
delay_df = pd.DataFrame()

for key, result in tqdm(amp_results.items()):
    apcounts = result["APCount"]
    long_df = result["df"].copy()

    long_df = long_df[long_df[TIME_LABEL]<=max_dur]
    
    all_ap_times = get_all_ap_times(long_df, get_last_sec(long_df), thresh=-20)
    
    ap_soma, ap_init, ap_comm = get_ap(apcounts)
    fr_soma, fr_ais, fr_term = get_IFR(all_ap_times)

    base_info = {NAV_FRAC_LABEL: result[NAV_FRAC_LABEL],
                 NAV_SECTIONS_LABEL: result[NAV_SECTIONS_LABEL]}

    ap_long_df = ap_long_df.append([{SECTION_LABEL: SOMA_LABEL, INSTA_FR_LABEL: fr_soma, AP_LABEL: ap_soma, **base_info},
                            {SECTION_LABEL: AIS_LABEL, INSTA_FR_LABEL: fr_ais, AP_LABEL: ap_init, **base_info},
                            {SECTION_LABEL: TERMINAL_LABEL, INSTA_FR_LABEL: fr_term, AP_LABEL: ap_comm, **base_info}],
                           ignore_index=True)
    
ap_long_df = ap_long_df.fillna(0)
ap_long_df

### Plot firing rate

In [None]:
PROP_FAIL_LABEL = "Propagation\nfailure"

ap_long_df[f"frac_section"] = ap_long_df.agg(lambda x: f"{x[NAV_SECTIONS_LABEL]} = {x[NAV_FRAC_LABEL]}", axis=1)
ap_long_df[FIRING_RATE_LABEL] = ap_long_df[AP_LABEL]*(1000/dur)
ap_long_df[NAV_PERC_LABEL] = perc_decrease(ap_long_df)
ap_long_df[SITE_LABEL] = ap_long_df[SECTION_LABEL]

def diff_firing_rate(data, y1=AIS_LABEL, y2=TERMINAL_LABEL, label=None, color='b', alpha=0.2, **kwargs):
    ax = plt.gca()
    sec1_df = data[data['style']==y1]
    sec2_df = data[data['style']==y2]
    ax.fill_between(x=sec1_df.x, y1=sec1_df.y, y2=sec2_df.y,
                    color=color, alpha=alpha)    
sites = [AIS_LABEL, TERMINAL_LABEL]


with sns.plotting_context("notebook"):
    rect = plt.Rectangle((0,0), 1, 1, color='b', alpha=0.2)
    g = sns.relplot(data=ap_long_df,
                    kind="line",
                    col=NAV_SECTIONS_LABEL, 
                    col_order=select_nav_loc_changes,
                    col_wrap=3,
                    x=NAV_PERC_LABEL,
                    y=FIRING_RATE_LABEL,
                    hue=SITE_LABEL,
                    style=SITE_LABEL,
                    hue_order=sites,
                    style_order=sites,
                    height=2.,
                    lw=2,
                    palette=SECTION_PALETTE)
    g.map_dataframe(diff_firing_rate, y1=AIS_LABEL, y2=TERMINAL_LABEL, 
                    color='b')
    g.set_titles("{col_name}")
    for ax in g.axes.flat:
        ax.set_yticks(range(0, 201, 100))
        ax.set_yticks(range(50, 201, 100), minor=True)
        ax.set_xticks(range(0, 100, 10), minor=True)
        # these are reset after map_dataframe
        ax.set(xlabel=NAV_PERC_LABEL, ylabel=FIRING_RATE_LABEL)
    
    plt.legend([rect], [PROP_FAIL_LABEL])
    g.tight_layout()
    save_fig(f"save/firing_rate_{amp}", fig=g.fig)
    g = sns.relplot(data=ap_long_df,
                    kind="line",
                    col=NAV_SECTIONS_LABEL, 
                    col_order=other_nav_changes,
                    col_wrap=3,
                    x=NAV_PERC_LABEL,
                    y=FIRING_RATE_LABEL,
                    hue=SITE_LABEL,
                    style=SITE_LABEL,
                    hue_order=sites,
                    style_order=sites,
                    height=2.,
                    lw=2,
                    palette=SECTION_PALETTE)
    g.map_dataframe(diff_firing_rate, y1=AIS_LABEL, y2=TERMINAL_LABEL, 
                    color='b')
    g.set_titles("{col_name}")
    plt.legend([rect], [PROP_FAIL_LABEL], loc=(0,0))

    for ax in g.axes.flat:
        ax.set_yticks(range(0, 201, 100))
        ax.set_yticks(range(50, 201, 100), minor=True)
        ax.set_xticks(range(0, 100, 10), minor=True)
        ax.set(xlabel=NAV_PERC_LABEL, ylabel=FIRING_RATE_LABEL)

    g.tight_layout()
    save_fig(f"save/firing_rate_{amp}_supp", fig=g.fig)

## Pulses

### Run simulations

In [None]:
def reset_biophys_pulse(pv, **kwargs):
    base_nav = reset_biophys(pv, **kwargs)
    base_gNaT = pv.soma[0].gNaTs2_tbar_NaTs2_t
    # lower dependence of the model on transient Na in the soma (which was quite high in the original model)
    for sec in pv.somatic:
        sec.gNaTs2_tbar_NaTs2_t = base_gNaT*0.01
    return base_nav

pv = get_pv(name="pulse", node_spacing=33, node_length=1., ais_L=26.5)
base_nav = reset_biophys(pv)

dur = 250
stims = [(0.75, 120)]

fractions = [1, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0]
nav_loc_changes = [
                   "somatic", 
                   "ais", 
                   "nodes", 
                   ("somatic", "ais"),
                   ("ais", "nodes"),
                   ("somatic", "ais", "nodes") # control (no nav1.1 anywhere)
                   ]

formatted_nav_loc_changes = [format_nav_loc(nv) for nv in nav_loc_changes]

pulse_results = run_sims(pv, stims, nav_loc_changes, fractions, dur, load=True, reset_biophys=reset_biophys_pulse)


### Calculate neuronal failure rate from pulses

In [None]:
amp_df = pd.DataFrame()
fail_df = pd.DataFrame()
last_sec = get_last_sec(list(pulse_results.values())[0]["df"])
min_time = 25
dur = 250
max_time = dur

for key, result in tqdm(pulse_results.items()):
    apcounts = result["APCount"]
    long_df = result["df"].copy()
    freq = result[STIM_FREQ_LABEL]
    nav_frac = result[NAV_FRAC_LABEL]
    

    pulse_times  = get_pulse_times(freq, max_time-min_time)+min_time
    long_df = long_df[(long_df[TIME_LABEL]>=min_time) & (long_df[TIME_LABEL]<=max_time)]
    all_ap_times = get_all_ap_times(long_df, last_sec, thresh=-20)
    
    ap_soma, ap_init, ap_comm = get_ap(apcounts)
    failure_rates = calculate_failure_rates(pulse_times, all_ap_times)
    fr_soma, fr_ais, fr_term = get_IFR(all_ap_times)

    base_info = {STIM_FREQ_LABEL: freq,
                 NAV_FRAC_LABEL: result[NAV_FRAC_LABEL],
                 NAV_SECTIONS_LABEL: result[NAV_SECTIONS_LABEL],
                }

    amp_df = amp_df.append([{SECTION_LABEL: SOMA_LABEL, INSTA_FR_LABEL: fr_soma, AP_LABEL: ap_soma, **base_info},
                            {SECTION_LABEL: AIS_LABEL, INSTA_FR_LABEL: fr_ais, AP_LABEL: ap_init, **base_info},
                            {SECTION_LABEL: TERMINAL_LABEL, INSTA_FR_LABEL: fr_term, AP_LABEL: ap_comm, **base_info}],
                           ignore_index=True)
    fail_df = fail_df.append({**failure_rates, **base_info}, ignore_index=True)

fail_df[NAV_PERC_LABEL] = perc_decrease(fail_df)
amp_df.head()

### Plot

In [None]:
fail_df[NAV_PERC_LABEL] = perc_decrease(fail_df).astype(int)
fail_long_df = fail_df.melt(id_vars=[NAV_SECTIONS_LABEL, NAV_PERC_LABEL, NAV_FRAC_LABEL, STIM_FREQ_LABEL], var_name="Failure type", value_name="Failure rate")

pal = sns.color_palette("Set2", n_colors=3)
# pal = sns.color_palette([pal[1], pal[2], pal[0]])
pal = ["#5D5D5D", "#EF5B29", "#16A787"]
order = ["Stimulation", "Propagation", "Initiation"]

with sns.plotting_context("notebook"):
    g = sns.relplot(data=fail_long_df, 
                    col=NAV_SECTIONS_LABEL,
                    col_order=["Soma+AIS+Nodes", "Soma+AIS", "Nodes"],
#                     col_wrap=3,
                    row=STIM_FREQ_LABEL,
                    row_order=[120],
                    x=NAV_PERC_LABEL, 
                    y="Failure rate", 
                    hue="Failure type",
                    style="Failure type",
                    hue_order=order,
                    style_order=order,
                    height=2,
#                     aspect=2,
                    lw=2,
#                     palette="mako",
                    palette=pal,
                    kind='line',
                    clip_on=False
                   )
    g.set_titles("{col_name} | {row_name:.0f} Hz")
    g.set(ylim=(0,1))
    for ax in g.axes.flat:
        ax.set_xticks(range(0, 100, 10), minor=True)
    sns.despine(offset=5)
    save_fig(f"save/failure_rate_{stims[0][0]}", fig=g.fig)

# Recovery

In [None]:
dur = 100
scale = 1000/dur

amp = 0.75

# get the interneuron (object is in cache/memory if called with the same arguments)
pv = get_pv(node_spacing=33.0, node_length=1., ais_L=26.5)
base_nav = reset_biophys(pv)
    
print('running')
if "test" in pv.name:
    t,v, AP, x_df = get_trace(pv, amp, dur, shape_plot=True)
else:
    AP, x_df = get_cached_df(f"{pv.name}_{1}_{amp}_{dur}", pv, amp, dur, shape_plot=True)

ap_start_times = get_ap_times(x_df)
get_aps = lambda ap: (ap["soma"].n*scale, ap["init"].n*scale, ap["comm"].n*scale)
aps_soma, aps_init, aps_comm = get_aps(AP)
print(f"action potentials\namp = {amp} & 100 % -> {aps_soma:5.2f} | {aps_init:5.2f} | {aps_comm:5.2f}")
fr = 1000/np.diff(ap_start_times).mean()

nav_frac = 0.5
# set_nrn_prop(pv, "gNav11bar_Nav11", 0, secs="all", ignore_error=True)
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["somatic"]*nav_frac, secs="somatic")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["ais"]*nav_frac, secs="ais")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["nodes"]*nav_frac, secs="nodes")

if "test" in pv.name:
    t,v, AP, x0_df = get_trace(pv, amp, dur, shape_plot=True)
else:
    AP, x0_df = get_cached_df(f"{pv.name}_{nav_frac:.2f}_{amp}_{dur}", pv, amp, dur, shape_plot=True)

aps0_soma, aps0_init, aps0_comm = get_aps(AP)
print(f"amp = {amp} &  {nav_frac*100:.0f} % -> {aps0_soma:5.2f} | {aps0_init:5.2f} | {aps0_comm:5.2f}")

assert aps0_soma<=aps_soma+20 and aps0_comm < aps_comm and aps0_comm <= aps0_init

print("rescue re-dist")

set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["somatic"]*(nav_frac+0.7), secs="somatic")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["ais"]*(nav_frac+0.7), secs="ais")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["nodes"]*0.4, secs="nodes")

if "test" in pv.name:
    t,v, AP, x_r1_df = get_trace(pv, amp, dur, shape_plot=True)
else:
    AP, x_r1_df = get_cached_df(f"{pv.name}_{(nav_frac+0.7, nav_frac+0.7, 0.4)}_{amp}_{dur}", pv, amp, dur, shape_plot=True)

aps_r1_soma, aps_r1_init, aps_r1_comm = get_aps(AP)
print(f"amp = {amp} & re-dist -> {aps_r1_soma:5.2f} | {aps_r1_init:5.2f} | {aps_r1_comm:5.2f}")

print("rescue up-reg")
# back to impaired
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["somatic"]*nav_frac, secs="somatic")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["ais"]*nav_frac, secs="ais")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["nodes"]*nav_frac, secs="nodes")
# rescue by up-reg
soma_up = 1.2
ais_up = 1.3
set_nrn_prop(pv, "gNaTs2_tbar_NaTs2_t", pv.soma[0].gNaTs2_tbar_NaTs2_t*soma_up, secs="somatic")
set_nrn_prop(pv, "gNaTa_tbar_NaTa_t", pv.axon[0].gNaTa_tbar_NaTa_t*ais_up, secs="ais")

if "test" in pv.name:
    t,v, AP, x_r2_df = get_trace(pv, amp, dur, shape_plot=True)
else:
    AP, x_r2_df = get_cached_df(f"{pv.name}_nat{(soma_up, ais_up)}_{amp}_{dur}", pv, amp, dur, shape_plot=True)

aps_r2_soma, aps_r2_init, aps_r2_comm = get_aps(AP)
print(f"amp = {amp} & up-reg -> {aps_r2_soma:5.2f} | {aps_r2_init:5.2f} | {aps_r2_comm:5.2f}")

fr_df = pd.DataFrame()
CONDITION_LABEL = "Condition"

with sns.plotting_context():
    from src.settings import GROUP_COLOR_D,GROUP_COLOR_A,GROUP_COLOR_C,GROUP_COLOR_E, GROUP_COLOR_B
    fig, axes = plt.subplots(ncols=4, sharey=True, sharex=True, figsize=(8,2))
    titles = ["baseline", "50 % $g_{NaV1.1}$", "redistribution\nNav1.1", "upregulation\nNav1.x"]
    for i, (_df, ax, group_color, title, apcounts) in enumerate(zip([x_df,x0_df,x_r1_df,x_r2_df], 
                                                                     axes, 
                                                                    [GROUP_COLOR_D, 
                                                                     GROUP_COLOR_A, 
                                                                     GROUP_COLOR_C, 
                                                                     GROUP_COLOR_C],
                                                                     titles,
                                                                    [(aps_soma, aps_init, aps_comm),
                                                                     (aps0_soma, aps0_init, aps0_comm),
                                                                     (aps_r1_soma, aps_r1_init, aps_r1_comm),
                                                                     (aps_r2_soma, aps_r2_init, aps_r2_comm)
                                                                    ]
                                                  )):
        plot_voltage_trace(wide_to_long(_df), thresh=False, concise=True, ax=ax, palette=[group_color]*2,
                          legend=(i==0), alpha=1,
                           lw=0.5,
                           offset=150
                          )
        if i>0:
            ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
        ax.set_xlabel("")
        ax.set_ylabel("")
        ax.set_title(title)
        
        fr_df = fr_df.append([{SECTION_LABEL: SOMA_LABEL, FIRING_RATE_LABEL: apcounts[0], CONDITION_LABEL: title},
                             {SECTION_LABEL: AIS_LABEL, FIRING_RATE_LABEL: apcounts[1], CONDITION_LABEL: title},
                             {SECTION_LABEL: TERMINAL_LABEL, FIRING_RATE_LABEL: apcounts[2], CONDITION_LABEL: title}
                             ], ignore_index=True)
        
        
    sns.despine(fig=fig, left=True, bottom=True)
    save_fig(f"save/recover_{amp}", fig=fig)


In [None]:
with sns.plotting_context("paper"):

    sns.catplot(data=fr_df, 
                    x=CONDITION_LABEL,
                    y=FIRING_RATE_LABEL,
                    hue=SECTION_LABEL,
                    hue_order=[AIS_LABEL, TERMINAL_LABEL],
    #                 palette=[GROUP_COLOR_D, GROUP_COLOR_A, GROUP_COLOR_C, GROUP_COLOR_C],
                palette=SECTION_PALETTE,
                    height=2,
                aspect=2,
                kind="bar",
                   )
save_fig(f"save/recover_{amp}_summary")

# Supplementary material

## Redistribution of NaV1.1

Grid search to find where Nav1.1 @ Soma+AIS and NaV1.1 @ Nodes (< 50 %) can recover spike **initiation**

### Theoretical redistribution calculations


In [None]:
reset_biophys(pv)
ais_SA = sum([axon.L*axon.diam for axon in pv.ais])
ais_SA *= 1e-8
ais_total_uS = ais_SA * pv.axon[0].gNav11bar_Nav11 * 1e6

node_SA = sum([node.L*node.diam for node in pv.nodes])
node_SA *= 1e-8
node_total_uS = node_SA * pv.node[0].gNav11bar_Nav11 * 1e6

soma_SA = sum([soma.L*soma.diam for soma in pv.somatic])
soma_SA *= 1e-8
soma_total_uS = soma_SA * pv.soma[0].gNav11bar_Nav11 * 1e6

print(f"Total gNaV in Soma: {soma_total_uS:.5f} uS | Total gNaV in AIS: {ais_total_uS:.5f} uS | Total gNaV in Node: {node_total_uS:.5f} uS")

total_nav = soma_total_uS+ais_total_uS+node_total_uS

frac_loss = 0.4
gnav_node = node_total_uS*frac_loss
gnav_other_gain = node_total_uS-gnav_node
new_ais = ais_total_uS + gnav_other_gain/2
new_soma = soma_total_uS + gnav_other_gain/2
new_ais/ais_total_uS, new_soma/soma_total_uS

### Parameter search

In [None]:
REDIST_AIS_LABEL = f"AIS {NAV_PERC_LABEL}"
REDIST_NODE_LABEL = f"Node {NAV_PERC_LABEL}"

dur = 200
scale = 1000/dur

get_aps = lambda ap: (ap["soma"].n*scale, ap["init"].n*scale, ap["comm"].n*scale)

# get the interneuron (object is in cache/memory if called with the same arguments)
pv = get_pv("re-dist", node_spacing=33.0, node_length=1., ais_L=26.5)

base_nav = reset_biophys(pv)

print("rescue re-dist")
amp = 0.75

red_df = pd.DataFrame()

for frac_a, frac_n in tqdm(tuple(product(
                                        np.round(np.arange(0, 1.5, 0.1), 2),
                                        np.round(np.arange(0., 0.6, 0.1), 2)))):
    set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["somatic"]*frac_a, secs="somatic")
    set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["ais"]*frac_a, secs="ais")
    set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["nodes"]*frac_n, secs="nodes")

    key_name = f"{pv.name}_{frac_a:.2f}_{frac_n:.2f}_{amp}_{dur}"
    AP, x_df = get_cached_df(key_name, pv, amp, dur, shape_plot=True)
    aps_r1_soma, aps_r1_init, aps_r1_comm = get_aps(AP)

    red_df = red_df.append([{REDIST_AIS_LABEL: 100*(1-frac_a), 
                             REDIST_NODE_LABEL: 100*(1-frac_n), 
                            SECTION_LABEL: label, 
                            FIRING_RATE_LABEL: aps} 
                           for label, aps in zip([SOMA_LABEL, AIS_LABEL, TERMINAL_LABEL],
                                                 [aps_r1_soma, aps_r1_init, aps_r1_comm])],
                          ignore_index=True)

red_df_wide_df = red_df.pivot(index=REDIST_NODE_LABEL, columns=[SECTION_LABEL, REDIST_AIS_LABEL], values=FIRING_RATE_LABEL)


### plot redistribution of Nav1.1

In [None]:
with sns.plotting_context("poster"):
    plt.rcParams["legend.fontsize"] = 'x-small'    
    plt.rcParams["legend.title_fontsize"] = 'x-small'

    _plot_df = red_df[(red_df[SECTION_LABEL].isin([AIS_LABEL,TERMINAL_LABEL])) & (red_df[REDIST_NODE_LABEL]<100)]
    num_node_nav = _plot_df[REDIST_NODE_LABEL].unique().size
    sns.scatterplot(data=_plot_df, 
                    x=REDIST_AIS_LABEL, y=FIRING_RATE_LABEL, 
                    hue=REDIST_NODE_LABEL, 
                    style=SECTION_LABEL, 
                    markers={SOMA_LABEL: "s", AIS_LABEL:"d", TERMINAL_LABEL: "o"},
                    size=SECTION_LABEL, 
                    size_order=_plot_df[SECTION_LABEL].unique(), 
                    palette="mako_r",
                    alpha=0.9,
                    clip_on=False)
    ax = plt.gca()
    ax.axvline(x=0,  ls="--", alpha=0.5, c="k", zorder=-10)
    # ax.axhline(y=200, ls="--", alpha=0.5, c="k", zorder=-10)
    ax.fill_between(x=ax.get_xlim(), y1=200, y2=220, alpha=0.1, color="k", ec="None", zorder=-10, hatch="\\")
    ax.legend(loc=(1,0), ncol=1)
    # ax.set_xlim(0, red_df[REDIST_AIS_LABEL].max()+0.05)
    ax.set_ylim(0, 220)
    sns.despine(ax=ax, offset=5)
    save_fig(f"save/redist_{amp}")

## Up-regulation of Nav1.x

Grid search to find where Nav1.x @ Soma+AIS and NaV1.1 @ Nodes (< 50 %) can recover spike **initiation**

In [None]:
NAT_SOMA_LABEL = "NaT Soma (% baseline)"
NAT_AIS_LABEL = "NaT AIS (% baseline)"
NAT_LABEL = "NaT (% baseline)"

dur = 200
scale = 1000/dur

get_aps = lambda ap: (ap["soma"].n*scale, ap["init"].n*scale, ap["comm"].n*scale)

amp = 0.75

# get the interneuron (object is in cache/memory if called with the same arguments)
pv = get_pv("up-reg", node_spacing=33.0, node_length=1., ais_L=26.5)

base_nav = reset_biophys(pv)
base_nat_soma = pv.soma[0].gNaTs2_tbar_NaTs2_t
base_nat_ais = pv.axon[0].gNaTa_tbar_NaTa_t


# back to impaired
nav_frac= 0.5
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["somatic"]*nav_frac, secs="somatic")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["ais"]*nav_frac, secs="ais")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["nodes"]*nav_frac, secs="nodes")

# rescue by up-reg
res_df = pd.DataFrame()

stop_soma = 1.5 # inclusive
step_soma = 0.1
y_scale = 1
stop_ais = stop_soma*y_scale
step_ais = step_soma*y_scale

for s, a in tqdm(tuple(product(np.round(np.arange(1., stop_soma+step_soma, step_soma),2), 
                               np.round(np.arange(1., stop_ais+step_ais, step_ais),2)))):
    set_nrn_prop(pv, "gNaTs2_tbar_NaTs2_t", base_nat_soma*s, secs="somatic")
    set_nrn_prop(pv, "gNaTa_tbar_NaTa_t", base_nat_ais*a, secs="ais")
    key_name = f"{pv.name}_{s:.2f}_{a:.2f}_{amp}_{dur}"
    AP, x_df = get_cached_df(key_name, pv, amp, dur, shape_plot=True)
    aps_r2_soma, aps_r2_init, aps_r2_comm = get_aps(AP)
    res_df = res_df.append([{NAT_SOMA_LABEL: round(s*100), 
                             NAT_AIS_LABEL: round(a*100), 
                             SECTION_LABEL: label, 
                             FIRING_RATE_LABEL: aps} 
                            for label, aps in zip([SOMA_LABEL, AIS_LABEL, TERMINAL_LABEL],
                                                  [aps_r2_soma, aps_r2_init, aps_r2_comm])],
                           ignore_index=True)

res_df

In [None]:
def plot_heatmap_inner_square(res_df, y_scale=y_scale, cmap="rocket", shrink=1, border=5, fig_kwargs=None):
    if fig_kwargs is None:
        fig_kwargs = {}
    wide_df = res_df.pivot(index=NAT_SOMA_LABEL, columns=[SECTION_LABEL, NAT_AIS_LABEL], values=FIRING_RATE_LABEL)
    x = wide_df.columns.levels[1]
    y = wide_df.index
    X,Y = np.meshgrid(x,y)
    signal = np.random.rand(len(x)*len(y))
    det = np.random.poisson(lam=0.5,size=len(x)*len(y))
    det[det>1] = 1

    df_signal = res_df.copy()

    # prepare Dataframes
    df = wide_df.iloc[::-1]
    dfmark = res_df[res_df[SECTION_LABEL]=="Soma"]

    #plotting
    fig, ax = plt.subplots(**fig_kwargs)
    ext = [x.min()-np.diff(x)[0]/2., x.max()+np.diff(x)[0]/2., 
           y_scale*(y.min()-np.diff(y)[0]/2.), y_scale*(y.max()+np.diff(y)[0]/2.) ]
    sm = ax.imshow(df, extent=ext, cmap=cmap, vmin=0)
    ax.set_xticks(x[:], minor=True)
    ax.set_xticks(x[::2])
    ax.set_yticks(y[:]*y_scale, minor=True)
    ax.set_yticks(y[::2]*y_scale)
    ax.set_yticklabels(y[::2])

    dx = np.diff(x)[0]/3
    dy = np.diff(y)[0]*y_scale
    dxbig = np.diff(x)[0]*3
    for (xi,yi), in zip(dfmark[[NAT_AIS_LABEL, NAT_SOMA_LABEL]].values):
        rec = plt.Rectangle((xi-dx/2.,yi*y_scale-dy/2.), dx, dy, fill=False, 
                            edgecolor="k", lw=0.1)
        recbig = plt.Rectangle((xi-dxbig/2.,yi*y_scale-dy/2.),dxbig,dy, fill=False, 
                            edgecolor="w", lw=0.5)
        ax.add_artist(rec)        
        ax.add_artist(recbig)

    cbar = fig.colorbar(sm, label=FIRING_RATE_LABEL, shrink=shrink)
    if border!="full":
        cbar.outline.set_visible(False)
    
    if border=="full":
        # black border all around
        sns.despine(ax=ax, top=False, right=False)
    elif border=="None":
        # no border
        sns.despine(ax=ax, left=True, bottom=True)
    else:
        # offset basic border
        sns.despine(ax=ax, offset=border)
    
    ax.set_xlabel(NAT_AIS_LABEL.replace("(%","\n(%"))
    ax.set_ylabel(NAT_SOMA_LABEL.replace("(%","\n(%"))
    
    return fig, ax

with sns.plotting_context("notebook"):
    fig, ax = plot_heatmap_inner_square(res_df[(res_df[NAT_SOMA_LABEL]<=140)&(res_df[NAT_AIS_LABEL]<=160)], 
                                        shrink=0.6, border=5, cmap="bone",
                                       fig_kwargs=dict(figsize=(3,3)))
    recbig = plt.Rectangle((125, 115),10, 10, fill=False, 
                            edgecolor="b", lw=2)
    ax.add_artist(recbig)
    save_fig(f"save/upreg_heatmap_{amp}")


## Input-Firing rate (IF) curves

In [None]:
dur = 100
scale = 1000/dur

amps = np.round(np.linspace(0.0, 1.0, 11), 2)

get_aps = lambda ap: (ap["soma"].n*scale, ap["init"].n*scale, ap["comm"].n*scale)

def get_if_curve(pv, dur, amps):
    amp_df = pd.DataFrame(columns=[CURRENT_LABEL, FIRING_RATE_LABEL, SECTION_LABEL])

    for amp in tqdm(amps):
        key_name = f"{pv.name}_{amp}_{dur}"
        AP, _ = get_cached_df(key_name, pv, amp, dur, shape_plot=False)
        
        aps_soma, aps_init, aps_comm = get_aps(AP)
        
        amp_df = amp_df.append([
            {
                CURRENT_LABEL: amp, 
                FIRING_RATE_LABEL: aps_soma, 
                SECTION_LABEL:SOMA_LABEL
            }, 
            {
                CURRENT_LABEL: amp, 
                FIRING_RATE_LABEL: aps_init, 
                SECTION_LABEL:AIS_LABEL
            },
            {
                CURRENT_LABEL: amp, 
                FIRING_RATE_LABEL: aps_comm, 
                SECTION_LABEL: TERMINAL_LABEL
            }
            ], ignore_index=True)

    amp_df = amp_df.convert_dtypes()
    return amp_df

pv = get_pv(name="io", node_spacing=33, node_length=1., ais_L=26.5)
pv_old_name = pv.name
base_nav = reset_biophys(pv)

nav_locs = [("somatic", "ais", "nodes")]
fracs = [0.1, 0.3, 0.5, 1]
percs = perc_decrease(np.array(fracs))

amp_df = pd.DataFrame()
for frac, nav_loc in product(fracs, nav_locs):
    reset_biophys(pv)
    if isinstance(nav_loc, str):
        set_relative_nav11bar(pv, frac, at=nav_loc, base=base_nav[nav_loc])
    else:
        # is an iterable of locations to change
        for _nav_loc in nav_loc:
            set_relative_nav11bar(pv, frac, at=_nav_loc, base=base_nav[_nav_loc])

    pv.name = f"{pv_old_name}_{frac}_{nav_loc}"
    new_amp_df = get_if_curve(pv, dur, amps)
    pv.name = pv_old_name

    new_amp_df[NAV_FRAC_LABEL] = frac    
    new_amp_df[NAV_SECTIONS_LABEL] = format_nav_loc(nav_loc)

    amp_df = amp_df.append(new_amp_df, ignore_index=True)
amp_df

### plot IF

In [None]:
def plot_if_curve(amp_df, **kwargs):
    sns.lineplot(data=amp_df, x=CURRENT_LABEL, y=FIRING_RATE_LABEL, hue=NAV_PERC_LABEL, style=SECTION_LABEL, 
                 style_order=[AIS_LABEL, TERMINAL_LABEL], **kwargs)

amp_df[NAV_PERC_LABEL] = perc_decrease(amp_df).astype(int)

with sns.plotting_context("notebook"):
    plot_percs = 0, 50, 70
    color_palette = [GROUP_COLOR_D, GROUP_COLOR_A, sns.dark_palette(GROUP_COLOR_A, n_colors=11)[7]]
    fig, ax = plt.subplots(figsize=(4,2))
    plot_if_curve(amp_df, palette=color_palette, hue_order=plot_percs, lw=4, alpha=0.7)
    plt.legend(loc=(0.05,0.8), ncol=2)
#     plt.tight_layout()
    save_fig("save/if_curve")

## TTX

Block **all** sodium channels

In [None]:
dur = 100
scale = 1000/dur

get_aps = lambda ap: (ap["soma"].n*scale, ap["init"].n*scale, ap["comm"].n*scale)

pv = get_pv(node_spacing=33.0, node_length=1., ais_L=26.5)

base_nav = reset_biophys(pv)

base_nav["basal"] = pv.dend[0].gNav11bar_Nav11

base_nap = {"somatic": pv.soma[0].gNap_Et2bar_Nap_Et2,
            "ais": pv.axon[0].gNap_Et2bar_Nap_Et2,
            "basal": pv.dend[0].gNap_Et2bar_Nap_Et2
           }
base_nat = {"somatic": pv.soma[0].gNaTs2_tbar_NaTs2_t,
            "basal": pv.dend[0].gNaTs2_tbar_NaTs2_t
           }
base_nata = pv.axon[0].gNaTa_tbar_NaTa_t


ttx_df = pd.DataFrame()

ttx_arr = np.round(np.arange(0, 1.1, 0.05), 2)

for amp in tqdm([0.25, 0.5, 0.75]):
    for ttx in tqdm(ttx_arr):
        na_act_frac = 1-ttx
        set_nrn_prop(pv, "gNap_Et2bar_Nap_Et2", base_nap["somatic"]*na_act_frac, secs="somatic")
        set_nrn_prop(pv, "gNap_Et2bar_Nap_Et2", base_nap["ais"]*na_act_frac, secs="ais")
        set_nrn_prop(pv, "gNap_Et2bar_Nap_Et2", base_nap["basal"]*na_act_frac, secs="basal")

        set_nrn_prop(pv, "gNaTs2_tbar_NaTs2_t", base_nat["somatic"]*na_act_frac, secs="somatic")
        set_nrn_prop(pv, "gNaTs2_tbar_NaTs2_t", base_nat["basal"]*na_act_frac, secs="basal")

        set_nrn_prop(pv, "gNaTa_tbar_NaTa_t", base_nata*na_act_frac, secs="ais")

        set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["somatic"]*na_act_frac, secs="somatic")
        set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["ais"]*na_act_frac, secs="ais")
        set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["nodes"]*na_act_frac, secs="nodes")
        set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["basal"]*na_act_frac, secs="basal")

        key_name = f"ttx_{pv.name}_{na_act_frac}_{amp}_{dur}"
        AP, x_df = get_cached_df(key_name, pv, amp, dur, shape_plot=True)
        ap_start_times = get_ap_times(x_df)

        aps_soma, aps_init, aps_comm = get_aps(AP)
        ttx_df = ttx_df.append([
                {
                    CURRENT_LABEL: amp, 
                    FIRING_RATE_LABEL: aps_soma, 
                    SECTION_LABEL: SOMA_LABEL,
                    TTX_LABEL: ttx,
                }, 
                {
                    CURRENT_LABEL: amp, 
                    FIRING_RATE_LABEL: aps_init, 
                    SECTION_LABEL: AIS_LABEL,
                    TTX_LABEL: ttx,
                },
                {
                    CURRENT_LABEL: amp, 
                    FIRING_RATE_LABEL: aps_comm, 
                    SECTION_LABEL: TERMINAL_LABEL,
                    TTX_LABEL: ttx,
                }], ignore_index=True
            )
ttx_df


In [None]:
with sns.plotting_context("notebook"):
    TTX_PERC = f"{TTX_LABEL} (%)"
    ttx_df[TTX_PERC] = ttx_df[TTX_LABEL]*100
    fig, ax = plt.subplots(figsize=(4,2))
    sns.lineplot(data=ttx_df, x=TTX_PERC, y=FIRING_RATE_LABEL, 
#                  style=SECTION_LABEL, 
                 hue=CURRENT_LABEL,
#                  palette="colorblind"
                )
#     plt.legend(loc="upper left", bbox_to_anchor=(0.5,1.2), fontsize='small')
    ax.set_xlim(0, 100)
    ax.set_ylim(0)
    ax.set_xticks(range(0, 100, 10), minor=True)
    ax.set_yticks(range(0, 200, 50), minor=True)
    save_fig(f"save/ttx")

## Nav1.1 mutation

As in Berecki et al. 2019 
(Petrou model)

In [None]:
# nav1.1 mutation

dur = 100
scale = 1000/dur
amp1 = 0.75
amp2 = 0.5

amp = amp1

# get the interneuron (object is in cache/memory if called with the same arguments)
pv = get_pv("mut", node_spacing=33.0, node_length=1., ais_L=26.5)

base_nav = reset_biophys(pv)
amp=amp1
print('running')
t,v, AP, x_df = get_trace(pv, amp, dur, shape_plot=True)
ap_start_times = get_ap_times(x_df)
get_aps = lambda ap: (ap["soma"].n*scale, ap["init"].n*scale, ap["comm"].n*scale)
aps_soma, aps_init, aps_comm = get_aps(AP)
print(f"action potentials\namp = {amp} & 100 % -> {aps_soma:5.2f} | {aps_init:5.2f} | {aps_comm:5.2f}")
fr = 1000/np.diff(ap_start_times).mean()


amp = amp2

t,v, AP, xa2_df = get_trace(pv, amp, dur, shape_plot=True)
apsa2_soma, apsa2_init, apsa2_comm = get_aps(AP)
print(f"amp = {amp} & 100 % -> {apsa2_soma:5.2f} | {apsa2_init:5.2f} | {apsa2_comm:5.2f}")

amp = amp1

# set_nrn_prop(pv, "gNav11bar_Nav11", 0, secs="all", ignore_error=True)
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["somatic"]*0, secs="somatic")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["ais"]*0, secs="ais")
set_nrn_prop(pv, "gNav11bar_Nav11", base_nav["nodes"]*0, secs="nodes")
set_nrn_prop(pv, "gNav11bar_Nav11m", base_nav["somatic"], secs="somatic")
set_nrn_prop(pv, "gNav11bar_Nav11m", base_nav["ais"], secs="ais")
set_nrn_prop(pv, "gNav11bar_Nav11m", base_nav["nodes"], secs="nodes")
t,v,AP, x0_df = get_trace(pv, amp, dur, True)
aps0_soma, aps0_init, aps0_comm = get_aps(AP)
print(f"amp = {amp} % -> {aps0_soma:5.2f} | {aps0_init:5.2f} | {aps0_comm:5.2f}")

# assert aps0_soma<=aps_soma+20 and aps0_comm < aps_comm and aps0_comm <= aps0_init

amp = amp2

t,v, AP, x0a2_df = get_trace(pv, amp, dur, shape_plot=True)
aps0a2_soma, aps0a2_init, aps0sa2_comm = get_aps(AP)
print(f"amp = {amp} &  -> {aps0a2_soma:5.2f} | {aps0a2_init:5.2f} | {aps0sa2_comm:5.2f}")




## Input resistance

In [None]:
pv = get_pv("ir")
# get_trace(pv, 0, -20)

# calculate input resistance using NEURON's impendance object
imp = h.Impedance()
imp.loc(0.5, sec=pv.soma[0])
imp.compute(100)
r = imp.input(0.5, sec=pv.soma[0])
h.finitialize()
# values as per optimisation by BBP
print(f"Input resistance = {r:.2f} MΩ")

# Maximum propagation distance of action potential

Propagation is relative to the Xth action potential generated at the soma (to account for sustained firing)


In [None]:
action_potential_num = 7

ap_idx = action_potential_num - 1

prop_df = pd.DataFrame(index=pd.Index(fractions, name=NAV_FRAC_LABEL), 
                       columns=pd.Index(formatted_nav_loc_changes, name=NAV_SECTIONS_LABEL))

# first action potentials for first iteration
#  note that the section doesn't matter as long as the fraction is 1 (i.e. a control sim)
ap_times = get_ap_times(list(amp_results.values())[0]["df"])
time_window = (ap_times[ap_idx]-5, ap_times[ap_idx]+5)

for key, result in amp_results.items():
    frac = result[NAV_FRAC_LABEL]
    nav_loc = result[NAV_SECTIONS_LABEL]  
    long_df = result["df"]
        
    # maximum propagation distance
    max_prop_distance_idx, max_prop_distance = get_max_propagation(long_df, time=time_window)
    prop_df.loc[frac, nav_loc] = max_prop_distance
prop_df=prop_df.fillna(0)
prop_long_df = (prop_df
                .reset_index(drop=False)
                .melt(id_vars=[NAV_FRAC_LABEL], value_name=MAX_PROP_LABEL)
                .convert_dtypes()
                )

prop_long_df.head()

In [None]:
prop_long_df[NAV_PERC_LABEL] = perc_decrease(fail_long_df)
with sns.plotting_context("notebook"):
    figsize= (3,2)
    fig, ax =plt.subplots(figsize=figsize)
    
    palette = sns.color_palette("husl", n_colors=len(nav_loc_changes))
    
    sns.lineplot(data=prop_long_df,
    #             kind="line",
                x=NAV_PERC_LABEL,
                y=MAX_PROP_LABEL,
                style=NAV_SECTIONS_LABEL,
                style_order=select_nav_loc_changes,
                hue=NAV_SECTIONS_LABEL,
                hue_order=select_nav_loc_changes,
                palette=palette[::2],
                alpha=1.,
                 lw=2,
                 )
    # plt.xlim(100, 0.1)
    plt.ylim(0)
    # plt.xscale("log")
    plt.tight_layout()
    plt.legend(title=NAV_SECTIONS_LABEL, loc=(0.01,0.1))
    save_fig(f"save/fig_prop_{amp}")
    
    fig, ax =plt.subplots(figsize=figsize)
    sns.lineplot(data=prop_long_df,
    #             kind="line",
                x=NAV_PERC_LABEL,
                y=MAX_PROP_LABEL,
                style=NAV_SECTIONS_LABEL,
                style_order=other_nav_changes,
                hue=NAV_SECTIONS_LABEL,
                hue_order=other_nav_changes,
                palette=palette[1::2],
                alpha=1.,
                 lw=2,
                 )
    # plt.xlim(100, 0.1)
    plt.ylim(0)
    # plt.xscale("log")
    plt.tight_layout()
    plt.legend(title=NAV_SECTIONS_LABEL, loc=(0.01,0.1))
    save_fig(f"save/fig_prop_{amp}_supp")     