## 4) Compare optimization results

In this notebook we use the previously optimized parameters to assess the performance of different optimization strategies

In [None]:
%load_ext autoreload
%autoreload

#!nrnivmodl mechanisms
import bluepyopt as bpopt
import bluepyopt.ephys as ephys

import matplotlib.pyplot as plt
import MEAutility as mu
import json
import numpy
import time
import numpy as np
import LFPy
from pathlib import Path
import pandas as pd
import os
import pickle

%matplotlib notebook

In [None]:
import l5pc_model
import l5pc_evaluator
import l5pc_plot

In [None]:
sample_id = 3 # [0, ..., n_samples]
offspring_size = 250
max_ngen = 50
channels = [0, 6, 7, 10, 15]
nchannels=len(channels)

In [None]:
random_params_file = 'config/params/random.csv'
random_params = pd.read_csv(random_params_file, index_col='index')
gt_params = random_params.iloc[sample_id]

In [None]:
print(gt_params)

In [None]:
gt_params = gt_params.to_dict()

In [None]:
# get checkpoints
checkpoints_folder = Path('checkpoints/')
cp_soma_file = checkpoints_folder / f'random_{sample_id}' / f'soma_off{offspring_size}_ngen{max_ngen}_{nchannels}chan.pkl'
cp_bap_file = checkpoints_folder / f'random_{sample_id}' / f'bap_off{offspring_size}_ngen{max_ngen}_{nchannels}chan.pkl'
cp_extra_file = checkpoints_folder / f'random_{sample_id}' / f'extra_off{offspring_size}_ngen{max_ngen}_{nchannels}chan.pkl'

In [None]:
cp_soma = pickle.load(open(cp_soma_file, 'rb'))
cp_bap = pickle.load(open(cp_bap_file, 'rb'))
cp_extra = pickle.load(open(cp_extra_file, 'rb'))

In [None]:
hof_soma = cp_soma['halloffame']
hof_bap = cp_bap['halloffame']
hof_extra = cp_extra['halloffame']

In [None]:
cp_extra['generation']

In [None]:
prep = l5pc_evaluator.prepare_optimization('extra', sample_id, offspring_size=offspring_size, channels=channels,
                                           map_function = None)
evaluator = prep['evaluator']
fitness_calculator = prep['objectives_calculator']
fitness_protocols = prep['protocols']

In [None]:
best_soma = best_params = evaluator.param_dict(hof_soma[0])
best_bap = best_params = evaluator.param_dict(hof_bap[0])
best_extra = best_params = evaluator.param_dict(hof_extra[0])

In [None]:
best_soma

In [None]:
best_bap

In [None]:
best_extra

In [None]:
gt_params

In [None]:
rel_error_soma = {}
rel_error_bap = {}
rel_error_extra = {}

for param, gt_value in gt_params.items():
    rel_error_soma[param] = np.abs((gt_value - best_soma[param]) / gt_value)
    rel_error_bap[param] = np.abs((gt_value - best_bap[param]) / gt_value)    
    rel_error_extra[param] = np.abs((gt_value - best_extra[param]) / gt_value)    

In [None]:
rel_error_soma

In [None]:
rel_error_bap

In [None]:
rel_error_extra

In [None]:
print(np.sum(list(rel_error_soma.values())), 
      np.sum(list(rel_error_bap.values())), 
      np.sum(list(rel_error_extra.values())))

In [None]:
def plot_evolution(logbook, color, label=None, ax=None):
    if ax is None:
        fig = plt.figure()
        ax = fig.add_subplot(111)
    
    gens = []
    avgs = []
    stds = []
    mins = []
    maxs = []
    
    for log in logbook:
        gens.append(log['gen'])
        avgs.append(log['avg'])        
        stds.append(log['std'])        
        mins.append(log['min'])        
        maxs.append(log['max'])      
    
    gens = np.array(gens)
    avgs = np.array(avgs)
    stds = np.array(stds)
    mins = np.array(mins)
    maxs = np.array(maxs)
    
#     ax.plot(gens, avgs, color=color, label=label)
    ax.plot(gens, mins, color=color, label=label)
#     ax.plot(gens, maxs, color=color,  ls='--', alpha=0.3)
#     ax.fill(gens, avgs, mins, color=color, alpha=0.3)
    
    return ax

In [None]:
ax = plot_evolution(cp_soma['logbook'], color='C0', label='soma')
ax = plot_evolution(cp_bap['logbook'], color='C1', label='bap', ax=ax)
ax = plot_evolution(cp_extra['logbook'], color='C2', label='extra', ax=ax)
ax.legend()

In [None]:
original_responses = evaluator.run_protocols(protocols=fitness_protocols.values(), param_values=gt_params)
best_responses_soma = evaluator.run_protocols(protocols=fitness_protocols.values(), param_values=best_soma)
best_responses_bap = evaluator.run_protocols(protocols=fitness_protocols.values(), param_values=best_bap)
best_responses_extra = evaluator.run_protocols(protocols=fitness_protocols.values(), param_values=best_extra)

