In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt; plt.style.use('ggplot')

In [2]:
kilosort_folder = r'C:\Users\Rory\raw_data\CIT_WAY\dat_files\cat'
recording = r'2018-05-01_01'
sep = '\\'

num_spikes_for_averaging = 1000
num_channels = 32
num_samples_per_waveform = 120
temp_folder =r'C:\Users\Rory\Documents\Work\temp'

In [3]:
cols = [''.join(['Chan_', str(num)]) for num in range(1, 33)]
path = os.path.join(kilosort_folder, recording, recording) + '.dat'
waveform_window = np.arange(-num_samples_per_waveform/2,
                                 num_samples_per_waveform/2)

In [4]:
def load_kilosort_arrays(recording):
    '''
    Loads arrays generated during kilosort into numpy arrays and pandas DataFrames
    Parameters:
        recording       = name of the recording being analysed
    Returns:
        spike_clusters  = numpy array of len(num_spikes) identifying the cluster from which each spike arrose
        spike_times     = numpy array of len(num_spikes) identifying the time in samples at which each spike occured
        cluster_groups  = pandas DataDrame with one row per cluster and column 'cluster_group' identifying whether
                          that cluster had been marked as 'Noise', 'MUA' or 'Good'
    '''
    spike_clusters = np.load('spike_clusters.npy')
    spike_times = np.load('spike_times.npy')
    cluster_groups = pd.read_csv('cluster_groups.csv', sep='\t')
    try:  # check data quality
        assert np.shape(spike_times.flatten()) == np.shape(spike_clusters)
    except AssertionError:
        AssertionError('Array lengths do not match in recording {}'.format(
            recording))
    return spike_clusters, spike_times, cluster_groups


def load_data(recording, kilosort_folder, verbose, sep):
    '''
    Loads arrays generated during kilosort into numpy arrays and pandas DataFrames
    Parameters:
        recording       = name of the recording being analysed
        kilosort_folder = the name of the root directory in which subdirectories for each recording are stored
                          inside the sub-directories should be the files generated during spike sorting with
                          kilosort and phy
        verbose         = True or False
        sep             = os directory delimeter e.g. '/'
    Returns:
        spike_clusters  = numpy array of len(num_spikes) identifying the cluster from which each spike arrose
        spike_times     = numpy array of len(num_spikes) identifying the time in samples at which each spike occured
        cluster_groups  = pandas DataDrame with one row per cluster and column
                          'cluster_group' identifying whetherthat cluster had been marked as 'Noise', 'MUA' or 'Good'
    '''
    if verbose:
        print('\nLoading Data:\t{}\n'.format(recording))
    os.chdir(sep.join([kilosort_folder, recording]))
    spike_clusters, spike_times, cluster_groups = load_kilosort_arrays(
        recording)
    return spike_clusters, spike_times, cluster_groups


def get_good_cluster_numbers(cluster_groups_df):
    '''
    Takes the cluster_groups pandas DataFrame fomed during data loading and returns a numpy array of cluster
    ids defined as 'Good' during kilosort and phy spike sorting
    Parameters:
        cluster_groups_df   = the pandas DataFrame containing information on which cluster is 'Good', 'Noise' etc.
    Returns:
        A numpy array of 'Good' cluster ids
    '''
    good_clusters_df = cluster_groups_df.loc[cluster_groups_df['group'] == 'good', :]
    return good_clusters_df['cluster_id'].values

In [5]:
def create_good_spikes_df(good_cluster_numbers, spike_clusters, spike_times):
    df = pd.DataFrame({'spike_time':  spike_times.flatten(), 'spike_cluster':spike_clusters.flatten()})
    df = df.loc[df['spike_cluster'].isin(good_cluster_numbers), :]
    return df

In [6]:
def create_spiketimes_series(df, cluster, num_spikes):
    spike_times = df.loc[df['spike_cluster']==cluster, 'spike_time'].iloc[:num_spikes]
    spike_times.index = range(len(spike_times))  # change index 
    return spike_times

In [7]:
def load_raw_data(path):
    temp = np.memmap(path, dtype=np.int16)
    total_len = len(temp)
    real_len = int(total_len/num_channels)
    raw_data = np.memmap(path, dtype=np.int16, shape=(real_len, num_channels))
    return raw_data

