In [1]:
%cd ../..

/home/eli/AnacondaProjects/epych


In [2]:
%env DASK_LOGGING__DISTRIBUTED=CRITICAL
%env SPYTMPDIR=/mnt/data/tmp_storage
%env SPYLOGLEVEL=CRITICAL
%env SPYPARLOGLEVEL=CRITICAL

env: DASK_LOGGING__DISTRIBUTED=CRITICAL
env: SPYTMPDIR=/mnt/data/tmp_storage
env: SPYLOGLEVEL=CRITICAL
env: SPYPARLOGLEVEL=CRITICAL


In [3]:
import collections
import glob
import functools
import logging
import math
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import pynwb
import quantities as pq
import scipy

import epych
from epych.statistics import alignment

[striatum:35039] shmem: mmap: an error occurred while determining whether or not /tmp/ompi.striatum.1000/jf.0/3774283776/shared_mem_cuda_pool.striatum could be created.
[striatum:35039] create_and_attach: unable to create shared memory BTL coordinating structure :: size 134217728 


In [4]:
%matplotlib inline

In [5]:
logging.basicConfig(level=logging.INFO)

In [6]:
CONDITIONS = ["go_gloexp", "lo_gloexp", "go_seqctl", "lo_rndctl", "igo_seqctl"]
PRETRIAL_SECONDS = np.array(0.5) * pq.second
POSTTRIAL_SECONDS = np.array(0.5) * pq.second

In [7]:
aligner = epych.statistics.alignment.AlignmentSummary.unpickle("/mnt/data/DRAFT/000253/visual_alignment/")

In [8]:
probe_area_counter = collections.Counter()

In [9]:
def visual_align(probe, signal):
    area = os.path.commonprefix([loc for loc in signal.channels.location if "VIS" in loc])
    k = probe_area_counter[area]
    probe_area_counter[area] += 1
    return aligner.stats[area].align(k, signal.select_channels(["VIS" in loc for loc in signal.channels.location]))

In [10]:
NWB_SUBJECTS = glob.glob('/mnt/data/DRAFT/000253/sub-*/')

In [11]:
NUM_TRIALS = 0
ODDBALL_ONSET = 0.
ODDBALL_OFFSET = 0.

In [12]:
PILOT_FILES = []
REQUIRED_PROBES = 6

In [13]:
def epoch_intervals(intervals, epoch):
    mask = intervals[epoch][:].astype(bool)
    return np.stack((intervals['start_time'][mask], intervals['stop_time'][mask]), axis=-1)

In [14]:
def trial_intervals(intervals):
    trial_nums = intervals['trial_num'][:].astype(int)
    trials = np.unique(trial_nums)
    startstops = []
    for trial in trials:
        indices = np.nonzero(trial_nums == trial)[0]
        startstops.append((intervals['start_time'][indices[0]], intervals['stop_time'][indices[-1]]))
    return np.array(startstops)

In [15]:
def trial_stimulus_intervals(intervals):
    trial_nums = intervals['trial_num'][:].astype(int)
    trials = np.unique(trial_nums)
    stimuli = {k: [] for k in range(5)}
    for trial in trials:
        for index, interval in enumerate(np.nonzero(trial_nums == trial)[0]):
            stimuli[index].append((intervals['start_time'][interval], intervals['stop_time'][interval]))
    return {k: np.array(v) for k, v in stimuli.items()}

In [16]:
def probe_lfps(nwb_files, electrodes):
    signals = {}
    for p, nwb_file in enumerate(nwb_files):
        with pynwb.NWBHDF5IO(path=nwb_file, mode="r") as io:
            nwb = io.read()

            probe_lfp = "probe_%d_lfp_data" % p
            lfps = nwb.acquisition[probe_lfp]
            timestamps = lfps.timestamps[:] * pq.second
            dt = (timestamps[1:] - timestamps[:-1]).mean()
            if dt > 1.:
                raise ValueError('Invalid timestamps!')

            probe = list(nwb.electrode_groups.keys())[0]
            probe_electrodes = electrodes["group_name"][:] == probe
            channels = {
                "horizontal": electrodes["probe_horizontal_position"][probe_electrodes] * pq.mm,
                "id": electrodes["id"][probe_electrodes],
                "location": electrodes["location"][probe_electrodes],
                "vertical": electrodes["probe_vertical_position"][probe_electrodes] * pq.mm,
            }
            channels = {k: v[lfps.electrodes['local_index'][:]] for k, v in channels.items()}
            probe_channels = pd.DataFrame(data=channels, columns=["horizontal", "location", "vertical"], index=channels['id'])
            signals[probe] = epych.signals.lfp.RawLfp(probe_channels, lfps.data[:] * pq.volt, dt, timestamps, channels_dim=1, time_dim=0)
    return signals

