In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter, freqz
from scipy import signal as ss
plt.style.use('ggplot')

In [47]:
def load_kilosort_arrays(kilosort_folder, recording, verbose):
    if verbose:
        print('Loading Kilosort arrays: {}\n\n'.format(recording))
    path = os.path.join(kilosort_folder, recording)
    spike_clusters = np.load(os.path.join(path, 'spike_clusters.npy'))
    spike_times = np.load(os.path.join(path, 'spike_times.npy'))
    cluster_groups = pd.read_csv(os.path.join(path, 'cluster_groups.csv'), sep='\t')

    return spike_clusters, spike_times, cluster_groups


def load_raw_data(kilosort_folder, recording, num_channels):
    path = os.path.join(kilosort_folder, recording, recording) + '.dat'
    temp_data = np.memmap(path, dtype=np.int16)
    adjusted_len = int(len(temp_data) / num_channels)  # adjust for number of channels

    raw_data = np.memmap(path, dtype=np.int16, shape=(adjusted_len, num_channels))
    return raw_data


def get_good_cluster_numbers(cluster_groups):
    good_clusters = cluster_groups[cluster_groups['group'] == 'good']
    return good_clusters['cluster_id'].values


def gen_good_spikes_df(spike_times, spike_clusters, good_cluster_numbers):
    all_spikes_df = pd.DataFrame({'spike_time': spike_times.flatten(),
                                  'spike_cluster': spike_clusters.flatten()})
    good_spikes_df = all_spikes_df[all_spikes_df['spike_cluster'].isin(good_cluster_numbers)]
    return good_spikes_df


def gen_spiketimes_series(good_spikes_df, cluster, num_spikes, last_spikes):
    if len(good_spikes_df) < num_spikes:
        num_spikes = len(good_spikes_df)
    if last_spikes:
        spike_times = good_spikes_df[good_spikes_df['spike_cluster'] == cluster].iloc[-num_spikes:]
    else:
        spike_times = good_spikes_df[good_spikes_df['spike_cluster'] == cluster].iloc[num_spikes:]
    spike_times.index = range(len(spike_times))
    return spike_times


def extract_waveforms(num_spikes, num_samples, num_channels,
                      spike_times, raw_data):
    waveform_window = np.arange(-num_samples / 2, num_samples / 2)
    cols = [''.join(['Chan_', str(num)]) for num in range(0, num_channels)]

    if num_spikes >= len(spike_times['spike_time']):
        num_spikes = len(spike_times['spike_time']) - 20

    empty_template = np.zeros((num_spikes, num_samples, num_channels))
    for spike in range(num_spikes):
        start_index = int(spike_times['spike_time'].iloc[spike] + waveform_window[0])
        end_index = int((spike_times['spike_time'].iloc[spike] + waveform_window[-1]) + 1)
        waveform = raw_data[start_index:end_index, :]
        empty_template[spike, :, :] = waveform[:, :]

    waveform_per_chan = np.mean(empty_template, axis=0)
    waveform_per_chan = pd.DataFrame(waveform_per_chan, columns=cols)
    return waveform_per_chan


def choose_channel(df, method, broken_chans):
    '''
    Choose either channel with max or minumum values
    method == 'max' or 'min'
    '''
    if broken_chans:
        for chan in broken_chans:
            df.drop('Chan_{}'.format(str(chan)), inplace=True, axis=1)
    if method.lower() == 'max':
        chan = df.apply(np.max, axis=0)
        selected_chan = df.loc[:, chan.idxmax()]
        chan = chan.idxmax()
    elif method.lower() == 'min':
        chan = df.apply(np.min, axis=0)
        selected_chan = df.loc[:, chan.idxmin()]
        chan = chan.idxmin()
    else:
        raise ValueError('Unable to parse channel selection method.\nEnter \'min\' or\'max\'')

    return selected_chan, chan



def merge_dfs(df_list, broadcast, **kwargs):
    df = pd.concat(df_list)
    df.index = range(len(df))
    if broadcast:
        df['recording'] = kwargs['recording']
    return df


In [48]:
verbose = True
recording = '2018-05-01_01'
kilosort_folder = r"F:\CIT_WAY"
temp_folder=r'F:\CIT_WAY\csvs\temp'
num_channels = 32
num_spikes = 1100
last_spikes = False
num_samples = 240
spike_selection_method = 'min'
broken_channels = [22]

In [49]:
spike_clusters, spike_times, cluster_groups = load_kilosort_arrays(recording=recording,
                                                                           kilosort_folder=kilosort_folder, verbose=verbose)

