In [151]:
%matplotlib inline
from matplotlib import pyplot as plt
from heka_reader import Bundle

import numpy as np
import operator
import math

plt.rcParams['figure.figsize'] = (20, 12)

In [152]:
def get_spike_indices(trace, edge, num_stds):
    """
    Gets indices of spikes in extracellular recordings.
    Adapted from Bartosz Telenczuk: https://github.com/btel/SpikeSort
    """
    thresh = np.std(trace) * num_stds
    
    op1, op2 = operator.lt, operator.gt
    
    edges = ['rising', 'falling']
    
    assert edge in edges
    if edge == 'falling':
        op1, op2, = op2, op1
        thresh = -thresh
    
    i, = np.where(op1(trace[:-1], thresh) & op2(trace[1:], thresh))
    
    return i

In [153]:
def fit_sine_polar(x, y):
    """
    Fits a sine wave with a frequency of 1 to data from a polar plot.
    Adapted from Ed Tate: http://exnumerus.blogspot.com/2010/04/how-to-fit-sine-wave-example-in-python.html
    """
    rows = [[np.sin(t), np.cos(t), 1] for t in x]
    
    a = np.matrix(rows)
    b = np.matrix(y).T
    
    w = np.linalg.lstsq(a, b)[0]
    
    phase = math.atan2(w[1, 0], w[0, 0])
    amplitude = np.linalg.norm([w[0, 0], w[1, 0]], 2)
    bias = w[2, 0]
    
    return (phase, amplitude, bias)

In [158]:
ds_cell = r'C:\Users\tomlinsa\Downloads\Alex_stim\spike_analysis\data\DS traces\R2P_161117_01.dat'
# ds_cell = r'C:\Users\tomlinsa\Downloads\Alex_stim\spike_analysis\data\R2P_160817_01-jump_offalpha.dat'

bundle = Bundle(ds_cell)
data = bundle.data

In [159]:
# get index of trigger spikes
sweeps  = [i for i in range(12)]
indices = [[0, 5, sweep, 2] for sweep in sweeps]
traces  = [data[index] for index in indices]

onsets = [get_spike_indices(trace, 'rising', 15)[-1] for trace in traces]

IndexError: list index out of range

In [None]:
# get spike indices, starting at trigger onset
sweeps  = [i for i in range(12)]
indices = [[0, 1, sweep, 0] for sweep in sweeps]
traces  = [data[index] for index in indices]
# traces  = [trace[onset:] for trace, onset in zip(traces, onsets)]

spike_indices = [get_spike_indices(trace, 'falling', 4) for trace in traces]
num_spikes = [len(spike_index) for spike_index in spike_indices]

In [None]:
dirs = [0, 180, 30, 210, 60, 240, 90, 270, 120, 300, 150, 330]
# dirs = dirs[:3]
print(dirs)

directions = np.array(dirs) / 180 * math.pi
assert len(directions) == len(num_spikes)

plt.polar(directions, num_spikes, '.r')
print(directions, num_spikes)

phase, amplitude, bias = fit_sine_polar(directions, num_spikes)

x = np.arange(0, 2*math.pi+0.05, 0.05)
y_est = amplitude * np.sin(x + phase) + bias

plt.polar(x, y_est, '-k')
plt.polar(directions, num_spikes, '.b')

pref_dir = math.pi / 2 - phase
if pref_dir < 0:
    pref_dir += 2 * math.pi
pref_fit = amplitude * math.sin(pref_dir + phase) + bias

plt.polar([pref_dir, 0], [pref_fit*1.2, 0], '-r')
pref_dir / math.pi * 180