In [80]:
import pandas as pd
from preprocess import load_kilosort_arrays, get_good_cluster_numbers, gen_df
import os
from functools import partial
import numpy as np





def load_table(table=None, d=None):
    '''given a parent path and table name, returns table as pandas df

    params:
        d = parent dir of db tables
        table = name of table. can be path to table if d not specified
    '''

    assert table, 'table to load not specified'
    if not table.endswith('.csv'):
        table = ''.join([table, '.csv'])

    if os.path.isabs(table):
        p = table
    else:
        assert d, 'no parent diectory given to load_table'
        p = os.path.join(d, table)

    try:
        return pd.read_csv(p)
    except (IOError, OSError):
        raise ValueError('''\nTable name specified but not found.
            Tried to load: %s''' % p)
        
def _check_data(tbl1, tbl2, key1, key2):
    m = '''Bad data in table {t1} and {t2}.
        duplicate values in keys {k1} and {k2}'''.format(t1=tbl1, t2=tbl2,
                                                         k1=key1, k2=key2)
    pks1 = set(tbl1[key1].values)
    pks2 = set(tbl2[key2].values)
    assert not pks1.difference(pks2) and not pks2.difference(pks1), m


def _get_recording_id(rd, r_tbl):
    assert 'dat_filename' in r_tbl.columns, 'Cannot find datfile column'
    name = os.path.basename(rd)
    matches = r_tbl[r_tbl['dat_filename'] == name]['recording_id'].values
    assert len(
        matches) == 1, 'Error finding recording id: {}. multiple IDs found'.format(name)
    return matches[0]

def get_spike_times(p, r_id):
    spk_c, spk_tms, c_gps = load_kilosort_arrays(p)
    clusters = get_good_cluster_numbers(c_gps)
    df = gen_df(spk_c, spk_tms, clusters)
    df['recording_id'] = r_id
    return df

def _load_dat_data(p, n_chans=32):
    tmp = np.memmap(p, dtype=np.int16)
    shp = int(len(tmp) / n_chans)
    return np.memmap(p, dtype=np.int16,
                     shape=(shp, n_chans))


def _extract_waveforms(spk_tms, raw_data, n_spks=800, n_samps=240, n_chans=32):
    assert len(spk_tms) > n_spks, 'Not ennough spikes'

    window = np.arange(int(-n_samps / 2), int(n_samps / 2))
    wvfrms = np.zeros((n_spks, n_samps, n_chans))
    for i in range(n_spks):
        srt = int(spk_tms.iloc[i] + window[0])
        end = int(spk_tms.iloc[i] + window[-1] + 1)
        srt = srt if srt > 0 else 0
        wvfrms[i, :, :] = raw_data[srt:end, :]
    wvfrms = pd.DataFrame(np.mean(wvfrms, axis=0),
                          columns=[''.join(['Chan_', str(i)]) for
                                   i in range(1, n_chans + 1)])
    norm = wvfrms - np.mean(wvfrms)
    tmp = norm.apply(np.min(axis=0))
    good_chan = tmp.idxmin()
    wvfrms = wvfrms.loc[:, good_chan]
    return wvfrms, good_chan


def update_neurons(n_tbl, new_data, r_id):

    def _reformat(new_data):
        g = new_data[new_data['spike_times'] < 108000000].groupby('cluster_id')[
            'spike_times']
        col = g.apply(
            lambda x: 1 if np.sum(x) > 800 else 0)
        new_data = new_data.drop('spike_times', axis=1).drop_duplicates()
        new_data = new_data.set_index('cluster_id').join(col).reset_index()
        new_data.columns = ['cluster_id', 'recording_id', 'has_bl']
        return new_data

    new_data = _reformat(new_data)
    print(n_tbl)

    if len(n_tbl) == 0:
        out = new_data
        out.index = pd.RangeIndex(0, len(new_data))
    else:
        if np.sum(n_tbl['recording_id'] == r_id) != 0:
            n_tbl = n_tbl[n_tbl['recording_id'] != r_id]
        n_tbl.set_index('neuron_id', inplace=True)
        i = n_tbl.index[-1]
        new_data.index = pd.RangeIndex(i + 1, i + 1 + len(new_data))
        out = pd.concat([n_tbl, new_data])
    out.index.name = 'neuron_id'
    return out.reset_index()

In [78]:
rd = '/media/ruairi/UBUNTU/CIT_WAY/dat_files/cat/2018-07-27'

db_dir = '/home/ruairi/data/db'

