In this demo notebook, we'll create some hybrid data from the [single phase 3A data set](http://data.cortexlab.net/singlePhase3/) provided by Nick Steinmetz and the Cortex Lab.

In [1]:
import datetime
import glob
import hashlib
import importlib
import os
import os.path as op
import shutil
import sys
import zipfile

import numpy as np

# add folder containing our library to sys.path
factory_dir = op.abspath("../..")
if factory_dir not in sys.path:
    sys.path.insert(0, factory_dir)

import factory.generate.shift
import factory.generate.generators
import factory.io.phy
import factory.io.raw
import factory.io.gt
from factory.io.logging import log

In [2]:
%matplotlib notebook

In [3]:
hybrid_dir = op.abspath(op.join(".", "cortex_demo"))

if not op.isdir(hybrid_dir):
    os.mkdir(hybrid_dir)

Let's now set our parameters. The copy would be too long to wait, so we assume `raw_source_file` and `raw_target_file` already exist.

In [4]:
class Parameter(object):
    pass

params = Parameter()

# path to file containing raw source data
params.raw_source_file = r"F:/CortexLab/singlePhase3/data/Hopkins_20160722_g0_t0.imec.ap_CAR.bin"
# path to file to contain hybrid
params.raw_target_file = op.join(hybrid_dir, "Hopkins_20160722_g0_t0.imec.ap_CAR.GT.bin")
# type of raw data, as a numpy dtype
params.data_type = np.int16
# sample rate in Hz
params.sample_rate = 30000
# directory containing output from KiloSort
params.data_directory = "F:/CortexLab/singlePhase3/data"
# type of output from KiloSort, in this case phy input
params.output_type = "phy"
# probe layout
params.probe_type = "npix3a"
# indices (cluster labels) of ground-truth units
params.ground_truth_units = [36, 83, 199, 243, 267, 283, 464, 1074, 1159]

# random seed, for reproducibility
params.random_seed = 10191
# algorithm to generate hybrid data
params.generator_type = "steinmetz"
# number of singular values to use in the construction of artificial units
params.num_singular_values = 8
# number of channels to shift the units by
params.channel_shift = 20
# scale factor for randomly-generated jitter
params.time_jitter = 500
# minimum amplitude scale factor
params.amplitude_scale_min = 0.75
# maximum_amplitude_scale_factor
params.amplitude_scale_max = 1.5
# number of samples to take before an event timestep
params.samples_before = 40
# number of samples to take after an event timestep
params.samples_after = 40
# threshold a channel must exceed to be considered part of an event
params.event_threshold = -30
# point in the raw file at which the data starts
params.offset = 0
# absolutely do NOT copy this huge file
params.copy = False
# whether or not to overwrite a target file if it already exists
params.overwrite = True
# start time of raw data file, in sample units
params.start_time = 0
# log messages to the screen
params.verbose = True

SPIKE_LIMIT = 25000

Let's also import our probe module and set the random seed accordingly.

In [5]:
import factory.probes.npix3a as probe

np.random.seed(params.random_seed)

Let's make sure we're operating on the same files before we do anything else.

In [6]:
def md5sum(filename):
    # hat tip to this guy: https://stackoverflow.com/questions/22058048/hashing-a-file-in-python#22058673
    chunk_size = 65536  # read in 64 KiB chunks

    result = hashlib.md5()

    with open(filename, "rb") as fh:
        while True:
            data = fh.read(chunk_size)
            if not data:
                break
            result.update(data)
            
    return result.hexdigest()

In [9]:
filenames = ["Hopkins_20160722_g0_t0.imec.ap_CAR.bin", "spike_clusters.npy", "spike_templates.npy",
             "spike_times.npy", "templates.npy"]

for filename in filenames:
    if not op.isfile(op.join(params.data_directory, filename)):
        print(f"{filename} not found")

In [11]:
assert md5sum(params.raw_source_file) == "eb93a041e52eba844aed148ac9718998"