In [None]:
l5pc_plot.plot_multiple_responses([original_responses, best_responses_bap, best_responses_soma, best_responses_extra])

In [None]:
l5pc_plot.plot_multiple_responses([original_responses, best_responses_extra])

## Compare extracellular action potentials

In [None]:
## HELPER FUNCTIONS ##
def _construct_somatic_efel_trace(responses, somatic_recording_name, stim_start, stim_end):
    """Construct trace that can be passed to eFEL"""

    trace = {}
    if somatic_recording_name not in responses:
        return None

    if responses[somatic_recording_name] is None:
        return None

    response = responses[somatic_recording_name]

    trace['T'] = response['time']
    trace['V'] = response['voltage']
    trace['stim_start'] = [stim_start]
    trace['stim_end'] = [stim_end]

    return trace

def _setup_efel(threshold=None, interp_step=None, double_settings=None, int_settings=None):
    """Set up efel before extracting the feature"""

    import efel
    efel.reset()

    if threshold is not None:
        efel.setThreshold(threshold)

    if interp_step is not None:
        efel.setDoubleSetting('interp_step', interp_step)

    if double_settings is not None:
        for setting_name, setting_value in double_settings.items():
            efel.setDoubleSetting(setting_name, setting_value)

    if int_settings is not None:
        for setting_name, setting_value in int_settings.items():
            efel.setIntSetting(setting_name, setting_value)
            

def _get_peak_times(responses, somatic_recording_name, stim_start, stim_end, raise_warnings=False, **efel_kwargs):

    efel_trace = _construct_somatic_efel_trace(responses, somatic_recording_name, stim_start, stim_end)

    if efel_trace is None:
        peak_times = None
    else:
        _setup_efel(**efel_kwargs)

        import efel
        peaks = efel.getFeatureValues([efel_trace], ['peak_time'], raise_warnings=raise_warnings)
        peak_times = peaks[0]['peak_time']

        efel.reset()

    return peak_times

def _interpolate_response(response, fs=20.):
    from scipy.interpolate import interp1d

    x = response['time']
    y = response['voltage']
    f = interp1d(x, y, axis=1)
    xnew = np.arange(np.min(x), np.max(x), 1. / fs)
    ynew = f(xnew)  # use interpolation function returned by `interp1d`

    response_new = {}
    response_new['time'] = xnew
    response_new['voltage'] = ynew

    return response_new


def _filter_response(response, fcut=[0.5, 6000], order=2, filt_type='lfilter'):
    import scipy.signal as ss
    fs = 1 / np.mean(np.diff(response['time'])) * 1000
    fn = fs / 2.

    trace = response['voltage']

    if isinstance(fcut, (float, int, np.float, np.integer)):
        btype = 'highpass'
        band = fcut / fn
    else:
        assert isinstance(fcut, (list, np.ndarray)) and len(fcut) == 2
        btype = 'bandpass'
        band = np.array(fcut) / fn

    b, a = ss.butter(order, band, btype=btype)

    if len(trace.shape) == 2:
        if filt_type == 'filtfilt':
            filtered = ss.filtfilt(b, a, trace, axis=1)
        else:
            filtered = ss.lfilter(b, a, trace, axis=1)
    else:
        if filt_type == 'filtfilt':
            filtered = ss.filtfilt(b, a, trace)
        else:
            filtered = ss.lfilter(b, a, trace)

    response_new = {}
    response_new['time'] = response['time']
    response_new['voltage'] = filtered

    return response_new


def _upsample_wf(waveforms, upsample):
    from scipy.signal import resample_poly
    ndim = len(waveforms.shape)
    waveforms_up = resample_poly(waveforms, up=upsample, down=1, axis=ndim-1)

    return waveforms_up


