In [None]:
%cd ..

In [None]:
import epych
import glob
import h5py
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import quantities as pq

In [None]:
%matplotlib inline

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
def epoch_intervals(intervals, epoch):
    mask = intervals[epoch][:].astype(bool)
    return np.stack((intervals['start_time'][mask], intervals['stop_time'][mask]), axis=-1)

In [None]:
def trial_intervals(intervals):
    trial_nums = intervals['trial_num'][:].astype(int)
    trials = np.unique(trial_nums)
    startstops = []
    for trial in trials:
        indices = np.nonzero(trial_nums == trial)[0]
        startstops.append((intervals['start_time'][indices[0]], intervals['stop_time'][indices[-1]]))
    return np.array(startstops)

In [None]:
def trial_stimulus_intervals(intervals):
    trial_nums = intervals['trial_num'][:].astype(int)
    trials = np.unique(trial_nums)
    stimuli = {k: [] for k in range(5)}
    for trial in trials:
        for index, interval in enumerate(np.nonzero(trial_nums == trial)[0]):
            stimuli[index].append((intervals['start_time'][interval], intervals['stop_time'][interval]))
    return {k: np.array(v) for k, v in stimuli.items()}

In [None]:
CONDITIONS = ["go_gloexp", "lo_gloexp", "go_seqctl", "seqctl"]
PRETRIAL_SECONDS = 0.5
POSTTRIAL_SECONDS = 0.5

In [None]:
def hippocampal_areas(probe_areas):
    for (c, a) in enumerate(probe_areas):
        if "DG-" in a or "CA" in a:
            yield (c, a)

def visual_areas(probe_areas):
    for (c, a) in enumerate(probe_areas):
        if "VIS" in a:
            yield (c, a)

def subcortical_areas(probe_areas):
    for (c, a) in enumerate(probe_areas):
        if "DG-" in a or "CA" in a or "MB" in a or "SCi" in a or "POST" in a:
            yield (c, a)

In [None]:
def probe_electrode_metadata(electrodes, probe, channels):
    indices = (electrodes['group_name'][:] == probe.encode()).nonzero()[0]
    return {
        'areas': np.array([area.decode() for area in electrodes['location'][indices][channels]]),
        'horizontal': electrodes['probe_horizontal_position'][indices][channels] * pq.mm,
        'vertical': electrodes['probe_vertical_position'][indices][channels] * pq.mm,
    }

In [None]:
NWB_FILES = glob.glob('/mnt/data/*.nwb')

In [None]:
PILOT_FILES = []

In [None]:
def probe_lfps(nwb, electrodes, probes):
    signals = {}
    for p, probe in enumerate(probes):
        probe_lfp = "probe_%d_lfp" % p
        probe_lfp = probe_lfp + "/" + probe_lfp + "_data"
        probe_lfps = nwb['acquisition/'][probe_lfp]
        timestamps = probe_lfps["timestamps"][:]
        dt = (timestamps[1:] - timestamps[:-1]).mean()
        probe_data = np.transpose(probe_lfps["data"])[:, :, np.newaxis]
    
        probe_electrodes = electrodes["group_name"][:] == probe.encode()
        channels = {
            "horizontal": electrodes["probe_horizontal_position"][probe_electrodes],
            "id": electrodes["id"][probe_electrodes],
            "location": electrodes["location"][probe_electrodes],
            "vertical": electrodes["probe_vertical_position"][probe_electrodes],
        }
        channels = {k: v[probe_lfps['electrodes'][:]] for k, v in channels.items()}
        probe_channels = pd.DataFrame(data=channels, columns=["horizontal", "location", "vertical"], index=channels['id'])
        assert probe_data.shape[0] == len(probe_channels)
    
        signals[probe] = epych.signals.lfp.ContinuousLfp(probe_channels, probe_data, dt, timestamps)
    return signals

In [None]:
def passiveglo_epochs(glo_intervals):
    epochs = {
        "start": np.array([]),
        "end": np.array([]),
        "type": [],
    }

    trials = trial_intervals(glo_intervals)
    trial_times = trials[:, 0]
    epochs["start"] = np.concatenate((epochs["start"], trials[:, 0]), axis=0)
    epochs["end"] = np.concatenate((epochs["end"], trials[:, 1]), axis=0)
    epochs["type"] = epochs["type"] + ["trial"] * trials.shape[0]

    for condition in CONDITIONS:
        cond_times = epoch_intervals(glo_intervals, condition)
        epochs["start"] = np.concatenate((epochs["start"], cond_times[:, 0]), axis=0)
        epochs["end"] = np.concatenate((epochs["end"], cond_times[:, 1]), axis=0)
        epochs["type"] = epochs["type"] + [condition] * cond_times.shape[0]

    stim_times = trial_stimulus_intervals(glo_intervals)
    for k, v in stim_times.items():
        epochs["start"] = np.concatenate((epochs["start"], v[:, 0]), axis=0)
        epochs["end"] = np.concatenate((epochs["end"], v[:, 1]), axis=0)
        epochs["type"] = epochs["type"] + ["stim%d" % k] * v.shape[0]

    return pd.DataFrame(data=epochs, columns=epochs.keys())

