# Export multimodal datasets to NWB

This notebook requires these additional packages:

- pynwb
- pillow
- nwb-conversion-tools (version >= 0.11.38)
- spikeinterface (version >= 0.94)
- nwbwidgets (optional)

The notebook assumes that the experimental data are available in the `experimental_data` folder, including:
- mea
- patch
- imaging (only used for max Z projection)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
import os
import shutil
from pathlib import Path
from PIL import Image
from natsort import natsorted
from datetime import datetime, timedelta

import multimodalfitting.mea_patch as mp

import pynwb
from pynwb import NWBFile, NWBHDF5IO
from pynwb.file import Subject
from pynwb.icephys import CurrentClampStimulusSeries, CurrentClampSeries

from nwb_conversion_tools.tools.spikeinterface import write_recording, get_nwb_metadata, add_electrical_series
from nwbwidgets import nwb2widget

%matplotlib inline

In [None]:
job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

In [None]:
def get_recording_start_time(mea_file):

    # get correct start time
    f = h5py.File(mea_file, "r")
    date_str = f['time'][0].decode()
    date_str_split = date_str.split("\n")[0][date_str.find(
        "start:") + len("start:") + 1:date_str.find(";")]


    date = datetime.fromisoformat(date_str_split)
    return date

In [None]:
base_dir = Path("../../")

In [None]:
cell_name = "cell1_211006_3148"
# cell_name = "cell1_211011_3436"

experimental_folder = base_dir / "experimental_data" / cell_name

In [None]:
mea_folder = experimental_folder / "mea_data"
patch_folder = experimental_folder / "patch_data"

In [None]:
mea_files = sorted([p for p in mea_folder.iterdir() if "raw.h5" in p.name])
patch_files = sorted([p for p in patch_folder.iterdir() if "wcp" in p.name])

In [None]:
runs = [p.name.split(".")[0] for p in mea_files]
print(runs)

In [None]:
if cell_name == "cell1_211006_3148":
    run_ids = [1, 2, 3, 4, 5]  
elif cell_name == "cell1_211011_3436":
    run_ids = [3, 4, 5, 6]

runs = [run for run in runs if any([int(run[-1]) in run_ids])]
print(runs)

In [None]:
start_file = [mea for mea in mea_files if runs[0] in mea.name][0]

In [None]:
start_time = get_recording_start_time(start_file)

In [None]:
if cell_name == "cell1_211006_3148":
    run_map = None
elif cell_name == "cell1_211011_3436":
    run_map = {'run3': 'run1', 'run4': 'run2', 'run5': 'run3', 'run6': 'run4'}

In [None]:
session_description = f"Simultaneous patch-clamp/HD-MEA recording using E-CODE protocols for cell {cell_name} for {len(runs)} experimental runs."
experiment_description = f"Simultaneous patch-clamp/HD-MEA recording using E-CODE protocols"

# Create Subject
subject = Subject(subject_id=cell_name.split("_")[-1], 
                  description="Rat cortical embryonic cell culture",
                  species="Rattus norvegicus")

# Create an ICEphysFile
nwbfile = NWBFile(
    session_description=session_description,
    identifier=cell_name,
    session_start_time=start_time,
    subject=subject
)

In [None]:
# add basic metadata
nwbfile.experimenter = "Alessio Buccino and Julian Bartram"
nwbfile.lab = "Bio Engineering Laboratory (BEL) - Department of Bio Systems Science and Engineeering (D-BSSE)"
nwbfile.institution = "ETH Zurich"
nwbfile.experiment_description = experiment_description

In [None]:
# Add a device
device = nwbfile.create_device(name='MultiClamp 700B amplifier (Axon Instruments) - Axon Digidata 1440A (Axon Instruments)')

# Add an intracellular electrode
electrode = nwbfile.create_icephys_electrode(
    name="Patch-clamp electrode",
    description='Whole-cell patch pipette',
    device=device
)

In [None]:
# load image
image_file = experimental_folder / "imaging_data" / "deconvolved" / "max_projection.tif"
img = Image.open(image_file)
img_array = np.array(img)
plt.matshow(img_array)

In [None]:
resolution = 112.5 * 1e-7 #cm per pixels
print(1 / resolution, "pixels * cm")

In [None]:
max_z_proj = pynwb.image.GrayscaleImage(name="Confocal max z projection", 
                                        data=img_array, 
                                        description="Maximum Z projection of the Z-stack used for morphology reconstruction",
                                        resolution=resolution)

In [None]:
imaging_module = nwbfile.create_processing_module(
    name="imaging", description="processed imaging data"
)
images = pynwb.base.Images(name="confocal", images=[max_z_proj])
imaging_module.add(images)

In [None]:
protocol_names = []
run_names = []
current_sweep = 0

bin_folder = Path(f"{cell_name}_bin")
if bin_folder.is_dir():
    shutil.rmtree(bin_folder)
bin_folder.mkdir()