def _get_waveforms(response, peak_times, snippet_len_ms):
    times = response['time']
    traces = response['voltage']

    assert np.std(np.diff(times)) < 0.001 * np.mean(np.diff(times)), "Sampling frequency must be constant"

    fs = 1. / np.mean(np.diff(times))  # kHz

    reference_frames = (peak_times * fs).astype(int)

    if isinstance(snippet_len_ms, (tuple, list, np.ndarray)):
        snippet_len_before = int(snippet_len_ms[0] * fs)
        snippet_len_after = int(snippet_len_ms[1] * fs)
    else:
        snippet_len_before = int((snippet_len_ms + 1) / 2 * fs)
        snippet_len_after = int((snippet_len_ms - snippet_len_before) * fs)

    num_snippets = len(peak_times)
    if len(traces.shape) == 2:
        num_channels = traces.shape[0]
    else:
        num_channels = 1
        traces = traces[np.newaxis, :]
    num_frames = len(times)
    snippet_len_total = int(snippet_len_before + snippet_len_after)
    waveforms = np.zeros((num_snippets, num_channels, snippet_len_total), dtype=traces.dtype)

    for i in range(num_snippets):
        snippet_chunk = np.zeros((num_channels, snippet_len_total), dtype=traces.dtype)
        if 0 <= reference_frames[i] < num_frames:
            snippet_range = np.array([int(reference_frames[i]) - snippet_len_before,
                                      int(reference_frames[i]) + snippet_len_after])
            snippet_buffer = np.array([0, snippet_len_total], dtype='int')
            # The following handles the out-of-bounds cases
            if snippet_range[0] < 0:
                snippet_buffer[0] -= snippet_range[0]
                snippet_range[0] -= snippet_range[0]
            if snippet_range[1] >= num_frames:
                snippet_buffer[1] -= snippet_range[1] - num_frames
                snippet_range[1] -= snippet_range[1] - num_frames
            snippet_chunk[:, snippet_buffer[0]:snippet_buffer[1]] = traces[:, snippet_range[0]:snippet_range[1]]
        waveforms[i] = snippet_chunk

    return waveforms

In [None]:
def calculate_eap(responses, protocol_name, protocols, fs=20, fcut=1,
                  ms_cut=[2, 10], upsample=10, skip_first_spike=True, skip_last_spike=True, 
                  raise_warnings=False, verbose=False, **efel_kwargs):
    
    assert "Step" in protocol_name
    stimulus = protocols[protocol_name].stimuli[0]
    stim_start = stimulus.step_delay
    stim_end = stimulus.step_delay + stimulus.step_duration
    efel_kwargs['threshold'] = -20
    
    somatic_recording_name = f'{protocol_name}.soma.v'
    extra_recording_name = f'{protocol_name}.MEA.LFP'
    
    peak_times = _get_peak_times(responses, somatic_recording_name, stim_start, stim_end,
                                 raise_warnings=raise_warnings, **efel_kwargs)

    if len(peak_times) > 1 and skip_first_spike:
        peak_times = peak_times[1:]

    if len(peak_times) > 1 and skip_last_spike:
        peak_times = peak_times[:-1]
        
    if responses[extra_recording_name] is not None:
        response = responses[extra_recording_name]
    else:
        return None

    if np.std(np.diff(response['time'])) > 0.001 * np.mean(np.diff(response['time'])):
        assert fs is not None
        if verbose:
            print('interpolate')
        response_interp = _interpolate_response(response, fs=fs)
    else:
        response_interp = response

    if fcut is not None:
        if verbose:
            print('filter enabled')
        response_filter = _filter_response(response_interp, fcut=fcut)
    else:
        if verbose:
            print('filter disabled')
        response_filter = response_interp

    ewf = _get_waveforms(response_filter, peak_times, ms_cut)
    mean_wf = np.mean(ewf, axis=0)
    if upsample is not None:
        if verbose:
            print('upsample')
        assert upsample > 0
        upsample = int(upsample)
        mean_wf_up = _upsample_wf(mean_wf, upsample)
        fs_up = upsample * fs
    else:
        mean_wf_up = mean_wf
        fs_up = fs

    return mean_wf_up

In [None]:
mean_wf_extra = calculate_eap(best_responses_extra, "Step1", evaluator.fitness_protocols) * 1000
mean_wf_bap = calculate_eap(best_responses_bap, "Step1", evaluator.fitness_protocols) * 1000
mean_wf_soma = calculate_eap(best_responses_soma, "Step1", evaluator.fitness_protocols) * 1000
mean_wf_original = calculate_eap(original_responses, "Step1", evaluator.fitness_protocols) * 1000

In [None]:
mean_wf_extra_n = mean_wf_extra / np.max(np.abs(mean_wf_extra), 1, keepdims=True)
mean_wf_bap_n = mean_wf_extra / np.max(np.abs(mean_wf_bap), 1, keepdims=True)
mean_wf_soma_n = mean_wf_extra / np.max(np.abs(mean_wf_soma), 1, keepdims=True)
mean_wf_original_n = mean_wf_extra / np.max(np.abs(mean_wf_original), 1, keepdims=True)

In [None]:
probe_file = Path('config') / 'features' / f'random_{sample_id}'/ 'probe.json'
probe, electrode = l5pc_evaluator.define_electrode(probe_file)

In [None]:
vscale = 1.5 * np.max(np.abs(mean_wf_original_n))

In [None]:
ax_extra = mu.plot_mea_recording(mean_wf_original_n, probe, colors='k', lw=2)
mu.plot_mea_recording(mean_wf_soma_n, probe, colors='C0', ax=ax_extra)
mu.plot_mea_recording(mean_wf_bap_n, probe, colors='C1', ax=ax_extra)
mu.plot_mea_recording(mean_wf_extra_n, probe, colors='C2', ax=ax_extra)

## Capability of the models to reproduce BAP-activated calcium spikes?

It would be nice to show some functional output. 