In [None]:
def nwbfile_recording(nwb_file, required_probes=6):
    nwb = h5py.File(nwb_file, 'r')
    try:
        electrodes = nwb['general']['extracellular_ephys']['electrodes']
        probes = sorted([probe.decode() for probe in set(electrodes['group_name'][:])])
        if len(probes) < required_probes:
            PILOT_FILES.append(nwb_file)
            return None
        glo_intervals = nwb['intervals']['passive_glo']
    
        epochs = passiveglo_epochs(glo_intervals)
        signals = probe_lfps(nwb, electrodes, probes)
    
        units = {"start": pq.second, "end": pq.second}
        return epych.recording.Recording(epochs, pd.DataFrame(columns=["trial"]).set_index("trial"), units, **signals)
    finally:
        nwb.close()

In [None]:
erps = {}

In [None]:
GOOD_NWB_FILES = []
good_recordings = []

In [None]:
for f, nwb_file in enumerate(NWB_FILES):
    recording = nwbfile_recording(nwb_file)
    if recording is None:
        continue
    else:
        GOOD_NWB_FILES.append(nwb_file)
        for cond in CONDITIONS:
            condition_epochs = recording.intervals["type"] == cond
            trial_epochs = recording.intervals["type"] == "trial"
            sampling = recording.epoch(condition_epochs, trial_epochs, PRETRIAL_SECONDS, POSTTRIAL_SECONDS).baseline_correct(0, PRETRIAL_SECONDS)
            erps[(nwb_file, cond)] = sampling.erp()
            del sampling
        del recording

In [None]:
for cond in CONDITIONS:
    print(cond)
    for nwb_file in GOOD_NWB_FILES:
        erps[(nwb_file, cond)].plot(vmin=-1e-4, vmax=1e-4)

In [None]:
# for cond in CONDITIONS:
#     for f, nwb_file in enumerate(NWB_FILES):
#         nwb = h5py.File(nwb_file, 'r')
#         electrodes = nwb['general']['extracellular_ephys']['electrodes']
#         probes = sorted([probe.decode() for probe in set(electrodes['group_name'][:])])
#         glo_intervals = nwb['intervals']['passive_glo']
#         intervals = trial_intervals_with_event(glo_intervals, cond)
#         if len(probes) < 6:
#             PILOT_FILES.append(nwb_file)

#         for p, probe in enumerate(probes):
#             lfp_name = "probe_" + str(p) + "_lfp"
#             probe_lfps = nwb['acquisition'][lfp_name][lfp_name + '_data']
            
#             lfp_hz = int(1 / (probe_lfps['timestamps'][1] - probe_lfps['timestamps'][0]))
#             assert lfp_hz > 0
#             baseline_samples = int(lfp_hz * PRETRIAL_SECONDS)

#             start_samples = nearest_indices(probe_lfps['timestamps'][:], intervals[:, 0] - PRETRIAL_SECONDS)
#             stop_samples = nearest_indices(probe_lfps['timestamps'][:], intervals[:, 1] + POSTTRIAL_SECONDS)
#             trial_length = (stop_samples - start_samples).min()
#             assert trial_length > 0
#             for t in range(len(stop_samples)):
#                 if stop_samples[t] - start_samples[t] > trial_length:
#                     stop_samples[t] = start_samples[t] + trial_length

#             trial_lfps = epoch_timeseries(probe_lfps['data'], zip(start_samples, stop_samples))
#             trial_lfps = np.swapaxes(trial_lfps, 0, 1)
            
#             erps[(cond, nwb_file, probe)] = correct_baseline(trial_lfps, 0, baseline_samples-1).mean(axis=-1)
#             logging.info("Calculated ERP for %s" % probe)
#         logging.info("Calculated ERPs for %s" % nwb_file)
#         nwb.close()
#     logging.info("Calculated ERPs for %s" % cond)

In [None]:
# for cond in CONDITIONS:
#     fig, axes = plt.subplots(len(GOOD_NWB_FILES), 6, figsize=(6 * 15, 15))
#     fig.suptitle(cond)

