In [1]:
kilosort_folder = r'C:\Users\Rory\raw_data\CIT_WAY\dat_files\cat'
recording = r'2018-05-01_01'
sep = '\\'

In [2]:
import os
import numpy as np
import pandas as pd

def load_kilosort_arrays(recording):
    '''
    Loads arrays generated during kilosort into numpy arrays and pandas DataFrames
    Parameters:
        recording       = name of the recording being analysed
    Returns:
        spike_clusters  = numpy array of len(num_spikes) identifying the cluster from which each spike arrose
        spike_times     = numpy array of len(num_spikes) identifying the time in samples at which each spike occured
        cluster_groups  = pandas DataDrame with one row per cluster and column 'cluster_group' identifying whether
                          that cluster had been marked as 'Noise', 'MUA' or 'Good'
    '''
    spike_clusters = np.load('spike_clusters.npy')
    spike_times = np.load('spike_times.npy')
    cluster_groups = pd.read_csv('cluster_groups.csv', sep='\t')
    try:  # check data quality
        assert np.shape(spike_times.flatten()) == np.shape(spike_clusters)
    except AssertionError:
        AssertionError('Array lengths do not match in recording {}'.format(
            recording))
    return spike_clusters, spike_times, cluster_groups


def load_data(recording, kilosort_folder, verbose, sep):
    '''
    Loads arrays generated during kilosort into numpy arrays and pandas DataFrames
    Parameters:
        recording       = name of the recording being analysed
        kilosort_folder = the name of the root directory in which subdirectories for each recording are stored
                          inside the sub-directories should be the files generated during spike sorting with
                          kilosort and phy
        verbose         = True or False
        sep             = os directory delimeter e.g. '/'
    Returns:
        spike_clusters  = numpy array of len(num_spikes) identifying the cluster from which each spike arrose
        spike_times     = numpy array of len(num_spikes) identifying the time in samples at which each spike occured
        cluster_groups  = pandas DataDrame with one row per cluster and column
                          'cluster_group' identifying whetherthat cluster had been marked as 'Noise', 'MUA' or 'Good'
    '''
    if verbose:
        print('\nLoading Data:\t{}\n'.format(recording))
    os.chdir(sep.join([kilosort_folder, recording]))
    spike_clusters, spike_times, cluster_groups = load_kilosort_arrays(
        recording)
    return spike_clusters, spike_times, cluster_groups


def get_good_cluster_numbers(cluster_groups_df):
    '''
    Takes the cluster_groups pandas DataFrame fomed during data loading and returns a numpy array of cluster
    ids defined as 'Good' during kilosort and phy spike sorting
    Parameters:
        cluster_groups_df   = the pandas DataFrame containing information on which cluster is 'Good', 'Noise' etc.
    Returns:
        A numpy array of 'Good' cluster ids
    '''
    good_clusters_df = cluster_groups_df.loc[cluster_groups_df['group'] == 'good', :]
    return good_clusters_df['cluster_id'].values

In [3]:
spike_clusters, spike_times, cluster_groups = load_data(recording=recording,
                                                        kilosort_folder=kilosort_folder,
                                                        verbose=True,
                                                        sep=sep)


Loading Data:	2018-05-01_01



In [4]:
good_cluster_numbers = get_good_cluster_numbers(cluster_groups)

In [5]:
spike_times_0 = spike_times[spike_clusters==good_cluster_numbers[13]]

In [6]:
extract_st = spike_times[:10000]
num_waveforms_toload = len(extract_st)

In [7]:
num_channels = 32
samples_around_spike = np.arange(-60, 60)
num_waveform_samples = len(samples_around_spike)

In [8]:
empty_template = np.zeros((num_waveforms_toload +1,
                         num_channels,
                         num_waveform_samples))

In [9]:
temp = np.memmap('2018-05-01_01.dat', dtype='int64')
total_len = len(temp)
real_len = int(total_len/num_channels)

In [10]:
mmf = np.memmap('2018-05-01_01.dat', dtype='int64', shape=(32, real_len))

In [11]:
for spike, ind in enumerate(range(num_waveforms_toload)):
    temp_wf = mmf[:, int(extract_st[spike]+samples_around_spike[0]):int(extract_st[spike]+samples_around_spike[-1]+1)]
    empty_template[ind, :, :] = temp_wf[:, :]

In [12]:
plt.plot(empty_template[6, 7, :])

NameError: name 'plt' is not defined

In [13]:
means = np.mean(empty_template, axis=0)

In [14]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

f, a = plt.subplots(nrows=16, ncols=2, figsize=(15,15))
for chan in range(means.shape[0]):
    if chan >= 16:
        col = 1
        chan -= 16
    else:
        col = 0
    a[chan, col].plot(np.arange(-60, 60)/30, means[chan, :])
plt.show()

<Figure size 1500x1500 with 32 Axes>

In [16]:
data = np.fromfile(recording + '.dat')

In [35]:
data.shape[0]/32/30000

1984.7253333333333

In [31]:
reshaped = data.reshape(31, int(data.shape[0]/32))

ValueError: cannot reshape array of size 1905336320 into shape (31,59541760)

In [28]:
reshaped.shape[1]/30000/60

33.07875555555555

In [36]:
mmf.shape

(32, 59541760)