In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import h5py

import pynwb 
import spikeinterface.full as si
from natsort import natsorted

import sync_mea_patch as smp
from nwb_conversion_tools.utils.spike_interface import write_recording, get_nwb_metadata, add_electrical_series

from nwbwidgets import nwb2widget

%matplotlib inline

In [None]:
# Standard Python imports
from datetime import datetime, timedelta
from dateutil.tz import tzlocal
import numpy as np
import pandas
# Set pandas rendering option to avoid very wide tables in the html docs
pandas.set_option("display.max_colwidth", 30)
pandas.set_option("display.max_rows", 10)

# Import main NWB file class
from pynwb import NWBFile

from pynwb.file import Subject
# Import icephys TimeSeries types used
from pynwb.icephys import CurrentClampStimulusSeries, CurrentClampSeries
# Import I/O class used for reading and writing NWB files
from pynwb import NWBHDF5IO
# Import additional core datatypes used in the example
from pynwb.core import DynamicTable, VectorData

In [None]:
def get_recording_start_time(mea_file):

    # get correct start time
    f = h5py.File(mea_file, "r")
    date_str = f['time'][0].decode()
    date_str_split = date_str.split("\n")[0][date_str.find(
        "start:") + len("start:") + 1:date_str.find(";")]


    date = datetime.fromisoformat(date_str_split)
    return date

In [None]:
base_dir = Path("../../")

In [None]:
cell_name = "cell1_211006_3148"
experimental_folder = base_dir / "experimental_data" / cell_name

In [None]:
mea_folder = experimental_folder / "mea_data"
patch_folder = experimental_folder / "patch_data"

In [None]:
mea_files = sorted([p for p in mea_folder.iterdir() if "raw.h5" in p.name])
patch_files = sorted([p for p in patch_folder.iterdir() if "wcp" in p.name])

In [None]:
mea_files

In [None]:
runs = [p.name.split(".")[0] for p in mea_files]
print(runs)

In [None]:
start_time = get_recording_start_time(mea_files[0])

In [None]:
# Create Subject
subject = Subject(subject_id=cell_name.split("_")[-1], 
                  description="Cortical embryonic cell culture",
                  species="Wistar Rat")

# Create an ICEphysFile
nwbfile = NWBFile(
    session_description="Simultaneous patch-clamp/HD-MEA recording using E-CODE protocols.",
    identifier=cell_name,
    session_start_time=start_time,
    subject=subject
)

In [None]:
# add basic metadata
nwbfile.experimenter = "Alessio Buccino"
nwbfile.lab = "Bio Engineering Laboratory (BEL) - Department of Bio Systems Science and Engineeering (D-BSSE)"
nwbfile.institution = "ETH Zurich"
nwbfile.experiment_description = experiment_description

In [None]:
# Add a device
device = nwbfile.create_device(name='MultiClamp 700B amplifier (Axon Instruments) - Axon Digidata 1440A (Axon Instruments)')

# Add an intracellular electrode
electrode = nwbfile.create_icephys_electrode(
    name="elec0",
    description='Whole-cell patch pipette',
    device=device
)


In [None]:
protocol_names = []
run_names = []
current_sweep = 0

