In [1]:
import importlib
import os
import os.path as op
import shutil
import sys

import numpy as np
import scipy.spatial

In [2]:
%matplotlib notebook

In [3]:
factory_dir = r"C:\Users\Alan\Documents\hybridfactory"

os.chdir(factory_dir)
sys.path.insert(0, factory_dir)

In [4]:
import factory.io.raw
import factory.io.gt

In [5]:
params = importlib.import_module("npix-gen-20180510")
probe = importlib.import_module(f"factory.probes.{params.probe_type}")

params.verbose = True
params.copy = False

SPIKE_LIMIT = 25000

np.random.seed(params.random_seed)

In [6]:
def _log(msg, stdout, in_progress=False):
    end = " ... " if in_progress else "\n"

    if stdout:
        print(msg, end=end)


def _user_dialog(msg, options=("y", "n"), default_option="n"):
    default_option = default_option.lower()
    options = [o.lower() for o in options]
    assert default_option in options

    options.insert(options.index(default_option), default_option.upper())
    options.remove(default_option)

    print(msg, end=" ")
    choice = input(f"[{'/'.join(options)}] ").strip().lower()

    iters = 0
    while choice and choice not in list(map(lambda x: x.lower(), options)) and iters < 3:
        iters += 1
        choice = input(f"[{'/'.join(options)}] ").strip().lower()

    if not choice or choice not in list(map(lambda x: x.lower(), options)):
        choice = default_option

    return choice

In [7]:
def copy_source_target(params, probe):
    """

    Parameters
    ----------
    params : module
        Session parameters.
    probe : module
        Probe parameters.

    Returns
    -------
    source : numpy.memmap
        Memory map of source data file.
    target : numpy.memmap
        Memory map of target data file.
    """

    if op.isfile(params.raw_target_file) and not params.overwrite:
        if _user_dialog(f"Target file {params.raw_target_file} exists! Overwrite?") == "y":
            params.overwrite = True
        else:
            print("aborting")
            return

    if params.copy:
        _log(f"Copying {params.raw_source_file} to {params.raw_target_file}", params.verbose, in_progress=True)
        shutil.copy2(params.raw_source_file, params.raw_target_file)
        _log("done", params.verbose)

    file_size_bytes = op.getsize(params.raw_source_file)
    byte_count = np.dtype(params.data_type).itemsize  # number of bytes in data type
    nrows = probe.NCHANS + params.extra_channels
    ncols = file_size_bytes // (nrows * byte_count)

    params.num_samples = ncols

    source = np.memmap(params.raw_source_file, dtype=params.data_type, offset=params.offset, mode="r",
                       shape=(nrows, ncols), order="F")
    target = np.memmap(params.raw_target_file, dtype=params.data_type, offset=params.offset, mode="r+",
                       shape=(nrows, ncols), order="F")

    return source, target

In [8]:
def construct_artificial_events(source, params, probe, unit_times):
    """

    Parameters
    ----------
    source : numpy.memmap
        Memory map of data file.
    params : module
        Session parameters.
    probe : module
        Probe parameters.
    unit_times : numpy.ndarray
        Array of firing times for this unit.

    Returns
    -------
    art_events : numpy.ndarray
        Tensor, num_channels x num_samples x num_events, constructed by `generator`.
    channels : numpy.ndarray
        Channels on which the original events occur.
    """

    # e.g., factory.generators.steinmetz
    generator = importlib.import_module(f"factory.generators.{params.generator_type}")
    art_events, channels = generator.generate(source, params, probe, unit_times)

    return art_events, channels

In [9]:
def construct_artificial_events(source, params, probe, unit_times):
    """

    Parameters
    ----------
    source : numpy.memmap
        Memory map of data file.
    params : module
        Session parameters.
    probe : module
        Probe parameters.
    unit_times : numpy.ndarray
        Array of firing times for this unit.

    Returns
    -------
    art_events : numpy.ndarray
        Tensor, num_channels x num_samples x num_events, constructed by `generator`.
    channels : numpy.ndarray
        Channels on which the original events occur.
    """

    # e.g., factory.generators.steinmetz
    generator = importlib.import_module(f"factory.generators.{params.generator_type}")
    art_events, channels = generator.generate(source, params, probe, unit_times)

    return art_events, channels

In [10]:
def shift_channels(channels, params, probe):
    """Shift a subset of the channels.

    Parameters
    ----------
    channels : numpy.ndarray
        Input channels to be shifted.
    params : module
        Session parameters.
    probe : module
        Probe parameters.

    Returns
    -------
    shifted_channels : numpy.ndarray or None
        Channels shifted by some constant factor.
    """

    # inverse_channel_map[probe.channel_map] == [1, 2, ..., probe.channel_map.size - 1]
    inverse_channel_map = np.zeros(probe.channel_map.size, dtype=np.int64)
    inverse_channel_map[probe.channel_map] = np.arange(probe.channel_map.size)

    # make sure our shifted channels fall in the range [0, probe.channel_map)
    if inverse_channel_map[channels].max() < probe.channel_map.size - params.channel_shift:
        shifted_channels = probe.channel_map[inverse_channel_map[channels] + params.channel_shift]
    else:
        shifted_channels = probe.channel_map[inverse_channel_map[channels] - params.channel_shift]

    try:
        assert shifted_channels.min() > -1 and shifted_channels.max() < probe.channel_map.size
    except AssertionError:
        _log(f"channel shift of {params.channel_shift} places events outside of probe range", params.verbose)
        return None

    # make sure our shifted channels don't land on unconnected channels
    try:
        assert np.intersect1d(shifted_channels, probe.channel_map[~probe.connected]).size == 0
    except AssertionError:
        _log(f"channel shift of {params.channel_shift} places events on unconnected channels", params.verbose)
        return None

    # make sure our shifted channels don't alter spatial relationships
    channel_distance = scipy.spatial.distance.pdist(probe.channel_positions[inverse_channel_map[channels], :])
    shifted_distance = scipy.spatial.distance.pdist(probe.channel_positions[inverse_channel_map[shifted_channels], :])

    try:
        assert np.isclose(channel_distance, shifted_distance).all()
    except AssertionError:
        _log(f"channel shift of {params.channel_shift} alters spatial relationship between channels", params.verbose)
        return None

    return shifted_channels