#     for f, nwb_file in enumerate(GOOD_NWB_FILES):
#         nwb = h5py.File(nwb_file, 'r')
#         electrodes = nwb['general']['extracellular_ephys']['electrodes']
#         probes = sorted([probe.decode() for probe in set(electrodes['group_name'][:])])
#         glo_intervals = nwb['intervals']['passive_glo']
#         trial_intervals = trial_intervals_with_event(glo_intervals, cond)
#         condstim_intervals = event_intervals(glo_intervals, cond)
#         stim_onset, stim_offset = (condstim_intervals - trial_intervals[:, np.newaxis, 0]).mean(axis=0)

#         for p, probe in enumerate(probes):
#             lfp_name = "probe_" + str(p) + "_lfp"
#             probe_lfps = nwb['acquisition'][lfp_name][lfp_name + '_data']
#             probe_electrodes = probe_electrode_metadata(electrodes, probe, probe_lfps['electrodes'][:])
#             lfp_hz = int(1 / (probe_lfps['timestamps'][1] - probe_lfps['timestamps'][0]))
#             visual_channels = list(visual_areas(probe_electrodes['areas']))
#             areas = [a for (c, a) in visual_channels]
#             cs = [c for (c, a) in visual_channels]
#             probe_area = os.path.commonprefix(areas)

#             plot_heatmap(fig, axes[f, p], erps[(cond, nwb_file, probe)][cs], os.path.basename(nwb_file) + '/' + probe_area, vmin=-1e-4, vmax=1e-4)
#             mark_areas(axes[f, p], areas)
#             axes[f, p].axvline((stim_onset + PRETRIAL_SECONDS) * lfp_hz, linestyle='--', color='lightgreen')
#             axes[f, p].axvline((stim_offset + PRETRIAL_SECONDS) * lfp_hz, linestyle='--', color='red')

#             trial_samples = erps[(cond, nwb_file, probe)].shape[1]
#             xtick_locs = np.linspace(0, trial_samples, 20)
#             xticks = np.linspace(0., trial_samples / lfp_hz, 20) - stim_onset - PRETRIAL_SECONDS
#             xticks = ["%0.2f" % t for t in xticks]
#             axes[f, p].set_xticks(xtick_locs, xticks)

#         nwb.close()

#     fig.tight_layout()
#     plt.show()
#     fig.savefig(cond + '_visual_erps.pdf')
#     plt.close(fig)

In [None]:
# for cond in CONDITIONS:
#     fig, axes = plt.subplots(len(GOOD_NWB_FILES), 6, figsize=(6 * 15, 15))
#     fig.suptitle(cond)

#     for f, nwb_file in enumerate(GOOD_NWB_FILES):
#         nwb = h5py.File(nwb_file, 'r')
#         electrodes = nwb['general']['extracellular_ephys']['electrodes']
#         probes = sorted([probe.decode() for probe in set(electrodes['group_name'][:])])
#         glo_intervals = nwb['intervals']['passive_glo']
#         trial_intervals = trial_intervals_with_event(glo_intervals, cond)
#         condstim_intervals = event_intervals(glo_intervals, cond)
#         stim_onset, stim_offset = (condstim_intervals - trial_intervals[:, np.newaxis, 0]).mean(axis=0)

#         for p, probe in enumerate(probes):
#             lfp_name = "probe_" + str(p) + "_lfp"
#             probe_lfps = nwb['acquisition'][lfp_name][lfp_name + '_data']
#             probe_electrodes = probe_electrode_metadata(electrodes, probe, probe_lfps['electrodes'][:])
#             lfp_hz = int(1 / (probe_lfps['timestamps'][1] - probe_lfps['timestamps'][0]))

#             subcortical_channels = list(subcortical_areas(probe_electrodes['areas']))
#             areas = [a for (c, a) in subcortical_channels]
#             cs = [c for (c, a) in subcortical_channels]
#             probe_area = os.path.commonprefix(areas)

#             plot_heatmap(fig, axes[f, p], erps[(cond, nwb_file, probe)][cs], os.path.basename(nwb_file) + '/' + probe_area, vmin=-1e-4, vmax=1e-4)
#             mark_areas(axes[f, p], areas)
#             axes[f, p].axvline((stim_onset + PRETRIAL_SECONDS) * lfp_hz, linestyle='--', color='lightgreen')
#             axes[f, p].axvline((stim_offset + PRETRIAL_SECONDS) * lfp_hz, linestyle='--', color='red')

#             trial_samples = erps[(cond, nwb_file, probe)].shape[1]
#             xtick_locs = np.linspace(0, trial_samples, 20)
#             xticks = np.linspace(0., trial_samples / lfp_hz, 20) - stim_onset - PRETRIAL_SECONDS
#             xticks = ["%0.2f" % t for t in xticks]
#             axes[f, p].set_xticks(xtick_locs, xticks)

#         nwb.close()

#     fig.tight_layout()
#     plt.show()
#     fig.savefig(cond + '_subcortical_erps.pdf')
#     plt.close(fig)