In [None]:
%load_ext autoreload
%autoreload

# !nrnivmodl mechanisms
import bluepyopt as bpopt
import bluepyopt.ephys as ephys

import pprint
pp = pprint.PrettyPrinter(indent=2)

import matplotlib.pyplot as plt
%matplotlib notebook
import MEAutility as mu

import json
import numpy
import time
import numpy as np
import LFPy
import neuroplotlib as nplt

## Define extracellular electrodes

In [None]:
mea_positions = np.zeros((20, 3))
mea_positions[:, 2] = 20
mea_positions[:, 1] = np.linspace(-500, 1000, 20)
probe = mu.return_mea(info={'pos': mea_positions, 'center': False, 'plane': 'xy'})
electrode = LFPy.RecExtElectrode(probe=probe, method='linesource')

In [None]:
ax = mu.plot_probe(probe)

## Setup protocols, fitness, and evaluator

In [None]:
release_params = {
    'gNaTs2_tbar_NaTs2_t.apical': 0.026145,
    'gSKv3_1bar_SKv3_1.apical': 0.004226,
    'gImbar_Im.apical': 0.000143,
    'gNaTa_tbar_NaTa_t.axonal': 3.137968,
    'gK_Tstbar_K_Tst.axonal': 0.089259,
    'gamma_CaDynamics_E2.axonal': 0.002910,
    'gNap_Et2bar_Nap_Et2.axonal': 0.006827,
    'gSK_E2bar_SK_E2.axonal': 0.007104,
    'gCa_HVAbar_Ca_HVA.axonal': 0.000990,
    'gK_Pstbar_K_Pst.axonal': 0.973538,
    'gSKv3_1bar_SKv3_1.axonal': 1.021945,
    'decay_CaDynamics_E2.axonal': 287.198731,
    'gCa_LVAstbar_Ca_LVAst.axonal': 0.008752,
    'gamma_CaDynamics_E2.somatic': 0.000609,
    'gSKv3_1bar_SKv3_1.somatic': 0.303472,
    'gSK_E2bar_SK_E2.somatic': 0.008407,
    'gCa_HVAbar_Ca_HVA.somatic': 0.000994,
    'gNaTs2_tbar_NaTs2_t.somatic': 0.983955,
    'decay_CaDynamics_E2.somatic': 210.485284,
    'gCa_LVAstbar_Ca_LVAst.somatic': 0.000333
}

In [7]:
import l5pc_model
import l5pc_evaluator

feature_set = "extra" # 'soma'/'bap'

morphology = ephys.morphologies.NrnFileMorphology('morphology/C060114A7.asc', do_replace_axon=True)
param_configs = json.load(open('config/parameters.json'))
parameters = l5pc_model.define_parameters()
mechanisms = l5pc_model.define_mechanisms()

l5pc_cell = ephys.models.LFPyCellModel('l5pc', 
                                       v_init=-65., 
                                       morph=morphology, 
                                       mechs=mechanisms, 
                                       params=parameters)

param_names = [param.name for param in l5pc_cell.params.values() if not param.frozen]      

if feature_set == "extra":
    fitness_protocols = l5pc_evaluator.define_protocols(electrode) 
else:
    fitness_protocols = l5pc_evaluator.define_protocols() 
    
if feature_set == "extra":
    sim = ephys.simulators.LFPySimulator(LFPyCellModel=l5pc_cell, cvode_active=True, electrode=electrode)
else:
    sim = ephys.simulators.LFPySimulator(LFPyCellModel=l5pc_cell, cvode_active=True)

feature_file, resp = l5pc_evaluator.compute_feature_values(release_params, l5pc_cell, fitness_protocols, sim, 
                                                     feature_set=feature_set, probe=probe, channels=None,
                                                     feature_folder='config/features')

fitness_calculator = l5pc_evaluator.define_fitness_calculator(fitness_protocols, feature_file=feature_file, 
                                                              probe=probe)

evaluator = ephys.evaluators.CellEvaluator(                                          
                cell_model=l5pc_cell,                                                       
                param_names=param_names,                                                    
                fitness_protocols=fitness_protocols,                                        
                fitness_calculator=fitness_calculator,                                      
                sim=sim) 

Running bAP


  'not disabling banner' % nrnpy_path)


Running Step3


  'not disabling banner' % nrnpy_path)


Running Step2


  'not disabling banner' % nrnpy_path)


Running Step1


  'not disabling banner' % nrnpy_path)