In [11]:
def jitter_events(unit_times, params):
    """

    Parameters
    ----------
    unit_times : numpy.ndarray
        Firing times for this unit, to be jittered.
    params : module
        Session parameters.

    Returns
    -------
    jittered_times : numpy.ndarray
        Jittered firing times for artificial events.
    """

    isi_samples = params.sample_rate // 1000  # number of samples in 1 ms
    # normally-distributed jitter factor, with an absmin of `isi_samples`
    jitter1 = isi_samples + np.abs(np.random.normal(loc=0, scale=params.time_jitter // 2, size=unit_times.size // 2))
    jitter2 = -(isi_samples + np.abs(isi_samples + np.random.normal(loc=0, scale=params.time_jitter // 2,
                                                                    size=unit_times.size - jitter1.size)))

    # leaves a window of 2 ms around `unit_times` so units don't fire right on top of each other
    jitter = np.random.permutation(np.hstack((jitter1, jitter2))).astype(unit_times.dtype)

    jittered_times = unit_times + jitter

    try:
        assert (jittered_times - params.samples_before > 0).all() and \
               (jittered_times + params.samples_after < params.num_samples).all()
    except AssertionError:
        _log(f"time jitter of {params.time_jitter} and sample window places events outside of sample range",
             params.verbose)
        return None

    return jittered_times

In [12]:
io = importlib.import_module(f"factory.io.{params.output_type}")  # e.g., factory.io.phy, factory.io.kilosort, ...
event_times = io.load_event_times(params.data_directory)
event_clusters = io.load_event_clusters(params.data_directory)

In [13]:
gt_channels = []
gt_times = []
gt_labels = []

In [14]:
source, target = copy_source_target(params, probe)

for unit_id in params.ground_truth_units:
    unit_times = event_times[event_clusters == unit_id]
    
    if unit_times.size > SPIKE_LIMIT:
        unit_times = np.random.choice(unit_times, size=SPIKE_LIMIT, replace=False)

    # generate artificial events for this unit
    _log(f"Generating ground truth for unit {unit_id}", params.verbose, in_progress=True)
    art_events, channel_subset = construct_artificial_events(source, params, probe, unit_times)
    if art_events is None:
        _log("no waveforms crossed threshold; skipping", params.verbose)
        continue

    _log("done", params.verbose)

    # shift channels
    _log("Shifting channels", params.verbose, in_progress=True)
    shifted_channels = shift_channels(channel_subset, params, probe)

    if shifted_channels is None:
        continue  # cause is logged in `shift_channels`

    _log("done", params.verbose)

    # jitter events
    _log("Jittering events", params.verbose, in_progress=True)
    jittered_times = jitter_events(unit_times, params)
    _log("done", params.verbose)

    if jittered_times is None:
        continue

    # write to file
    _log("Writing events to file", params.verbose, in_progress=True)
    for i, jittered_center in enumerate(jittered_times):
        jittered_samples = np.arange(jittered_center - params.samples_before,
                                     jittered_center + params.samples_after + 1)

        shifted_window = factory.io.raw.read_roi(target, shifted_channels, jittered_samples)
        perturbed_data = shifted_window + art_events[:, :, i]

        factory.io.raw.write_roi(target, shifted_channels, jittered_samples, perturbed_data)

    _log("done", params.verbose)

    cc_indices = np.abs(art_events).max(axis=1).argmax(axis=0)  # num_events
    center_channels = shifted_channels[cc_indices] + 1

    gt_channels.append(center_channels)
    gt_times.append(jittered_times)
    gt_labels.append(unit_id)

Copying F:\CortexLab\singlePhase3\data\Hopkins_20160722_g0_t0.imec.ap_CAR.bin to C:\Users\Alan\Documents\Data\npix-hybrid\Hopkins_20160722_g0_t0.imec.ap_CAR.GT.bin ... done
Generating ground truth for unit 18 ... done
Shifting channels ... channel shift of 20 places events on unconnected channels
Generating ground truth for unit 36 ... done
Shifting channels ... done
Jittering events ... done
Writing events to file ... done
Generating ground truth for unit 83 ... done
Shifting channels ... done
Jittering events ... done
Writing events to file ... done
Generating ground truth for unit 199 ... no waveforms crossed threshold; skipping
Generating ground truth for unit 243 ... done
Shifting channels ... done
Jittering events ... done
Writing events to file ... done
Generating ground truth for unit 267 ... done
Shifting channels ... channel shift of 20 places events on unconnected channels
Generating ground truth for unit 283 ... no waveforms crossed threshold; skipping
Generating ground tru

In [15]:
del source, target

In [16]:
dirname = op.dirname(params.raw_target_file)

# save ground-truth units for validation
filename = factory.io.gt.save_gt_units(dirname, gt_channels, gt_times, gt_labels)
_log(f"Firing times and labels saved to {filename}.", params.verbose)

Firing times and labels saved to C:\Users\Alan\Documents\Data\npix-hybrid\firings_true.npy.


In [17]:
firings_true = np.load(filename)

In [24]:
firings_true.shape

(3, 57816)