tlbs_to_load = ['neurons', 'spike_times', 'recordings', 'experiments']
loader = partial(load_table, d=db_dir)
try:
        nrns, spktms, rcds, exp = list(map(loader, tlbs_to_load))
        _check_data(nrns, spktms, 'neuron_id', 'neuron_id')
except (ValueError, AssertionError):
        print('Error on data import')
        raise ValueError

        
def get_waveforms(spike_data, r_id, rd):
    raw_data = _load_dat_data(p=os.path.join(
        rd, os.path.basename(rd)) + '.dat')
    f1 = partial(_extract_waveforms, raw_data=raw_data, ret='data')
    f2 = partial(_extract_waveforms, raw_data=raw_data, ret='')
    waveforms = spike_data.groupby('cluster_id')['spike_times'].apply(
        f1, raw_data=raw_data).apply(pd.Series).reset_index()
    chans = spike_data.groupby('cluster_id')[
        'spike_times'].apply(f2, raw_data=raw_data).apply(pd.Series).reset_index()
    chans.columns = ['cluster_id', 'channel']
    waveforms.columns = ['cluster_id', 'sample', 'value']
    return waveforms, chans


In [8]:
r_id = _get_recording_id(rd=rd, r_tbl=rcds)

In [9]:
spike_data = get_spike_times(p=rd, r_id=r_id)

In [11]:
spike_data.head()

Unnamed: 0,cluster_id,spike_times,recording_id
8,5,737,9
9,98,758,9
10,60,832,9
11,26,836,9
12,45,858,9


In [13]:
raw_data = _load_dat_data(os.path.join(rd, os.path.basename(rd),
                                        ) + '.dat', 32)

In [82]:
import matplotlib.pyplot as plt

def _extract_waveforms(spk_tms, raw_data, ret='data', n_spks=800, n_samps=240, n_chans=32):
    assert len(spk_tms) > n_spks, 'Not ennough spikes'
    spk_tms = spk_tms.values
    window = np.arange(int(-n_samps / 2), int(n_samps / 2))
    wvfrms = np.zeros((n_spks, n_samps, n_chans))
    for i in range(n_spks):
        srt = int(spk_tms[i] + window[0])
        end = int(spk_tms[i] + window[-1] + 1)
        srt = srt if srt > 0 else 0
        wvfrms[i, :, :] = raw_data[srt:end, :]
    wvfrms = pd.DataFrame(np.mean(wvfrms, axis=0),
                          columns=range(1, n_chans + 1))
    norm = wvfrms - np.mean(wvfrms)
    tmp = norm.apply(np.min, axis=0)
    good_chan = tmp.idxmin()
    wvfrms = wvfrms.loc[:, good_chan]
    if ret == 'data':
        return wvfrms
    else:
        return good_chan

In [74]:
f = partial(_extract_waveforms, raw_data=raw_data, ret='data')
v = partial(_extract_waveforms, raw_data=raw_data, ret='')
t = spike_data.groupby('cluster_id')['spike_times'].apply(f, raw_data=raw_data).apply(pd.Series).reset_index()
c = spike_data.groupby('cluster_id')['spike_times'].apply(v, raw_data=raw_data).apply(pd.Series).reset_index()

In [83]:
waveforms, chans = get_waveforms(spike_data=spike_data, r_id=r_id, rd=rd)

In [84]:
waveforms.head()

Unnamed: 0,cluster_id,sample,value
0,1,0,54.73
1,1,1,53.97875
2,1,2,53.5075
3,1,3,52.30875
4,1,4,55.79625


In [85]:
chans.head()

Unnamed: 0,cluster_id,channel
0,1,6
1,3,20
2,4,6
3,5,16
4,6,2


In [86]:
spike_data.head()

Unnamed: 0,cluster_id,spike_times,recording_id
8,5,737,9
9,98,758,9
10,60,832,9
11,26,836,9
12,45,858,9


In [87]:
nrns.head()

Unnamed: 0,neuron_id,recording_id,channel,has_bl


In [91]:
spike_data.shape

(2354509, 3)

In [95]:
spike_data2 = pd.merge(spike_data, chans, on='cluster_id')

In [96]:
spike_data2.shape

(2354509, 4)

In [97]:
spike_data2

Unnamed: 0,cluster_id,spike_times,recording_id,channel
0,5,737,9,16
1,5,3306,9,16
2,5,4083,9,16
3,5,4921,9,16
4,5,10259,9,16
5,5,13783,9,16
6,5,15310,9,16
7,5,16089,9,16
8,5,17652,9,16
9,5,20980,9,16