for i, run in enumerate(runs):
    print(f"\n\nAdding run: {run}\n\n")
    mea_file_run = [m for m in mea_files if run in m.name][0]
    patch_files_protocols = natsorted([p.name.split(".")[1] for p in patch_files if run in p.name])
    patch_files_run = []
    for prot in patch_files_protocols:
        patch_files_run.append([p for p in patch_files if prot in p.name and run in p.name][0])
        
    subrec, patch, timestamps, ttl_mea_sync = mp.sync_patch_mea(mea_file_run, patch_files_run, 
                                                                patch_ttl_channel=2, 
                                                                correct_mea_times=True, verbose=True,
                                                                remove_blank_mea_channels=False, 
                                                                return_patch_single_sweeps=True)
    electrodes = subrec.get_property("electrode")
    subrec_bin = subrec.save(folder=bin_folder / run, **job_kwargs)
    start_time_run = get_recording_start_time(mea_file_run)
    tdelta = (start_time_run - start_time).seconds
    subrec_bin.set_times(subrec.get_times() + tdelta)

    last_protocol = None
    sweeps_in_protocol = []
    sequences_in_run = []

    for p in patch:
        resp = p["data"][0]
        stim = p["data"][3]
        name = str(Path(p['name']).stem)
        if run_map:
            name = name.replace(run, run_map[run])
        protocol_name = name.split('.')[1].split("_")[1]
        run_name = name.split('.')[0]
        timestamps = p["time"] + tdelta

        # Create an ic-ephys stimulus
        stimulus = CurrentClampStimulusSeries(
            name=f"stimulus_{protocol_name}_{run_name}_{current_sweep}",
            data=stim,
            timestamps=timestamps,
            electrode=electrode,
            stimulus_description=protocol_name,
            description=f"Injected current",
            gain=1e-9,
            sweep_number=current_sweep,
            unit="amperes"
        )

        # Create an ic-response
        response = CurrentClampSeries(
            name=f"response_{protocol_name}_{run_name}_{current_sweep}",
            data=resp,
            resolution=np.nan,
            timestamps=timestamps,
            electrode=electrode,
            stimulus_description=protocol_name,
            description=f"Recorded somatic membrane potential",
            gain=1e-3,
            sweep_number=current_sweep,
            unit="volts"
        )

        # Create recording
        ir_index = nwbfile.add_intracellular_recording(
            electrode=electrode,
            stimulus=stimulus,
            response=response
        )

        # Create simultaneous recording (only one rec in our case)
        sweep_index = nwbfile.add_icephys_simultaneous_recording(recordings=[ir_index,])

        if last_protocol is None:
            sweeps_in_protocol.append(sweep_index)
        elif protocol_name == last_protocol:
            sweeps_in_protocol.append(sweep_index)
        else:
            # (C) Add a list of simultaneous recordings table indices as a sequential recording
            sequence_index = nwbfile.add_icephys_sequential_recording(
                simultaneous_recordings=sweeps_in_protocol,
                stimulus_type=last_protocol
            )
            sweeps_in_protocol = []
            sequences_in_run.append(sequence_index)
        last_protocol = protocol_name

        current_sweep += 1
        protocol_names.append(protocol_name)
        run_names.append(run_name)

    # Add a list of sequential recordings table indices as a repetition
    run_index = nwbfile.add_icephys_repetition(sequential_recordings=sequences_in_run)
    
    # Add Ecephys
    if run_map:
        run_name = run_map[run]
    else:
        run_name = run
    es_key = f"ElectricalSeries_{run_name}"
    print(es_key)
    metadata_ecephys = get_nwb_metadata(subrec)
    metadata_ecephys["Ecephys"]["Device"][0]["name"] = "Mea1k HD-MEA"
    metadata_ecephys["Ecephys"]["Device"][0]["description"] = "Mea1k HD-MEA device with 26'400 electrodes. 1024 recorded simultaneously."
    metadata_ecephys["Ecephys"]["ElectrodeGroup"][0]["device"] = "Mea1k HD-MEA"
    metadata_ecephys['Ecephys'][es_key] = {
            'name': es_key,
            'description': f"HD-MEA extracellular recording for {run}"
        }
    
    if i == 0:
        nwbfile = write_recording(subrec_bin, nwbfile=nwbfile, metadata=metadata_ecephys,
                                  es_key=es_key, use_times=True)
    else:
        add_electrical_series(subrec_bin, nwbfile=nwbfile, metadata=metadata_ecephys, 
                              es_key=es_key, use_times=True)
        

nwbfile.intracellular_recordings.add_column(
    name='protocol_name',
    data=protocol_names,
    description='eCode protocol name'
)
nwbfile.intracellular_recordings.add_column(
    name='run',
    data=run_names,
    description='Run number'
)

In [None]:
nwbfile.processing

In [None]:
nwb_path = Path(f"{cell_name}.nwb")

In [None]:
if nwb_path.is_file():
    nwb_path.unlink()

In [None]:
with NWBHDF5IO(str(nwb_path), "w") as io:
    io.write(nwbfile)

In [None]:
if cell_name == "cell1_211006_3148":
    max_chan = 'ch885'
elif cell_name == "cell1_211011_3436":
    max_chan = 'ch384'

with NWBHDF5IO(str(nwb_path), "r") as io:
    read_nwbfile = io.read()
    max_chan_name = [ch for ch in read_nwbfile.electrodes["channel_name"].data[:] if max_chan in ch][0]
    max_id = list(read_nwbfile.electrodes["channel_name"].data[:]).index(max_chan_name)
    
    fig, axs = plt.subplots(nrows=2, sharex=True, figsize=(10, 10))
    ax0 = axs[0]
    ax1 = axs[1]
    for acq_name in read_nwbfile.acquisition:
        acq = read_nwbfile.acquisition[acq_name]
        run_id = acq_name[acq_name.find("run") + 3:acq_name.find("run") + 4]
        if isinstance(acq, pynwb.icephys.CurrentClampSeries):
            ax0.plot(acq.timestamps[()], acq.data[()], color=f"C{run_id}", lw=1)
        else:
            ax1.plot(acq.timestamps[()], acq.data[:, max_id], color=f"C{run_id}", lw=0.5, alpha=0.8)

### (optional) View saved NWB dataset with NWBwidgets

In [None]:
io = NWBHDF5IO(str(nwb_path), "r") 
read_nwbfile = io.read()

nwb2widget(read_nwbfile)

In [None]:
io.close()