In [8]:
def get_waveforms_from_raw_data(num_spikes_for_averaging, num_samples_per_waveform, 
                                num_channels, spike_times, waveform_window,
                               raw_data):
    
    empty_template = np.zeros((num_spikes_for_averaging,
                               num_samples_per_waveform,
                               num_channels))
    
    for spike in range(num_spikes_for_averaging):
        start_index = int(spike_times.iloc[spike]+waveform_window[0])  # start of waveform in raw data
        end_index = int((spike_times.iloc[spike]+waveform_window[-1])+1)  # end of waveform in raw data

        waveform = raw_data[start_index:end_index, 0:num_channels]  # extract waveform from raw data
        empty_template[spike, :, :] = waveform[:,:]  #  add extracted waveform to 3d matrix
    mean_waveform = np.mean(empty_template, axis=0)
    waveform_per_channel_df = pd.DataFrame(mean_waveform, columns=cols)
    return waveform_per_channel_df

In [9]:
def choose_max_amp_channel(df):
    maxes = df.apply(np.max, axis=0)
    correct_chan = df.loc[:, maxes.idxmax()]
    return correct_chan

In [10]:
def create_df_for_peak_finding(chan_df):
    ave_waveform = pd.DataFrame({'y_values':chan_df})
    ave_waveform['reverse_change'] = ave_waveform.diff(periods=-1)
    ave_waveform['diff_reverse'] = np.where(ave_waveform.y_values.diff(periods=-1) > 0, 'increase', 'degrease')
    return ave_waveform

In [11]:
def up_down_up(df, cluster, thresh=0.1):
    
    peak_amp = df['y_values'].max()
    peak_sample = df['y_values'].idxmax()
    
    min_amp = df['y_values'].min()
    min_sample = df['y_values'].idxmin()
    
    baseline_amp = peak_amp * thresh
    baseline_sample = df.loc[(df['y_values']<peak_amp) 
                       & (df.index < peak_sample) 
                       & (df['diff_reverse']=='decline') 
                       & (df['y_values']>=baseline_amp), 'y_values'].idxmin()
    
    SW_peak = np.absolute(peak_sample - baseline_sample)/30
    SW_troff = np.absolute(min_sample - baseline_sample)/30
    SW_base = np.absolute(return_baseline_sample - baseline_sample)/30
    
    clu_df = pd.DataFrame({'cluster': str(cluster), 'SW_peak': SW_peak, 'SW_troff': SW_troff, 'SW_base': SW_base})
    
    return clu_df

In [12]:
path

'C:\\Users\\Rory\\raw_data\\CIT_WAY\\dat_files\\cat\\2018-05-01_01\\2018-05-01_01.dat'

In [13]:
raw_data =  load_raw_data(path)
spike_clusters, spike_times, cluster_groups = load_data(recording=recording,
                                                        kilosort_folder=kilosort_folder,
                                                        verbose=False,
                                                        sep=sep)
good_cluster_numbers = get_good_cluster_numbers(cluster_groups)
df = create_good_spikes_df(good_cluster_numbers, spike_clusters, spike_times)

In [14]:
def get_cluster_waveforms(recording, good_spikes_df, cluster, num_spikes,
                         num_samples_per_waveform, num_channels, waveform_window,
                         raw_data, temp_folder):
    
    
    spike_times = create_spiketimes_series(df, 
                                           cluster=cluster, 
                                           num_spikes=num_spikes)
    
    df_all_chans = get_waveforms_from_raw_data(num_spikes_for_averaging, num_samples_per_waveform, 
                                num_channels, spike_times, waveform_window,
                               raw_data)
    df_max_chan = choose_max_amp_channel(df_all_chans)
    ave_waveform = create_df_for_peak_finding(chan_df=df_max_chan)
    
    f, a = plt.subplots(ncols=2, figsize=(12,8))
    df_all_chans.plot(ax=a[0])
    df_max_chan.plot(ax=a[1])
    
    plt.savefig(os.path.join(temp_folder, recording) + str(cluster) + '.png')
    plt.close()
    try:
        clu_df = up_down_up(df=ave_waveform, cluster=cluster_to_plot, thresh=0.1)
    except:
        with open(os.path.join(temp_folder, 'bad_waves.txt'), mode='a') as file:
            file.write('Bad Waveform in {rec}: Cluster {clu}\n'.format(rec=recording, clu=cluster))
    

In [15]:
for cluster in good_cluster_numbers:
    get_cluster_waveforms(recording=recording,
                          good_spikes_df=df, 
                          cluster=cluster, 
                          num_spikes=num_spikes_for_averaging,
                          num_samples_per_waveform=num_samples_per_waveform,
                          num_channels=num_channels,
                          waveform_window=waveform_window,
                          raw_data=raw_data, 
                          temp_folder=temp_folder)

### TODO:

    - Work around spikes with no positive peak
    - Turn into scirpt (loop over recordings, loop over clusters)