# this is somewhat faster than copying a ~80G file
if md5sum(params.raw_target_file) != "eb93a041e52eba844aed148ac9718998":
    print(f"{params.raw_target_file} is different! overwriting...", end="")
    file_size_bytes = op.getsize(params.raw_source_file)
    byte_count = np.dtype(params.data_type).itemsize  # number of bytes in data type
    nrows = probe.NCHANS
    ncols = file_size_bytes // (nrows * byte_count)

    source = np.memmap(params.raw_source_file, dtype=params.data_type, offset=params.offset, mode="r",
                       shape=(nrows, ncols), order="F")
    target = np.memmap(params.raw_target_file, dtype=params.data_type, offset=params.offset, mode="r+",
                       shape=(nrows, ncols), order="F")
    
    # load up firing times
    firings_true = factory.io.gt.load_gt_units(params.data_directory).astype(np.int64)
    # reset the target to match up with the source
    factory.io.raw.reset_target(source, target, params.samples_before, params.samples_after, firings_true[1, :])
    # save
    del source, target
    print("done")

assert md5sum(op.join(params.data_directory, "spike_clusters.npy")) == "d6d49ccbb9e34edc286c161541b681b3"
assert md5sum(op.join(params.data_directory, "spike_templates.npy")) == "218e7748281db5e95babb6b3ebc182c8"

assert md5sum(op.join(params.data_directory, "spike_times.npy")) == "938bc213d15ba9aa2cc9cc84a403c314"
assert md5sum(op.join(params.data_directory, "templates.npy")) == "8328de406b19e0afd35d9aa49c1d6858"

C:\Users\Alan\Documents\hybridfactory\notebooks\demos\cortex_demo\Hopkins_20160722_g0_t0.imec.ap_CAR.GT.bin is different! overwriting...done


Define some functions (really just copy them from `generate.py`).

In [12]:
def _legal_params():
    required_params = {"raw_source_file": None,
                       "raw_target_file": None,
                       "data_type": [np.int16],
                       "sample_rate": None,
                       "output_type": ["kilosort", "phy", "jrc"],
                       "data_directory": None,
                       "probe_type": ["npix3a", "eMouse", "hh2_arseny"],
                       "ground_truth_units": None}

    optional_params = {"random_seed": None,
                       "generator_type": ["steinmetz"],
                       "num_singular_values": 6,
                       "channel_shift": None,  # depends on probe
                       "time_jitter": 500,
                       "amplitude_scale_min": 0.75,
                       "amplitude_scale_max": 2.,
                       "samples_before": 40,
                       "samples_after": 40,
                       "event_threshold": -30,
                       "offset": 0,
                       "copy": True,
                       "overwrite": False,
                       "start_time": 0}

    return required_params, optional_params


def _write_param(fh, param, param_val):
    if param == "data_type":
        if param_val == np.int16:  # no other data types supported yet
            param_val = "np.int16"
    elif isinstance(param_val, str):  # enclose string in quotes
        param_val = f'r"{param_val}"'
    elif isinstance(param_val, np.ndarray):  # numpy doesn't do roundtripping
        param_val = param_val.tolist()
    elif param_val is None:
        return

    print(f"{param} = {param_val}", file=fh)


def _write_config(filename, params):
    required_params, optional_params = _legal_params()

    with open(filename, "w") as fh:
        print("import numpy as np\n", file=fh)

        print("# required parameters\n", file=fh)
        for param in required_params:
            _write_param(fh, param, params.__dict__[param])

        print("\n# optional parameters\n", file=fh)
        for param in optional_params:
            _write_param(fh, param, params.__dict__[param])

        print(f"# automatically generated on {datetime.datetime.now()}", file=fh)

