In [1]:
%matplotlib inline

from pathlib import Path
import h5py
from types import SimpleNamespace
from joblib import delayed, Parallel

import numpy as np
import pandas as pd
from scipy import signal
pd.set_option('display.max_rows', 30)

import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='whitegrid', palette='muted')

## added TreeMazeanalyses folder using the following command
## conda develop /home/alexgonzalez/Documents/TreeMazeAnalyses2
import TreeMazeAnalyses2.Utils.robust_stats as rs

from TreeMazeAnalyses2.Analyses.subject_info import SubjectInfo, SubjectSessionInfo
from TreeMazeAnalyses2.Analyses import spatial_functions as spatial_funcs
import TreeMazeAnalyses2.Analyses.open_field_functions as of_funcs

from importlib import reload

import ipywidgets as widgets
from ipywidgets import interact, fixed, interact_manual

Load data

In [3]:
data_path = Path('/home/alexgonzalez/Documents/data/butler_hardcastle')
file_name = 'grid_cell_data.h5'

f = h5py.File(data_path / file_name, "r")
n_sessions = len(f)
print(f'Number of sessions in data = {n_sessions}')    

n_neuron_counter = 0
neural_data = {}
behav_data ={}
spikes = {}
for session_id in range(n_sessions):
    behav_data[session_id] = f[f'session_{session_id+1}/behavioral_data'][:]
    spikes[session_id] = f[f'session_{session_id+1}/neural_data'][:].T
    n_neuron_counter += spikes[session_id].shape[0]

# close h5 file
f.close()
n_neurons = n_neuron_counter
print(f'Number of neurons = {n_neurons}')

Number of sessions in data = 48
Number of neurons = 94


Two main variables: 

1. `behav_data` that contains all the behavioral variables for each session:
    a. time, x, y, ha
    b. for accessing x on session 1: `behav_data[1][:,1]` 
2. `neural_data` that contains the neurons for that session.
    a. for accessing neuron 1 of session 3: `neural_data[3][1,:]`

### Explore behavior for each session

In [4]:
@interact(session_id=(0,n_sessions-1))
def explore_behavior(session_id):
    x = behav_data[session_id][:,1]
    y = behav_data[session_id][:,2]
    
    f,ax = plt.subplots()
    ax.plot(x, y)

