In [15]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
from scipy.interpolate import interp1d

from cca.data_util import sum_over_chunks

In [21]:
def load_sabes_data(filename, bin_width_s=.100):
    #Load MATLAB file
    with h5py.File(filename, "r") as f:
        #Get channel names (e.g. M1 001 or S1 001)
        num_channels = f['chan_names'].shape[1]
        chan_names = []
        for i in range(num_channels):
            chan_names.append(f[f['chan_names'][0,i]].value.tobytes()[::2].decode())
        #Get M1 and S1 indices
        M1_indices = [i for i in range(num_channels) if chan_names[i].split(' ')[0] == 'M1']
        S1_indices = [i for i in range(num_channels) if chan_names[i].split(' ')[0] == 'S1']
        #Get time
        t = f['t'][0,:]
        #Individually process M1 and S1 indices
        result = {}
        for indices in (M1_indices, S1_indices):
            if len(indices) == 0:
                continue
            #Get region (M1 or S1)
            region = chan_names[indices[0]].split(" ")[0]
            #Perform binning
            num_channels = len(indices)
            num_sorted_units = f["spikes"].shape[0] - 1 #The FIRST one is the 'hash' -- ignore!
            d = num_channels * num_sorted_units #d is the output dimension (total # of sorted units)
            max_t = t[-1]
            num_bins = int(np.ceil((max_t - t[0]) / bin_width_s))
            binned_spikes = np.zeros((num_bins, d), dtype=np.int)
            for chan_idx in indices: #0,...,95, for example 
                for unit_idx in range(1, num_sorted_units): #ignore hash!
                    spike_times = f[f["spikes"][unit_idx, chan_idx]].value
                    if spike_times.shape == (2,):
                        #ignore this case (no data)
                        continue
                    spike_times = spike_times[0, :] #flatten
                    spike_times = spike_times[spike_times < max_t] #get rid of extraneous t vals
                    bin_idx = np.floor((spike_times - t[0]) / bin_width_s).astype(np.int)
                    bin_idx_unique, counts = np.unique(bin_idx, return_counts=True)
                    #make sure to ignore the hash here...
                    binned_spikes[bin_idx_unique, chan_idx * num_sorted_units + unit_idx - 1] += counts
            result[region] = binned_spikes
        #Get cursor position
        #Cursor position values line up with the END of each bin
        cursor_pos = f["cursor_pos"][:].T
        cursor_pos_interp = interp1d(np.arange(len(cursor_pos)) * 0.004, cursor_pos, axis=0)
        result["cursor"] = cursor_pos_interp(np.arange(len(binned_spikes))*bin_width_s + .5*bin_width_s)
        return result



In [22]:
filename = "/home/davidclark/Project_data/NHP_reaches/indy_20160627_01.mat"
results = load_sabes_data(filename)
X = results["M1"]
cursor = results["cursor"]

ValueError: A value in x_new is above the interpolation range.