In [13]:
def copy_source_target(params, probe):
    """

    Parameters
    ----------
    params : module
        Session parameters.
    probe : module
        Probe parameters.

    Returns
    -------
    source : numpy.memmap
        Memory map of source data file.
    target : numpy.memmap
        Memory map of target data file.
    """

    raw_source_files = glob.glob(params.raw_source_file)

    if len(raw_source_files) > 1:
        assert len(raw_source_files) == len(params.start_time)
        
        raw_target_files = raw_source_files.copy()
        for k, rtf in enumerate(raw_target_files):
            # just save hybrid data files in directory containing params file
            dirname = op.dirname(rtf)
            rtf = rtf.replace(dirname, op.dirname(params.data_directory))

            try:
                last_dot = -(rtf[::-1].index('.') + 1)
                rtf = rtf[:last_dot] + ".GT" + rtf[last_dot:]  # add ".GT" before extension
            except ValueError:  # no '.' found in rtf
                rtf += ".GT"  # add ".GT" at the end
            finally:
                raw_target_files[k] = rtf
        start_times = params.start_time
    else:
        raw_target_files = [params.raw_target_file]
        start_times = [params.start_time]

    for k, raw_source_file in enumerate(raw_source_files):
        start_time = start_times[k]
        raw_target_file = raw_target_files[k]

        if op.isfile(raw_target_file) and not params.overwrite:
            if _user_dialog(f"Target file {raw_target_file} exists! Overwrite?") == "y":
                params.overwrite = True
            else:
                _err_exit("aborting", 0)

        if params.copy:
            log(f"Copying {raw_source_file} to {raw_target_file}", params.verbose, in_progress=True)
            shutil.copy2(raw_source_file, raw_target_file)
            log("done", params.verbose)

        file_size_bytes = op.getsize(raw_source_file)
        byte_count = np.dtype(params.data_type).itemsize  # number of bytes in data type
        nrows = probe.NCHANS
        ncols = file_size_bytes // (nrows * byte_count)

        params.num_samples = ncols

        source = np.memmap(raw_source_file, dtype=params.data_type, offset=params.offset, mode="r",
                           shape=(nrows, ncols), order="F")
        target = np.memmap(raw_target_file, dtype=params.data_type, offset=params.offset, mode="r+",
                           shape=(nrows, ncols), order="F")

        yield source, target, start_time

In [14]:
def unit_channels_union(unit_mask, params, probe):
    """

    Parameters
    ----------
    unit_mask : numpy.ndarray
        Boolean array of events to take for this unit.
    params : module
        Session parameters.
    probe : module
        Probe parameters.

    Returns
    -------
    channels : numpy.ndarray
        Channels on which unit events occur.
    """

    # select all channels on which events occur for this unit...
    event_channel_indices = factory.io.jrc.load_event_channel_indices(params.data_directory)
    channel_neighbor_indices = factory.io.jrc.load_channel_neighbor_indices(params.data_directory)
    event_channels = probe.channel_map[probe.connected][event_channel_indices]

    # ...find neighbors for all channels...
    channel_neighbors = {}
    for channel_neighborhood in probe.channel_map[probe.connected][channel_neighbor_indices].T:
        channel_neighbors[channel_neighborhood[0]] = set(channel_neighborhood)

    # ...and isolate the channels which are neighbors of distinct centers for this unit
    unit_channel_centers = np.unique(event_channels[unit_mask])
    unit_channels = np.array(list(set.union(*[channel_neighbors[c] for c in unit_channel_centers])))

    return unit_channels

In [15]:
def scale_events(events, params, probe):
    """

    Parameters
    ----------
    events : numpy.ndarray
        Tensor, num_channels x num_samples x num_events.
    params : module
        Session parameters.
    probe : module
        Probe parameters.

    Returns
    -------
    scaled_events : numpy.ndarray
        Tensor, num_channels x num_samples x num_events, scaled.
    """

    abs_events = np.abs(events)

    centers = abs_events.max(axis=0).argmax(axis=0)

    scale_factors = np.random.uniform(params.amplitude_scale_min, params.amplitude_scale_max, size=abs_events.shape[2])
    scale_rows = [np.hstack((np.linspace(0, scale_factors[i], centers[i]),
                  np.linspace(scale_factors[i], 0, events.shape[1]-centers[i]+1)[1:]))[np.newaxis, :] for i in range(events.shape[2])]

    return np.stack(scale_rows, axis=2) * events

In [16]:
log("Loading event times and cluster IDs", params.verbose, in_progress=True)

import factory.io.phy as io
event_times = io.load_event_times(params.data_directory)
event_clusters = io.load_event_clusters(params.data_directory)
log("done", params.verbose)

gt_channels = []
gt_times = []
gt_labels = []