for i, run in enumerate(runs):
    print(f"\n\nAdding run {run}\n\n")
    mea_file_run = [m for m in mea_files if run in m.name][0]
    patch_files_protocols = natsorted([p.name.split(".")[1] for p in patch_files if run in p.name])
    patch_files_run = []
    for prot in patch_files_protocols:
        patch_files_run.append([p for p in patch_files if prot in p.name and run in p.name][0])
        
    subrec, patch, timestamps, ttl_mea_sync = smp.sync_patch_mea(mea_file_run, patch_files_run, patch_ttl_channel=2, 
                                                             correct_mea_times=True, verbose=True,
                                                             remove_blank_mea_channels=False, average_sweeps=False,
                                                             return_single_sweeps=True)
    electrodes = subrec.get_property("electrode")
    print(electrodes[:5])
    
    start_time_run = get_recording_start_time(mea_file_run)
    tdelta = (start_time_run - start_time).seconds
    subrec.set_times(subrec.get_times() + tdelta)
    starting_time_ecephys = subrec.get_times()[0]

    last_protocol = None
    sweeps_in_protocol = []
    sequences_in_run = []

    for p in patch:
        rate = float(np.round(1. / np.median(np.diff(p["time"]))))
        stim = p["data"][3]
        resp = p["data"][0]
        sweep_id = int(p['name'].split('-')[1])
        protocol_name = p['name'].split('.')[1].split("_")[1]
        run_name = p['name'].split('.')[0]
        timestamps = p["time"] + tdelta

        # Create an ic-ephys stimulus
        stimulus = CurrentClampStimulusSeries(
            name=f"stimulus_{p['name']}",
            data=stim,
            timestamps=timestamps,
            electrode=electrode,
            gain=1.,
            sweep_number=current_sweep
        )

        # Create an ic-response
        response = CurrentClampSeries(
            name=f"response_{p['name']}",
            data=resp,
            resolution=np.nan,
            timestamps=timestamps,
            electrode=electrode,
            gain=1.,
            sweep_number=current_sweep
        )

        # Create recording
        ir_index = nwbfile.add_intracellular_recording(
            electrode=electrode,
            stimulus=stimulus,
            response=response
        )

        # Create simultaneous recording (only one rec in our case)
        sweep_index = ex_nwbfile.add_icephys_simultaneous_recording(recordings=[ir_index,])

        if last_protocol is None:
            sweeps_in_protocol.append(sweep_index)
        elif protocol_name == last_protocol:
            sweeps_in_protocol.append(sweep_index)
        else:
            # (C) Add a list of simultaneous recordings table indices as a sequential recording
            sequence_index = nwbfile.add_icephys_sequential_recording(
                simultaneous_recordings=sweeps_in_protocol,
                stimulus_type=last_protocol
            )
            sweeps_in_protocol = []
            sequences_in_run.append(sequence_index)
        last_protocol = protocol_name

        current_sweep += 1
        protocol_names.append(protocol_name)
        run_names.append(run_name)

    # Add a list of sequential recordings table indices as a repetition
    run_index = nwbfile.add_icephys_repetition(sequential_recordings=sequences_in_run)
    
#     # Add Ecephys
#     es_key = f"ElectricalSeries_{run}"
#     metadata_ecephys = get_nwb_metadata(subrec)
#     metadata_ecephys["Ecephys"]["Device"][0]["name"] = "Mea1k HD-MEA"
#     metadata_ecephys["Ecephys"]["Device"][0]["description"] = "Mea1k HD-MEA device with 26'400 electrodes. 1024 recorded simultaneously."
#     metadata_ecephys["Ecephys"]["ElectrodeGroup"][0]["device"] = "Mea1k HD-MEA"
#     metadata_ecephys['Ecephys'][es_key] = {
#             'name': es_key,
#             'description': f"HD-MEA extracellular recording for {run}"
#         }
    
#     if i == 0:
#         nwbfile = write_recording(subrec, nwbfile=nwbfile, metadata=metadata_ecephys,
#                                   es_key=es_key, use_times=True)
#     else:
#         add_electrical_series(subrec, nwbfile=nwbfile, starting_time=starting_time_ecephys,
#                               metadata=metadata_ecephys, es_key=es_key, use_times=True)
        

nwbfile.intracellular_recordings.add_column(
    name='protocol_name',
    data=protocol_names,
    description='E-CODE protocol name'
)
nwbfile.intracellular_recordings.add_column(
    name='run',
    data=run_names,
    description='Run number'
)

In [None]:
with NWBHDF5IO(f"{cell_name}.nwb", "w") as io:
    io.write(nwbfile)

In [None]:
max_chan = '885'

with NWBHDF5IO(f"{cell_name}.nwb", "r") as io:
    read_nwbfile = io.read()
#     max_id = read_nwbfile.electrodes["channel_id"].data.index(max_chan)
    
    fig, axs = plt.subplots(nrows=2, sharex=True)
    for acq_name in read_nwbfile.acquisition:
        acq = read_nwbfile.acquisition[acq_name]
        run_id = acq_name[acq_name.find("run") + 3:acq_name.find("run") + 4]
        if isinstance(acq, pynwb.icephys.CurrentClampSeries):
            axs[0].plot(acq.timestamps[()], acq.data[()], color=f"C{run_id}")
        else:
            print(acq_name, acq.data.shape)