interactive(children=(IntSlider(value=23, description='session_id', max=47), Output()), _dom_classes=('widget-…

### Define task parameters

These are mostly imported from the subject_info.py parameter structure, and adapted to this specific data set.

In [5]:
time_step = 0.02
task_params = {
                'time_step': time_step,     # time step

                # pixel params
                'x_pix_lims': [0, 150],  # camera field of view x limits [pixels]
                'y_pix_lims': [0, 150],  # camera field of view y limits [pixels]
                'x_pix_bias': 0,  # factor for centering the x pixel position
                'y_pix_bias': 0,  # factor for centering the y pixel position
                'vt_rate': 1.0 / 30.0,  # video acquisition frame rate
                'xy_pix_rot_rad': 0,  # rotation of original xy pix camera to experimenter xy

                # conversion params
                'x_pix_mm': 10,  # pixels to mm for the x axis [pix/mm]
                'y_pix_mm': 10,  # pixels to mm for the y axis [pix/mm]
                'x_mm_bias': 0,  # factor for centering the x mm position
                'y_mm_bias': 0,  # factor for centering the y mm position
                'x_mm_lims': [0, 1500],  # limits on the x axis of the maze [mm]
                'y_mm_lims': [0, 1500],  # limits on the y axis of the maze [mm]
                'x_cm_lims': [0, 150],  # limits on the x axis of the maze [cm]
                'y_cm_lims': [0, 150],  # limits on the y axis of the maze [cm]

                # binning parameters
                'mm_bin': 30,  # millimeters per bin [mm]
                'cm_bin': 3,  # cm per bin [cm]
                'max_speed_thr': 80,  # max speed threshold for allowing valid movement [cm/s]
                'min_speed_thr': 2,  # min speed threshold for allowing valid movement [cm/s]
                'rad_bin': np.deg2rad(10),  # angle radians per bin [rad]
                'occ_num_thr': 3,           # number of occupation times threshold [bins
                'occ_time_thr': time_step * 3,  # time occupation threshold [sec]
                'speed_bin': 2,                # speed bin size [cm/s]

                # filtering parameters
                'spatial_sigma': 2,  # spatial smoothing sigma factor [au]
                'spatial_window_size': 5,  # number of spatial position bins to smooth [bins]
                'temporal_window_size': 11,  # smoothing temporal window for filtering [bins]
                'temporal_angle_window_size': 11,  # smoothing temporal window for angles [bins]
                'temporal_window_type': 'hann',  # window type for temporal window smoothing
}

# derived parameters
task_params['filter_coef_'] = signal.get_window(task_params['temporal_window_type'],
                                                            task_params['temporal_window_size'],
                                                            fftbins=False)
task_params['filter_coef_'] /= task_params['filter_coef_'].sum()

task_params['filter_coef_angle_'] = signal.get_window(task_params['temporal_window_type'],
                                                      task_params['temporal_angle_window_size'],
                                                      fftbins=False)
task_params['filter_coef_angle_'] /= task_params['filter_coef_angle_'].sum()

# -- bins --
task_params['ang_bin_edges_'] = np.arange(0, 2*np.pi+task_params['rad_bin'], task_params['rad_bin'])
task_params['ang_bin_centers_'] = task_params['ang_bin_edges_'][:-1] + task_params['rad_bin']/2
task_params['n_ang_bins'] = len(task_params['ang_bin_centers_'])

task_params['sp_bin_edges_'] = np.arange(task_params['min_speed_thr'],
                                         task_params['max_speed_thr'] + task_params['speed_bin'],
                                         task_params['speed_bin'])
task_params['sp_bin_centers_'] = task_params['sp_bin_edges_'][:-1]+task_params['speed_bin']/2
task_params['n_sp_bins'] = len(task_params['sp_bin_centers_'])

task_params['x_bin_edges_'] = np.arange(task_params['x_cm_lims'][0],
                                        task_params['x_cm_lims'][1]+task_params['cm_bin'],
                                        task_params['cm_bin'])
task_params['x_bin_centers_'] = task_params['x_bin_edges_'][:-1] + task_params['cm_bin']/2
task_params['n_x_bins'] = len(task_params['x_bin_centers_'])
task_params['n_width_bins'] = task_params['n_x_bins']

task_params['y_bin_edges_'] = np.arange(task_params['y_cm_lims'][0],
                                        task_params['y_cm_lims'][1] + task_params['cm_bin'],
                                        task_params['cm_bin'])
task_params['y_bin_centers_'] = task_params['y_bin_edges_'][:-1] + task_params['cm_bin']/2
task_params['n_y_bins'] = len(task_params['y_bin_centers_'])
task_params['n_height_bins'] = task_params['n_y_bins']

task_params = SimpleNamespace(**task_params)

## Process all behavioral data. 
-- this takes a bit of time

In [169]:
of_funcs = reload(of_funcs)
def _pworker(session_id):   
    x = behav_data[session_id][:,1]
    y = behav_data[session_id][:,2]
    ha = behav_data[session_id][:,3]
    
    x2,y2,ha2 = of_funcs._process_track_data(x,y,ha,task_params)
    
    x2 /= 10  # convert to cm
    y2 /= 10  # convert to cm
    speed, hd = spatial_funcs.compute_velocity(x2, y2, task_params.time_step)
    hd = np.mod(hd, 2 * np.pi)  # convert to 0 to 2pi

    return x2, y2, ha2, hd, speed

res = Parallel(n_jobs=8)(delayed(_pworker)(session_id) for session_id in range(n_sessions))

In [170]:
# conver to behavioral data time series to a pandas df
behav_ts = {}
for session_id in range(n_sessions):
    behav_ts[session_id] = pd.DataFrame(np.array(res[session_id]).T, columns=['x','y','ha', 'hd','speed'])


In [171]:
# create map function
def get_position_maps(x,y,x_bins,y_bins, time_step=0.02, window_size=5, spatial_sigma=2, occ_num_thr=3):
    """
        Generates 2 d maps of occupation from x,y time series.
    """
    
    pos_map_counts = spatial_funcs.histogram_2d(x, y, x_bins, y_bins)
    pos_map_sm = spatial_funcs.smooth_2d_map(pos_map_counts, window_size, spatial_sigma)
    
    pos_valid_maks = pos_map_counts >= occ_num_thr
    
    pos_map_secs = pos_map_counts*time_step
    pos_map_secs = spatial_funcs.smooth_2d_map(pos_map_secs, window_size, spatial_sigma)
    
    maps = {'counts': pos_map_counts,
            'counts_sm': pos_map_sm,
            'valid_mask': pos_valid_maks,
            'secs': pos_map_secs}
    return maps

# run on all sessions
behav_maps = {}
for session_id in range(n_sessions):
    behav_maps[session_id] = get_position_maps(behav_ts[session_id]['x'], behav_ts[session_id]['y'], 
                                               task_params.x_bin_edges_, 
                                               task_params.y_bin_edges_, 
                                               time_step=task_params.time_step,
                                               window_size=task_params.spatial_window_size, 
                                               spatial_sigma=task_params.spatial_sigma, 
                                               occ_num_thr=task_params.occ_num_thr)
      

### compare traces to original

In [172]:
@interact(session_id=(0,n_sessions-1))
def comp_traces(session_id):
    x = behav_data[session_id][:,1]
    y = behav_data[session_id][:,2]
    
    x2 = behav_ts[session_id]['x']
    y2 = behav_ts[session_id]['y']
    
    f,ax = plt.subplots(1,3,figsize=(15,5))
    ax[0].plot(x, y, linewidth=0.5)
    ax[0].set_title('Original Samps')
    
    ax[1].plot(x2, y2, linewidth=0.5)
    ax[1].set_title('Processed Samps')
    
    ax[2].scatter(x, x2, 2, alpha=0.25, label='x')
    ax[2].scatter(y, y2, 2, alpha=0.25, label='y')
    ax[2].set_title('Relationship')
    ax[2].legend()


interactive(children=(IntSlider(value=23, description='session_id', max=47), Output()), _dom_classes=('widget-…

In [179]:
def plot_map(map_, ax=None):
    if ax is None:
        f, ax = plt.subplots()
    ax = sns.heatmap(map_, ax=ax)
    ax.invert_yaxis()
    ax.axis('equal')
    return ax.figure, ax

@interact(session_id=(0,n_sessions-1), key=behav_maps[0].keys())
def _maps(session_id, key):
    x = behav_ts[session_id]['x']
    y = behav_ts[session_id]['y']
    
    f,ax = plt.subplots(1,2,figsize=(12,5))
    
    ax[0].plot(x,y, linewidth=0.5)
    plot_map(behav_maps[session_id][key], ax=ax[1])
    

interactive(children=(IntSlider(value=23, description='session_id', max=47), Dropdown(description='key', optio…

### Process spikes
`neural_data` contains the binned spike time series at 20ms increments, for an effective sampling rate of 50 samples per second. firing rate is a smoothed version of binned spikes and easier to use for most analyses.

In [209]:
neural_data[0].shape

(1, 75157)

In [217]:
def get_fr(bin_spikes, time_step=0.02, temporal_smoothing=0.125):
    
    # define filter.
    filter_len = np.round(temporal_smoothing / time_step).astype(int)
    filt_coeff = signal.windows.hann(filter_len)  # banishing filter
    filt_coeff /= filt_coeff.sum()  # normalize to conserve signal energy
    
    n_units, n_timebins = bin_spikes.shape
    fr = np.zeros((n_units, n_timebins))
    for unit in range(n_units):
        fr[unit] = signal.filtfilt(filt_coeff, 1, bin_spikes[unit] / time_step)
        
    return fr

# get fr
fr = {}
for session_id in range(n_sessions):
    fr[session_id] = get_fr(spikes[session_id], 
                            time_step=task_params.time_step, 
                            temporal_smoothing=0.125
                           )

# get fr maps
neural_maps = {}
map_dims = behav_maps[0]['counts'].shape
for session_id in range(n_sessions):
    n_session_neurons = fr[session_id].shape[0]
    neural_maps[session_id] = np.zeros((n_session_neurons,map_dims[0], map_dims[1]))
    for neuron_id in range(n_session_neurons):
        neural_maps[session_id][neuron_id] = \
            spatial_funcs.firing_rate_2_rate_map(fr[session_id][neuron_id], 
                                                 x=behav_ts[session_id]['x'], 
                                                 y=behav_ts[session_id]['y'], 
                                                 x_edges=task_params.x_bin_edges_,
                                                 y_edges=task_params.y_bin_edges_,
                                                )

In [213]:
@interact(session_id=(0,n_sessions-1), neuron_id=widgets.IntSlider(min=0, max=5, step=1, value=0))
def _maps(session_id, neuron_id):
    n_session_neurons = neural_maps[session_id].shape[0]
    print(f'num neurons = {n_session_neurons}')
    if neuron_id < n_session_neurons:
        f,ax = plt.subplots(figsize=(7,6))
        plot_map(neural_maps[session_id][neuron_id], ax=ax)
        ax.set_title('Rate Map')

interactive(children=(IntSlider(value=23, description='session_id', max=47), IntSlider(value=0, description='n…

### Save Data

In [218]:
import pickle
data_path = Path('/home/alexgonzalez/Documents/data/butler_hardcastle')
file_name = 'grid_cell_data.pickle'

data = {'behav_ts':behav_ts, 'behav_maps': behav_maps, 'neural_maps':neural_maps, 'spikes': spikes, 'fr':fr}
with (data_path/file_name).open(mode='wb') as f:
    pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)


In [219]:
fr

{0: array([[0., 0., 0., ..., 0., 0., 0.]]),
 1: array([[50.        , 23.45491503,  6.90983006, ...,  0.        ,
          0.        ,  0.        ]]),
 2: array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 3: array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 4: array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 5: array([[0.        , 0.        , 0.        , ..., 0.95491503, 0.        ,
         0.        ],
        [0.        , 0.        , 0.95491503, ..., 0.        , 0.        ,
         0.        ],
        [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
         0.        ],
        [0.        , 0.        , 0.        , ..., 0.        , 0.        ,
         0.        ]]),
 6: array([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]]),
 7: array([[0.        , 0.        , 0.95491503, ..., 0.        , 0.        ,
         0.        ]]),
 8: array([[0., 0., 0., ...