# Extract Spike Times

Code to extract spike times from cluster files and create the 'times_' files. 

In [1]:
import os

import h5py
import numpy as np

In [2]:
import sys
conv_path = '/home1/tom.donoghue/repos/convnwb'
sys.path.append(conv_path)
from convnwb.io import get_files
from convnwb.session import SDB

## Settings

In [3]:
# Define base data path
base_path = '/scratch/tom.donoghue/'

In [4]:
# Define subject information
subj = 'YDR'
session = '3'

In [5]:
# Initialize directory object
db = SDB(subj, session, base_path)

In [6]:
# Check the directory for the split files
db.split_files

PosixPath('/scratch/tom.donoghue/YDR/session_3/split_files')

In [7]:
db.sorting

PosixPath('/scratch/tom.donoghue/YDR/session_3/sorting')

## Extract Spike times

Loop through directories and extract spike time data.

In [8]:
ind = 0
for subdir, dirs, files in os.walk(db.split_files):

    for file in files:

        if file == 'sort_cat.h5':

            # Get relevant names
            channel_name = subdir.split('/')[-2]
            parent_directory = '/'.join((subdir.split('/')[:-1]))

            # Extract the sorting results
            sorted_data = h5py.File(os.path.join(subdir, file), 'r')
            sorted_groups = np.array(sorted_data.get('groups'))
            sorted_idx = np.array(sorted_data.get('index'))
            skipped = np.where(np.diff(sorted_idx) > 1)
            clusters = np.unique(sorted_groups[(sorted_groups[:, 1] > 0), 1])

            # Get the spiking results from the raw data
            raw_data = h5py.File(os.path.join(parent_directory, f'data_{channel_name}.h5'), 'r')
            raw_data_neg = raw_data.get('neg')
            spike_waveforms = np.array(raw_data_neg.get('spikes')) * 0.25 # Scale because of Blackrock
            spike_times = np.array(raw_data_neg.get('times'))
            spike_clusters = np.zeros(len(spike_times))

            # Check for channels with no clusters detected in the channel
            if np.size(clusters) < 1:
                
                cluster_msg = 'no clusters, skipping.'
                cluster_class = [spike_clusters, spike_times]

            else:

                cluster_msg = 'found clusters, extracting.'
                spike_classes = np.zeros(len(spike_times))

                # Gets the classes for all the initial spikes detected
                temp_classes = np.array(sorted_data.get('classes'))
                valid_classes_clusters = np.squeeze(sorted_groups[np.where(sorted_groups[:, 1] > 0), :])

                for k in range(len(sorted_idx)):

                    spike_classes[sorted_idx[k]] = temp_classes[k]

                    # If there is only one cluster
                    if valid_classes_clusters.ndim == 1:

                        if spike_classes[sorted_idx[k]] == valid_classes_clusters[0]:
                            spike_clusters[sorted_idx[k]] = valid_classes_clusters[1]

                    else:
                        for ind in range(np.shape(valid_classes_clusters)[0]):
                            if spike_classes[sorted_idx[k]] == valid_classes_clusters[ind, 0]:
                                spike_clusters[sorted_idx[k]] = valid_classes_clusters[ind, 1]

                # Get rid of non-sorted 'spikes'
                if valid_classes_clusters.ndim == 1:
                    spike_waveforms = spike_waveforms[np.in1d(spike_classes, valid_classes_clusters[0]), :]
                    spike_times = spike_times[np.in1d(spike_classes, valid_classes_clusters[0])]
                    spike_clusters = spike_clusters[np.in1d(spike_classes, valid_classes_clusters[0])]
                    spike_classes = spike_classes[np.in1d(spike_classes, valid_classes_clusters[0])]

                else:
                    spike_waveforms = spike_waveforms[np.in1d(spike_classes, valid_classes_clusters[:, 0]), :]
                    spike_times = spike_times[np.in1d(spike_classes, valid_classes_clusters[:, 0])]
                    spike_clusters = spike_clusters[np.in1d(spike_classes, valid_classes_clusters[:, 0])]
                    spike_classes = spike_classes[np.in1d(spike_classes, valid_classes_clusters[:, 0])]

                cluster_class = [spike_clusters, spike_times]

                file_name = f'times_{channel_name}_{str(ind)}.h5'
                hf = h5py.File(str(db.sorting / file_name), 'w')
                g1 = hf.create_group('spike_data_sorted')
                g1.create_dataset('spike_waveforms', data=spike_waveforms)
                g1.create_dataset('spike_times', data=spike_times)
                g1.create_dataset('spike_clusters', data=spike_clusters)
                g1.create_dataset('spike_classes', data=spike_classes)
                hf.close()
                
                ind += 1
                
            # Print out status
            print('Processing: ', channel_name, ' - ', cluster_msg)

Processing:  chan_10  -  found clusters, extracting.
Processing:  chan_11  -  found clusters, extracting.
Processing:  chan_12  -  found clusters, extracting.
Processing:  chan_13  -  found clusters, extracting.
Processing:  chan_14  -  found clusters, extracting.
Processing:  chan_15  -  found clusters, extracting.
Processing:  chan_16  -  found clusters, extracting.
Processing:  chan_17  -  found clusters, extracting.
Processing:  chan_18  -  found clusters, extracting.
Processing:  chan_19  -  found clusters, extracting.
Processing:  chan_1  -  found clusters, extracting.
Processing:  chan_20  -  found clusters, extracting.
Processing:  chan_21  -  found clusters, extracting.
Processing:  chan_22  -  found clusters, extracting.
Processing:  chan_23  -  found clusters, extracting.
Processing:  chan_24  -  found clusters, extracting.
Processing:  chan_25  -  found clusters, extracting.
Processing:  chan_26  -  found clusters, extracting.
Processing:  chan_27  -  found clusters, extrac