In [17]:
def passiveglo_epochs(glo_intervals):
    epochs = {
        "start": np.array([]),
        "end": np.array([]),
        "type": [],
    }

    trials = trial_intervals(glo_intervals)
    trial_times = trials[:, 0]
    epochs["start"] = np.concatenate((epochs["start"], trials[:, 0].squeeze()), axis=0)
    epochs["end"] = np.concatenate((epochs["end"], trials[:, 1].squeeze()), axis=0)
    epochs["type"] = epochs["type"] + ["trial"] * trials.shape[0]

    for condition in CONDITIONS:
        cond_times = epoch_intervals(glo_intervals, condition)
        epochs["start"] = np.concatenate((epochs["start"], cond_times[:, 0].squeeze()), axis=0)
        epochs["end"] = np.concatenate((epochs["end"], cond_times[:, 1].squeeze()), axis=0)
        epochs["type"] = epochs["type"] + [condition] * cond_times.shape[0]

    stim_times = trial_stimulus_intervals(glo_intervals)
    for k, v in stim_times.items():
        epochs["start"] = np.concatenate((epochs["start"], v[:, 0].squeeze()), axis=0)
        epochs["end"] = np.concatenate((epochs["end"], v[:, 1].squeeze()), axis=0)
        epochs["type"] = epochs["type"] + ["stim%d" % k] * v.shape[0]

    return pd.DataFrame(data=epochs, columns=epochs.keys())

In [18]:
def subjectdir_recording(subject_dir):
    passive_glo = scipy.io.loadmat(subject_dir + 'passiveglo_task_data.mat')
    epochs = passiveglo_epochs(passive_glo)
    units = {"start": pq.second, "end": pq.second}
    ogen_nwb = glob.glob(subject_dir + '/sub-*_ogen.nwb')[0]
    with pynwb.NWBHDF5IO(path=ogen_nwb, mode="r") as io:
        nwb = io.read()
        electrodes = {}
        electrodes['group_name'] = nwb.electrodes.group_name[:]
        electrodes['id'] = nwb.electrodes.id[:]
        electrodes['location'] = nwb.electrodes.location[:]
        electrodes['probe_horizontal_position'] = nwb.electrodes.probe_horizontal_position[:] * pq.mm
        electrodes['probe_vertical_position'] = nwb.electrodes.probe_vertical_position[:] * pq.mm

    nwb_files = sorted(glob.glob(subject_dir + '/sub-*_ses-*_probe-*.nwb'))
    nwb_files = [nwbf for nwbf in nwb_files if "test" not in nwbf]
    if len(nwb_files) < REQUIRED_PROBES:
        PILOT_FILES.append(subject_dir)
        return None

    try:
        signals = probe_lfps(nwb_files, electrodes)
        return epych.recording.RawRecording(epochs, pd.DataFrame(columns=["trial"]).set_index("trial"), units, **signals)
    except ValueError as verr:
        if verr.args == ('Invalid timestamps!',):
            return None
        raise verr

In [19]:
def initialize_spectrum(key, signal):
    area = os.path.commonprefix([loc for loc in signal.channels.location])
    return epych.statistics.spectrum.PowerSpectrum(signal.df, signal.channels, signal.f0, taper="hann")

In [20]:
for s, subject_dir in enumerate(sorted(NWB_SUBJECTS)):
    recording = subjectdir_recording(subject_dir)
    logging.info("Loaded LFPs from subject %s" % subject_dir)
    if os.path.exists(subject_dir + '/psds'):
        logging.info("Calculated oscillatory power spectra for LFPs from subject %s" % subject_dir)
        for k, v in recording.signals.items():
            area = os.path.commonprefix([loc for loc in v.channels.location if "VIS" in loc])
            probe_area_counter[area] += 1
        del recording
    else:
        sampling = recording.epoch(recording.intervals["type"] == "trial", before=PRETRIAL_SECONDS, after=POSTTRIAL_SECONDS).smap(visual_align, keys=True)
        del recording

        summary = epych.statistic.Summary(alignment.location_prefix, initialize_spectrum)
        summary.calculate([sampling])
        summary.pickle(subject_dir + '/psds')
        del summary
        del sampling
        logging.info("Calculated oscillatory power spectra for LFPs from subject %s" % subject_dir)

  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cached namespace '%s' version %s because version %s is already loaded."
  warn("Ignoring cach