# Generate experimental features -  Cell 1 - 010321

In [None]:
import json
import matplotlib.pyplot as plt

import bluepyopt as bpopt
import bluepyopt.ephys as ephys

import model
import evaluator
import time
import neuron
import plotting
import MEAutility as mu
from pprint import pprint
import numpy as np

import sys
import shutil

from pathlib import Path
import os

%matplotlib notebook

## 1) Generate features and protocols with BluePyEfe

In [None]:
sys.path.append('../efeatures_extraction')

In [None]:
data_folder = Path("../data/experimental/cell1_210301/")

In [None]:
extra_folder = data_folder / "extracellular"

In [None]:
with (extra_folder / "probe.json").open() as f:
    probe_info = json.load(f)
    probe_info["pitch"] = 17.5
    probe = mu.return_mea(info=probe_info)

In [None]:
eap = np.load(extra_folder / "template.npy")
locations = np.load(extra_folder / "locations.npy")
fs = 20000

In [None]:
ax_eap = mu.plot_mea_recording(eap, probe)

In [None]:
efeatures_output_directory = Path(f"../data/experimental/cell1_210301/efeatures")
ephys_dir = Path("../data/experimental/cell1_210301/patch_data/")

In [None]:
from bluepyefe.extract import read_recordings, extract_efeatures_at_targets, compute_rheobase,\
    group_efeatures, create_feature_protocol_files, convert_legacy_targets
from bluepyefe.plotting import plot_all_recordings_efeatures

from extraction_tools import build_wcp_metadata, wcp_reader, get_targets, ecodes_wcp_timings

In [None]:
# select files for different repetitions

In [None]:
runs = [2, 3, 4]  # run1 --> different rheobase

ecode_to_index = {
    "IDthres": 0,
    "firepattern": 1,
    "IV": 2,
    "IDrest": 3,
    "APWaveform": 4,
    "HyperDepol": 5,
    "sAHP": 6,
    "PosCheops": 7
}

files_list = []

for run in runs:
    rep_dict = {}
    for ecode in ecode_to_index:
        rep_dict[ecode] = Path(ephys_dir) / f"cell1_run{run}.{ecode_to_index[ecode]}.wcp"
    files_list.append(rep_dict)

In [None]:
# define timings for this experiment
ecodes_cell1_timings = {
    "IDthres": {
        'ton': 200,
        'toff': 470
    },
    "firepattern": {
        'ton': 500,
        'toff': 4100
    },
    "IV": {
        'ton': 250,
        'toff': 3250
    },
    "IDrest": {
        'ton': 200,
        'toff': 1550
    },
    "APWaveform": {
        'ton': 150,
        'toff': 200
    },
    "HyperDepol": {
        'ton': 200,
        'toff': 920,
        'tmid': 650
    },
    "sAHP": {
        'ton': 200,
        'toff': 1125,
        'tmid': 450,
        'tmid2': 675
    },
    "PosCheops": {
        'ton': 1000,
        't1': 9000,
        't2': 10500,
        't3': 14500,
        't4': 16000,
        'toff': 18660
    }
}

In [None]:
files_list

In [None]:
files_metadata = build_wcp_metadata(cell_id="cell1_010321", files_list=files_list, 
                                    ecode_timings=ecodes_cell1_timings, 
                                    repetition_as_different_cells=False)
pprint(files_metadata["cell1_010321"])

In [None]:
cells = read_recordings(
    files_metadata=files_metadata,
    recording_reader=wcp_reader
)

In [None]:
# define target features for different protocols
targets = get_targets(ecodes_cell1_timings)

In [None]:
targets["firepattern"]["tolerances"] = [20, 40]
targets["IDrest"]["tolerances"] = [20]
targets["PosCheops"]["tolerances"] = [50]
targets["HyperDepol"]["tolerances"] = [30]
targets["APWaveform"]["tolerances"] = [40]
targets["sAHP"]["tolerances"] = [40]

In [None]:
targets = convert_legacy_targets(targets)