Loading Kilosort arrays: 2018-05-01_01




In [50]:
raw_data = load_raw_data(recording=recording,
                                 kilosort_folder=kilosort_folder,
                                 num_channels=num_channels)
good_cluster_numbers = get_good_cluster_numbers(cluster_groups)
good_spikes_df = gen_good_spikes_df(spike_times=spike_times,
                                            spike_clusters=spike_clusters,
                                            good_cluster_numbers=good_cluster_numbers)

In [75]:
df_list = []
for clusters in good_cluster_numbers:
    spiketimes_series = gen_spiketimes_series(good_spikes_df=good_spikes_df,
                                                          cluster=clusters,
                                                          num_spikes=num_spikes,
                                                          last_spikes=last_spikes)
    waveform_per_chan = extract_waveforms(num_spikes=num_spikes,
                                                      num_samples=num_samples,
                                                      num_channels=num_channels,
                                                      spike_times=spiketimes_series,
                                                      raw_data=raw_data)
    one_channel, chan = choose_channel(df=waveform_per_chan,
                                                   method=spike_selection_method,
                                                   broken_chans=broken_channels)
    cluster_df = pd.DataFrame(one_channel).transpose()
    df_list.append(cluster_df)

In [86]:
df_merged = merge_dfs(df_list, broadcast=False, recording=recording)
df_merged.set_index(good_cluster_numbers, inplace=True)
df_merged.index.name = 'Clusters'
df_merged.columns.name = 'Sample time point'
df_merged.to_csv('F:\CIT_WAY\waveform_values_array.csv')

In [84]:
df_merged

Sample time point,0,1,2,3,4,5,6,7,8,9,...,230,231,232,233,234,235,236,237,238,239
Clusters,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
20,-6.766364,-3.228182,-2.001818,-4.191818,-4.14,-2.102727,-2.451818,-4.268182,-1.958182,-2.284545,...,6.523636,4.106364,7.21,6.481818,6.563636,8.334545,8.627273,7.758182,6.16,8.134545
22,-0.588182,-1.01,-2.565455,-2.794545,1.064545,-0.603636,-2.763636,-1.935455,-2.984545,-1.768182,...,17.679091,17.677273,16.270909,14.659091,17.831818,17.834545,17.035455,16.113636,14.865455,12.794545
24,0.154545,-3.001818,-2.642727,-1.373636,-1.862727,0.244545,-1.782727,-3.62,-2.044545,-0.85,...,5.59,6.335455,6.946364,8.868182,8.24,8.776364,6.994545,5.557273,3.251818,5.424545
33,3.211818,2.966364,2.295455,2.789091,2.837273,0.269091,-1.83,1.923636,2.532727,-0.367273,...,10.998182,10.671818,11.35,10.028182,12.112727,11.429091,11.568182,12.605455,10.92,13.231818
57,-6.506364,-6.645455,-5.041818,-8.246364,-10.225455,-9.194545,-6.309091,-7.777273,-7.669091,-6.001818,...,-0.81,-1.394545,-1.102727,-1.694545,0.452727,-0.832727,-3.227273,-2.531818,0.628182,1.308182
79,-2.821818,-1.880909,-1.75,-4.599091,-3.284545,-1.582727,-3.307273,-3.184545,-2.772727,-3.552727,...,11.726364,12.600909,11.424545,9.378182,10.200909,10.770909,11.83,12.634545,10.094545,10.412727
82,-1.631818,-3.225455,-5.061818,-7.17,-5.506364,-4.319091,-5.782727,-7.599091,-7.572727,-5.161818,...,2.920909,1.942727,3.816364,2.586364,2.313636,1.661818,1.364545,2.099091,4.223636,2.572727
123,-11.291818,-6.982727,-6.121818,-7.840909,-8.240909,-4.234545,-3.964545,-5.183636,-5.342727,-3.234545,...,16.717273,16.430909,18.782727,18.26,18.430909,19.830909,18.561818,16.154545,19.38,19.677273
128,-6.94,-7.700909,-3.966364,-5.582727,-6.410909,-4.225455,-5.770909,-8.518182,-8.259091,-6.47,...,10.608182,8.330909,10.494545,9.533636,7.201818,12.358182,9.998182,9.319091,12.335455,11.742727
129,-0.527273,-1.161818,-0.075455,-2.278182,0.094545,-4.461818,0.416364,-1.965455,0.373636,0.108182,...,2.739091,1.852727,3.335455,6.4,0.344545,1.099091,7.557273,4.099091,0.947273,1.608182