#             axs[1].plot(acq.timestamps[()], acq.data[:, max_id], color=f"C{run_id}")

## Test one run

In [None]:
run = runs[0]

In [None]:
mea_file_run = [m for m in mea_files if run in m.name][0]
patch_files_protocols = natsorted([p.name.split(".")[1] for p in patch_files if run in p.name])
patch_files_run = []
for prot in patch_files_protocols:
    patch_files_run.append([p for p in patch_files if prot in p.name and run in p.name][0])

In [None]:
mea_file_run

In [None]:
patch_files_run

In [None]:
subrec, patch, timestamps, ttl_mea_sync = smp.sync_patch_mea(mea_file_run, patch_files_run, patch_ttl_channel=2, 
                                                             correct_mea_times=True, verbose=True,
                                                             remove_blank_mea_channels=False, average_sweeps=False,
                                                             return_single_sweeps=True)

In [None]:
plot_signals = False

In [None]:
if plot_signals:
    plt.figure()
    rec_f = si.bandpass_filter(subrec)
    tr_max = rec_f.get_traces(channel_ids=['885'])

    for p in patch:
        plt.plot(p["time"], p["data"][3], alpha=0.5, label=p["name"])

    plt.plot(rec_f.get_times(), tr_max[:, 0], alpha=0.5, color="C1", label="extra")

In [None]:
protocol_names = []
run_names = []

last_protocol = None
sweeps_in_protocol = []
sequences_in_run = []
current_sweep = 0

for p in patch:
    rate = float(np.round(1. / np.median(np.diff(p["time"]))))
    stim = p["data"][3]
    resp = p["data"][0]
    sweep_id = int(p['name'].split('-')[1])
    name = f"{p['name'].split('.')[1]}-{p['name'].split('-')[1]}"
    protocol_name = p['name'].split('.')[1].split("_")[1]
    run_name = p['name'].split('.')[0]
    
    # Create an ic-ephys stimulus
    stimulus = CurrentClampStimulusSeries(
        name=f"stimulus_{name}",
        data=stim,
        timestamps=p["time"],
        electrode=electrode,
        gain=1.,
        sweep_number=current_sweep
    )

    # Create an ic-response
    response = CurrentClampSeries(
        name=f"response_{name}",
        data=resp,
        resolution=np.nan,
        timestamps=p["time"],
        electrode=electrode,
        gain=1.,
        sweep_number=current_sweep
    )
    
    # Create recording
    ir_index = nwbfile.add_intracellular_recording(
        electrode=electrode,
        stimulus=stimulus,
        response=response
    )
    
    # Create simultaneous recording (only one rec in our case)
    sweep_index = ex_nwbfile.add_icephys_simultaneous_recording(recordings=[ir_index,])
    
    if last_protocol is None:
        sweeps_in_protocol.append(sweep_index)
    elif protocol_name == last_protocol:
        sweeps_in_protocol.append(sweep_index)
    else:
        # (C) Add a list of simultaneous recordings table indices as a sequential recording
        sequence_index = nwbfile.add_icephys_sequential_recording(
            simultaneous_recordings=sweeps_in_protocol,
            stimulus_type=last_protocol
        )
        sweeps_in_protocol = []
        sequences_in_run.append(sequence_index)
    last_protocol = protocol_name
    
    current_sweep += 1
    protocol_names.append(protocol_name)
    run_names.append(run_name)

# (D) Add a list of sequential recordings table indices as a repetition
run_index = nwbfile.add_icephys_repetition(sequential_recordings=sequences_in_run)

nwbfile.intracellular_recordings.add_column(
    name='protocol_name',
    data=protocol_names,
    description='E-CODE protocol name'
)
nwbfile.intracellular_recordings.add_column(
    name='run',
    data=run_names,
    description='Run number'
)

In [None]:
nwbfile.intracellular_recordings.to_dataframe()