for source, target, start_time in copy_source_target(params, probe):
    time_mask = (event_times - params.samples_before >= start_time) & (event_times - start_time +
                                                                       params.samples_after < target.shape[1])

    for unit_id in params.ground_truth_units:
        unit_mask = (event_clusters == unit_id) & time_mask
        num_events = np.where(unit_mask)[0].size

        if num_events > SPIKE_LIMIT:  # if more events than limit, select some to ignore
            falsify = np.random.choice(np.where(unit_mask)[0], size=num_events-SPIKE_LIMIT, replace=False)
            unit_mask[falsify] = False
        elif num_events == 0:
            log(f"No events found for unit {unit_id}", params.verbose)
            continue

        # generate artificial events for this unit
        log(f"Generating ground truth for unit {unit_id}", params.verbose, in_progress=True)

        unit_times = event_times[unit_mask] - start_time
        unit_windows = factory.io.raw.unit_windows(source, unit_times, params.samples_before, params.samples_after)
        unit_windows[probe.channel_map[~probe.connected], :, :] = 0  # zero out the unconnected channels

        if params.output_type == "jrc":
            unit_channels = unit_channels_union(unit_mask, params, probe)
        else:
            unit_channels = factory.generate.generators.threshold_events(unit_windows, params.event_threshold)

        if unit_channels is None:
            log("no channels found for unit", params.verbose)
            continue

        # now create subarray for just appropriate channels
        events = unit_windows[unit_channels, :, :]  # num_channels x num_samples x num_events

        # actually generate the data
        if params.generator_type == "steinmetz":
            if num_events < params.num_singular_values:
                log("not enough events to generate!", params.verbose)
                continue
            art_events = factory.generate.generators.steinmetz(events, params.num_singular_values)
        else:
            raise NotImplementedError(f"generator '{params.generator_type}' does not exist!")

        art_events = scale_events(art_events, params, probe)

        log("done", params.verbose)

        # shift channels
        log("Shifting channels", params.verbose, in_progress=True)
        shifted_channels = factory.generate.shift.shift_channels(unit_channels, params, probe)

        if shifted_channels is None:
            continue  # cause is logged in `shift_channels`

        log("done", params.verbose)

        # jitter events
        log("Jittering events", params.verbose, in_progress=True)
        jittered_times = factory.generate.shift.jitter_events(unit_times, params)
        log("done", params.verbose)

        if jittered_times is None:
            continue

        # write to file
        log("Writing events to file", params.verbose, in_progress=True)
        for i, jittered_center in enumerate(jittered_times):
            jittered_samples = np.arange(jittered_center - params.samples_before,
                                         jittered_center + params.samples_after + 1, dtype=jittered_center.dtype)

            shifted_window = factory.io.raw.read_roi(target, shifted_channels, jittered_samples)
            perturbed_data = shifted_window + art_events[:, :, i]

            factory.io.raw.write_roi(target, shifted_channels, jittered_samples, perturbed_data)

        log("done", params.verbose)

        cc_indices = np.abs(art_events).max(axis=1).argmax(axis=0)
        center_channels = shifted_channels[cc_indices] + 1

        gt_channels.append(center_channels)
        gt_times.append(jittered_times + start_time)
        gt_labels.append(unit_id)

    # finished writing, flush to file
    del source, target

# save everything for later
dirname = params.data_directory

# save ground-truth units for validation
filename = factory.io.gt.save_gt_units(dirname, gt_channels, gt_times, gt_labels)
log(f"Firing times and labels saved to {filename}.", params.verbose)

# save parameter file for later reuse
filename = op.join(dirname, f"params-demo.py")
_write_config(filename, params)
log(f"Parameter file to recreate this run saved at {filename}.", params.verbose)

Loading event times and cluster IDs ... done
Generating ground truth for unit 36 ... done
Shifting channels ... done
Jittering events ... done
Writing events to file ... done
Generating ground truth for unit 83 ... done
Shifting channels ... done
Jittering events ... done
Writing events to file ... done
Generating ground truth for unit 199 ... no channels found for unit
Generating ground truth for unit 243 ... done
Shifting channels ... done
Jittering events ... done
Writing events to file ... done
Generating ground truth for unit 267 ... done
Shifting channels ... channel shift of 20 places events on unconnected channels
Generating ground truth for unit 283 ... no channels found for unit
Generating ground truth for unit 464 ... done
Shifting channels ... done
Jittering events ... done
Writing events to file ... done
Generating ground truth for unit 1074 ... done
Shifting channels ... done
Jittering events ... done
Writing events to file ... done
Generating ground truth for unit 1159 .

In [17]:
assert md5sum(params.raw_target_file) == "f0a179225e0b173e55efa0265d5dea2c"

Success! You now have a reproducible hybrid ground-truth file.