Step3 Num features: 10
Step3.soma.AP_height
17.017958284437842
Step3.soma.AHP_slow_time
0.21586736652864358
Step3.soma.ISI_CV
0.09902926906713798
Step3.soma.doublet_ISI
21.100000000004798
Step3.soma.adaptation_index2
0.007140518671968469
Step3.soma.mean_frequency
16.08767784426209
Step3.soma.AHP_depth_abs_slow
-55.903842685412464
Step3.soma.AP_width
0.8593749999994067
Step3.soma.time_to_first_spike
10.600000000093019
Step3.soma.AHP_depth_abs
-56.75947804849142
Step2 Num features: 10
Step2.soma.AP_height
26.49756615115496
Step2.soma.AHP_slow_time
0.1494249779232803
Step2.soma.ISI_CV
0.03602948430867007
Step2.soma.doublet_ISI
44.60000000001014
Step2.soma.adaptation_index2
-0.001441296000788097
Step2.soma.mean_frequency
8.846333975132124
Step2.soma.AHP_depth_abs_slow
-60.81958798459943
Step2.soma.AP_width
0.8882352941170197
Step2.soma.time_to_first_spike
23.000000000095838
Step2.soma.AHP_depth_abs
-60.693804249171926
Step1 Num features: 150
Step1.soma.AP_height
27.891857515758772
Step1.so

In [None]:
for objective in fitness_calculator.objectives:
    print(objective.features[0].name)

In [None]:
t_start = time.time()
LFPy_responses = evaluator.run_protocols(protocols=fitness_protocols.values(), 
                                         param_values=release_params)
t_end = time.time()
print(t_end-t_start)

In [None]:
LFPy_responses

### Visualize waveforms / features

In [None]:
extrafeat = fitness_calculator.objectives[-1].features[0]

In [None]:
feat, mean_wf = extrafeat.calculate_feature(LFPy_responses, return_waveforms=True)

In [None]:
mean_wf.shape

In [None]:
tr, pk = bpopt.ephys.extra_features_utils._get_trough_and_peak_idx(mean_wf, False)

In [None]:
plt.figure()
cmap = plt.get_cmap('viridis')
idxs = range(0, 10)
for i, mw in enumerate(mean_wf):
    if i in idxs:
        color = cmap(i/len(mean_wf))
        plt.plot(mw / np.max(np.abs(mw)), color=color, lw=0.3)
        plt.axvline(tr[i], ls='--', color=color)
        plt.axvline(pk[i], ls='-.', color=color)

In [None]:
mean_wf_norm = mean_wf / np.max(np.abs(mean_wf), 1, keepdims=True)
mu.plot_mea_recording(mean_wf_norm, probe)

mu.plot_mea_recording(mean_wf, probe)

### TODO: make sure all values are correct

In [None]:
bpopt.ephys.extra_features_utils.calculate_features(mean_wf, extrafeat.fs * 1000, 
                                                    feature_names=['peak_to_valley', 'peak_trough_ratio', 'halfwidth',
                                                                   'neg_peak_relative', 'pos_peak_relative', 
                                                                   'neg_peak_diff', 'pos_peak_diff'])

In [None]:
wf = mean_wf[3]
plt.figure()
plt.plot(wf, '*')

In [None]:
def plot_responses(responses):
    resp_no_mea = {}
    for (resp_name, response) in sorted(responses.items()):
        if 'MEA' not in resp_name:
            resp_no_mea[resp_name] = response
        else:
            print(resp_name)
    fig, axes = plt.subplots(len(resp_no_mea), figsize=(10,10))
    for index, (resp_name, response) in enumerate(sorted(resp_no_mea.items())):
        axes[index].plot(response['time'], response['voltage'], label=resp_name)
        axes[index].set_title(resp_name)
    fig.tight_layout()
    fig.show()

In [None]:
plot_responses(LFPy_responses)

In [None]:
# plot one MEA response
plt.figure()
mea_response = LFPy_responses['Step3.MEA.LFP']
_ = plt.plot(mea_response['time'], mea_response['voltage'].T)

## Optimization

In [None]:
offspring_size = 250
max_ngen = 100

In [None]:
import multiprocessing

pool = multiprocessing.Pool(processes=4)

map_function = pool.map

In [None]:
opt = bpopt.optimisations.DEAPOptimisation(                                     
    evaluator=evaluator,                                                            
    offspring_size=offspring_size,
    map_function=map_function) 

In [None]:
t_start = time.time()
final_pop, halloffame, log, hist = opt.run(max_ngen=max_ngen, cp_filename='checkpoints/checkpoint_extra.pkl')
t_stop = time.time()
print('Optimization time', t_stop - t_start)

In [None]:
best_params = evaluator.param_dict(halloffame[0])
print(pp.pprint(best_params))
best_responses = evaluator.run_protocols(protocols=fitness_protocols.values(), param_values=best_params)

In [None]:
plot_responses(best_responses)