In [None]:
%cd ..

In [None]:
import epych
import glob
import h5py
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import quantities as pq

In [None]:
%matplotlib inline

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
def epoch_intervals(intervals, epoch):
    mask = intervals[epoch][:].astype(bool)
    return np.stack((intervals['start_time'][mask], intervals['stop_time'][mask]), axis=-1)

In [None]:
def trial_intervals(intervals):
    trial_nums = intervals['trial_num'][:].astype(int)
    trials = np.unique(trial_nums)
    startstops = []
    for trial in trials:
        indices = np.nonzero(trial_nums == trial)[0]
        startstops.append((intervals['start_time'][indices[0]], intervals['stop_time'][indices[-1]]))
    return np.array(startstops)

In [None]:
def trial_stimulus_intervals(intervals):
    trial_nums = intervals['trial_num'][:].astype(int)
    trials = np.unique(trial_nums)
    stimuli = {k: [] for k in range(5)}
    for trial in trials:
        for index, interval in enumerate(np.nonzero(trial_nums == trial)[0]):
            stimuli[index].append((intervals['start_time'][interval], intervals['stop_time'][interval]))
    return {k: np.array(v) for k, v in stimuli.items()}

In [None]:
CONDITIONS = ["go_gloexp", "lo_gloexp", "go_seqctl", "seqctl"]
PRETRIAL_SECONDS = 0.5
POSTTRIAL_SECONDS = 0.5

In [None]:
def hippocampal_areas(probe_areas):
    for (c, a) in enumerate(probe_areas):
        if "DG-" in a or "CA" in a:
            yield (c, a)

def visual_areas(probe_areas):
    for (c, a) in enumerate(probe_areas):
        if "VIS" in a:
            yield (c, a)

def subcortical_areas(probe_areas):
    for (c, a) in enumerate(probe_areas):
        if "DG-" in a or "CA" in a or "MB" in a or "SCi" in a or "POST" in a:
            yield (c, a)

In [None]:
def probe_electrode_metadata(electrodes, probe, channels):
    indices = (electrodes['group_name'][:] == probe.encode()).nonzero()[0]
    return {
        'areas': np.array([area.decode() for area in electrodes['location'][indices][channels]]),
        'horizontal': electrodes['probe_horizontal_position'][indices][channels] * pq.mm,
        'vertical': electrodes['probe_vertical_position'][indices][channels] * pq.mm,
    }

In [None]:
NWB_FILES = glob.glob('/mnt/data/*.nwb')

In [None]:
PILOT_FILES = []

In [None]:
def probe_lfps(nwb, electrodes, probes):
    signals = {}
    for p, probe in enumerate(probes):
        probe_lfp = "probe_%d_lfp" % p
        probe_lfp = probe_lfp + "/" + probe_lfp + "_data"
        probe_lfps = nwb['acquisition/'][probe_lfp]
        timestamps = probe_lfps["timestamps"][:]
        dt = (timestamps[1:] - timestamps[:-1]).mean()
    
        probe_electrodes = electrodes["group_name"][:] == probe.encode()
        channels = {
            "horizontal": electrodes["probe_horizontal_position"][probe_electrodes],
            "id": electrodes["id"][probe_electrodes],
            "location": electrodes["location"][probe_electrodes],
            "vertical": electrodes["probe_vertical_position"][probe_electrodes],
        }
        channels = {k: v[probe_lfps['electrodes'][:]] for k, v in channels.items()}
        probe_channels = pd.DataFrame(data=channels, columns=["horizontal", "location", "vertical"], index=channels['id'])
    
        signals[probe] = epych.signals.lfp.RawLfp(probe_channels, probe_lfps["data"], dt, timestamps, channels_dim=1, time_dim=0)
    return signals

In [None]:
def passiveglo_epochs(glo_intervals):
    epochs = {
        "start": np.array([]),
        "end": np.array([]),
        "type": [],
    }

    trials = trial_intervals(glo_intervals)
    trial_times = trials[:, 0]
    epochs["start"] = np.concatenate((epochs["start"], trials[:, 0]), axis=0)
    epochs["end"] = np.concatenate((epochs["end"], trials[:, 1]), axis=0)
    epochs["type"] = epochs["type"] + ["trial"] * trials.shape[0]

    for condition in CONDITIONS:
        cond_times = epoch_intervals(glo_intervals, condition)
        epochs["start"] = np.concatenate((epochs["start"], cond_times[:, 0]), axis=0)
        epochs["end"] = np.concatenate((epochs["end"], cond_times[:, 1]), axis=0)
        epochs["type"] = epochs["type"] + [condition] * cond_times.shape[0]

    stim_times = trial_stimulus_intervals(glo_intervals)
    for k, v in stim_times.items():
        epochs["start"] = np.concatenate((epochs["start"], v[:, 0]), axis=0)
        epochs["end"] = np.concatenate((epochs["end"], v[:, 1]), axis=0)
        epochs["type"] = epochs["type"] + ["stim%d" % k] * v.shape[0]

    return pd.DataFrame(data=epochs, columns=epochs.keys())

In [None]:
def nwbfile_recording(nwb, required_probes=6):
    nwb = h5py.File(nwb_file, 'r')
    electrodes = nwb['general']['extracellular_ephys']['electrodes']
    probes = sorted([probe.decode() for probe in set(electrodes['group_name'][:])])
    if len(probes) < required_probes:
        PILOT_FILES.append(nwb_file)
        return None
    glo_intervals = nwb['intervals']['passive_glo']

    epochs = passiveglo_epochs(glo_intervals)
    signals = probe_lfps(nwb, electrodes, probes)

    units = {"start": pq.second, "end": pq.second}
    return epych.recording.RawRecording(epochs, pd.DataFrame(columns=["trial"]).set_index("trial"), units, **signals)

In [None]:
erps = {}

In [None]:
GOOD_NWB_FILES = []
good_recordings = []

In [None]:
for f, nwb_file in enumerate(NWB_FILES):
    with h5py.File(nwb_file, "r") as nwb:
        recording = nwbfile_recording(nwb)
        if recording is None:
            continue
        else:
            GOOD_NWB_FILES.append(nwb_file)
            for cond in CONDITIONS:
                condition_epochs = recording.intervals["type"] == cond
                trial_epochs = recording.intervals["type"] == "trial"
                sampling = recording.epoch(condition_epochs, trial_epochs, PRETRIAL_SECONDS, POSTTRIAL_SECONDS).baseline_correct(0, PRETRIAL_SECONDS)
                erps[(nwb_file, cond)] = sampling.erp().smap(lambda sig: sig.median_filter())
                del sampling
                logging.info("Finished ERPs of %s condition in %s" % (cond, nwb_file))
            del recording
            logging.info("Finished ERPs of %s" % nwb_file)

In [None]:
for cond in CONDITIONS:
    print(cond)
    for nwb_file in GOOD_NWB_FILES:
        erps[(nwb_file, cond)].plot(vmin=-1e-4, vmax=1e-4)

In [None]:
csds = {}

In [None]:
for cond in CONDITIONS:
    print(cond)
    for nwb_file in GOOD_NWB_FILES:
        csds[(nwb_file, cond)] = erps[(nwb_file, cond)].smap(lambda sig: sig.downsample(4).current_source_density(depth_column="vertical"))
        csds[(nwb_file, cond)].plot(vmin=-1e-4, vmax=1e-4)
        logging.info("Finished CSDs of %s condition in %s" % (cond, nwb_file))
    logging.info("Finished CSDs in %s" % nwb_file)