In [None]:
pprint(targets)

In [None]:
t_start = time.time()
extract_efeatures_at_targets(
    cells, 
    targets,
)
t_stop = time.time()
print(f"Elapsed time {t_stop - t_start}")

In [None]:
compute_rheobase(
    cells, 
    protocols_rheobase=['IDthres']
)

In [None]:
print(f"Cell rheobase: {cells[0].rheobase}")

In [None]:
for cell in cells:
    for recording in cell.recordings:
        if recording.protocol_name == "HyperDepol":
#             print(recording.protocol_name, recording.amp_rel, recording.amp2_rel)
#         else:
            print(recording.protocol_name, recording.amp_rel, recording.amp2_rel)

In [None]:
# plt.figure()
# for cell in cells:
#     for recording in cell.recordings:
#         if recording.protocol_name == "IV":
#             if recording.amp_rel == 0:
#                 plt.plot(recording.t, recording.voltage, label=f"{np.round(recording.amp_rel)}")
# #             print(recording.protocol_name, recording.amp_rel, recording.amp2_rel)
# #         else:
# #             print(recording.protocol_name, recording.amp_rel, recording.amp2_rel)
# plt.legend()

In [None]:
recording.repetition

In [None]:
protocols = group_efeatures(cells, targets, use_global_rheobase=True)

In [None]:
efeatures, protocol_definitions, current = create_feature_protocol_files(
    cells,
    protocols,
    output_directory=efeatures_output_directory,
    threshold_nvalue_save=1,
    write_files=True,
)

In [None]:
efeatures["firepattern_120"]

In [None]:
efeatures.keys()

In [None]:
efeatures["IDrest_200"]

## 3) Convert to BPO format and append extra features

In [None]:
from extraction_tools import convert_to_bpo_format, append_extrafeatures_to_json, compute_extra_features

In [None]:
protocols_of_interest = ["IDrest_150", "IDrest_250", "IDrest_300", "IV_-100", "IV_-20", "APWaveform_260"]

# remove sag features from IV_-20
exclude_features = {"IV_-20": ['sag_amplitude', 'sag_ratio1', 'sag_ratio2',]}

in_protocol_path = efeatures_output_directory / "protocols.json"
in_efeatures_path = efeatures_output_directory / "features.json"

out_protocol_path = efeatures_output_directory / "protocols_BPO.json"
out_efeatures_path = efeatures_output_directory / "features_BPO.json"

In [None]:
protocols_dict, efeatures_dict = convert_to_bpo_format(in_protocol_path, in_efeatures_path, 
                                                       out_protocol_path, out_efeatures_path, 
                                                       protocols_of_interest=protocols_of_interest, 
                                                       exclude_features=exclude_features,
                                                       std_from_mean=0.2)

In [None]:
pprint(efeatures_dict)

## Threshold EAP and extract featrues

In [None]:
thresh_uV = 5

In [None]:
amp_eap = np.ptp(eap, 1)
above_tr = np.where(amp_eap > thresh_uV)
eap_above = eap[above_tr]
extra_features_above = compute_extra_features(eap_above, fs, upsample=10)
probe_info = probe.info
probe_info["pos"] = locations[above_tr].tolist()
probe_above = mu.return_mea(info=probe_info)

In [None]:
extra_features = compute_extra_features(eap_above, fs, upsample=10)

In [None]:
pprint(extra_features.keys())

In [None]:
efeatures_dict = append_extrafeatures_to_json(extra_features, protocol_name="IDrest_300",
                                              efeatures_path=out_efeatures_path)

In [None]:
pprint(efeatures_dict)

In [None]:
# plot one extra features
f = plotting.plot_feature_map_w_colorbar(extra_features["pos_peak_diff"], probe_above, 
                                         feature_name="pos_peak_diff", label="time (s)")

In [None]:
# save probe_above.json
json.dump(probe_above.info, (efeatures_output_directory / "probe_BPO.json").open("w"))