#  running PCA + saving results

**results dictionary:**
- dict
    - session
        - trial df
        - full
            - ['evals', 'evecs', 'PR', 'SS stim', 'SS time', 'project', 'stim labels', 'time labels', 'state labels', 'ss state']
        - state 0/1
            - ['evals', 'evecs', 'PR', 'SS stim', 'SS time', 'project', 'stim labels', 'time labels']

# PARAMS
    session_id = 1139846596
    amplitude_cutoff_maximum = 0.1
    presence_ratio_minimum = 0.9
    isi_violations_maximum = 0.5
    region = 'VISp'
    post_stim_dur = 0 (or 0.5 for full 750ms)

# setup

In [1]:
import warnings
warnings.filterwarnings("ignore")
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
%matplotlib inline
import plotly.express as px
from sklearn.metrics import silhouette_score
from data.load_data import *
from PCA_utils import *
from collections import defaultdict
import time
import pickle
%load_ext autoreload
%autoreload 2

matching trials across contexts
- `trials_id` in `trial_df` has unique id for each trial type in presentation order; repeats for each condition i.e. **relative**
- made an additional index of unique trial identifers, different for active vs passive; called **absolute**

TODOs:
- run analysis on active vs passive as states (skip HMM states import)
- run analysis on all neurons (not just VISp)
    - FD wants VISpme2
- plots / interpretations
- cell type weights


ipad todos
- AvP: SS plot,

In [16]:
HMM_files = os.listdir('../analysis_data')
full_HMMs = [x for x in HMM_files if ('AP' in x) & ('new' in x)]  # HMM run on active and passive together

In [18]:
st = time.time()

cache = load_cache_behavior_neuropixel()

full_dict = defaultdict(dict)  # dictionary to hold data from all sessions

for HMM_filename in full_HMMs:
    
    session_id = HMM_filename[2:12]
    
    print('working on session', session_id, '...')
    
    session = cache.get_ecephys_session(session_id)
    
    trial_df = pd.read_feather(os.path.join('../analysis_data', HMM_filename))
    trial_df.set_index('stimulus_presentations_id', inplace=True)
    trial_df['state'] = trial_df['state'].apply(int).apply(str) #convert to string so classifier functions don't get confused

    # turn relative trial ideas (repeat for active+passive) into absolute/unique
    num_trial_ids = len(trial_df.trials_id.unique())
    unique_trial_ids = np.arange(0,num_trial_ids*2)
    trial_df['abs_trial_id'] = np.repeat(unique_trial_ids, 4)
    
    # get FRs
    counts_df = get_spikes(cache, session, trial_df, region='VISp', post_stim_dur=0)
    
    # do PCA
    session_dict = create_session_dict(trial_df, counts_df)
    
    full_dict[session_id] = session_dict
    
print('time elapsed:', time.time()-st)

working on session 1053925378 ...
working on session 1064415305 ...
working on session 1081090969 ...
working on session 1108334384 ...
working on session 1115356973 ...
time elapsed: 366.2976927757263


In [19]:
with open('HMM_dict_VISp_new.pkl', 'wb') as file:
    pickle.dump(full_dict, file)

In [None]:
# same but states are active and passive instead of HMM-defined
st = time.time()

cache = load_cache_behavior_neuropixel()
full_dict = defaultdict(dict)  # dictionary to hold data from all sessions
session_ids = [x[2:12] for x in HMM_files if 'AP' in x]

for session_id in session_ids:
    
    print('working on session', session_id, '...')
    session = cache.get_ecephys_session(session_id)
    
    trial_df = get_trial_df(session)
    trial_df['state'] = (~trial_df.active).apply(int).apply(str)  # <- here is where states are defined

    # turn relative trial ideas (repeat for active+passive) into absolute/unique
    num_trial_ids = len(trial_df.trials_id.unique())
    unique_trial_ids = np.arange(0,num_trial_ids*2)
    trial_df['abs_trial_id'] = np.repeat(unique_trial_ids, 4)
    
    # get FRs
    counts_df = get_spikes(cache, session, trial_df, region='VISp', post_stim_dur=0)
    
    # do PCA
    session_dict = create_session_dict(trial_df, counts_df)
    
    full_dict[session_id] = session_dict
    
print('time elapsed:', time.time()-st)

working on session 1053925378 ...
working on session 1064415305 ...
working on session 1081090969 ...
working on session 1108334384 ...
working on session 1115356973 ...


# expanded code
(before everything was function-ized and put in PCA_utils.py)

### spikes

In [None]:
# get spike times for all units recorded
spike_times = session.spike_times # dict

# get unit metadata for this session + apply quality metrics
# PARAMS
all_units = session.get_units(
    amplitude_cutoff_maximum = 0.1, 
    presence_ratio_minimum = 0.9,
    isi_violations_maximum = 0.5
)

# merge to channel data to match units to brain regions
channels = cache.get_channel_table()
unit_channels = all_units.merge(channels, left_on='peak_channel_id', right_index=True)

# to filter by region
units_df = unit_channels.loc[unit_channels.structure_acronym=='VISp'] #PARAMS

for each trial, count spikes in stim windows and average across the 4 stims

design matrix will look like this for each condition/state

| | unit 1 | unit 2 | unit 3 | ... | unit N |
|--|--|--|--|--|--|
| trial 1 |
| trial 2 |
| trial 3 |
| ... |
| trial M |

In [None]:
post_stim_dur = 0 #PARAMS

In [None]:
# get spike counts within each stim pres window
spike_mat = np.zeros([len(trial_df), len(units_df)])

for i,unit in enumerate(units_df.index): #for each neuron 1:N...
    spikes = spike_times[unit]
    counts = [] #initialize column vector

    for start,end in zip(trial_df.start_time, trial_df.end_time): #for each stim presentation...
        startInd = np.searchsorted(spikes, start)
        endInd = np.searchsorted(spikes, end+post_stim_dur)
        rel_spike_times = spikes[startInd:endInd]-start #relative spike times in this window
        count = len(rel_spike_times)
        counts.append(count) #append spike counts for this stim pres
        
    spike_mat[:,i] = counts #add column vector of FRs for this neuron to spike matrix
%matplotlib inline
plt.figure(figsize=(14,2))
plt.plot(spike_mat, alpha=.5);

In [None]:
counts_df = pd.DataFrame(
    data = spike_mat,
    index = trial_df.index,
    columns = units_df.index
)
counts_df['abs_trial_id'] = trial_df.abs_trial_id

## PCA

In [None]:
session_dict = create_session_dict(trial_df, counts_df)

In [None]:
plot_state_pcas(session_dict)

In [None]:
plot_full_pca(session_dict)

### old functions

In [None]:
def get_dm(counts_df, state_stims, agg_over_tr=True):
    ''' get max-normalized design matrix
    '''
    if agg_over_tr:
        FRs = counts_df.loc[state_stims].groupby('abs_trial_id').agg('mean')
        FRs_normed = FRs / FRs.max()
    else:
        FRs_normed = counts_df.loc[state_stims] / counts_df.loc[state_stims].max()
    
    return FRs_normed

In [None]:
def do_pca(dm):
    # input should be features X samples
    cov_mat = np.cov(dm)
    evals, evecs = np.linalg.eig(cov_mat)

    return evals, evecs

In [None]:
def participation_ratio(evals):
    return (np.sum(evals)**2) / np.sum(evals**2)

In [None]:
def create_session_dict(trial_df, counts_df):
    session_dict = {}
    
    states = trial_df.state.unique()

    for state in states:
        state_inds = trial_df.loc[trial_df.state==state].index
        state_dm = get_dm(counts_df, state_inds)
        evals, evecs = do_pca(state_dm.T)

        pr = participation_ratio(evals)

        # mean-center the data + project into PC space
        state_cent = state_dm - state_dm.mean()
        project = np.dot(np.transpose(evecs[:,0:3]), state_cent.T)

        # stim analysis
        int_labels = trial_df.loc[state_inds].groupby(['abs_trial_id', 'image_int']).agg('max').index.get_level_values(1).to_numpy()
        img_labels = trial_df.loc[state_inds].groupby(['abs_trial_id', 'image_name']).agg('max').index.get_level_values(1).to_numpy()
        ss_stim = silhouette_score(X=state_dm,labels=img_labels)

        # time analysis
        trial_labels = trial_df.loc[state_inds].groupby('abs_trial_id').agg('max').index.to_numpy()
        trial_splits = np.array_split(trial_labels, 3) # split time into 1st, 2nd, 3rd chunks for classification
        splits_labels = []
        for split in range(3):
            splits_labels += ([str(split)] * len(trial_splits[split]))
        ss_time = silhouette_score(X=state_dm,labels=splits_labels)

        state_active = trial_df.loc[state_inds].active
        prop_active = sum(state_active) / len(state_active)

        session_dict[state] = {
            'evals' : evals,
            'evecs' : evecs,
            'PR' : pr,
            'SS stim' : ss_stim,
            'SS time' : ss_time,
            'project' : project,
            'stim labels' : int_labels,
            'time labels' : trial_labels,
            '% active' : prop_active,
            'trial df' : trial_df
        }
    
    return session_dict

In [None]:
def plot_state_pcas(session_dict):

    states = session_dict['trial df'].state.unique()
    
    %matplotlib inline
    f1,ax1 = plt.subplots(1,len(states), figsize=(8,3.2))
    f2,ax2 = plt.subplots(1,len(states), figsize=(8,3))

    for state in states:

        s = session_dict[state]

        ax1[int(state)].scatter(s['project'][0], s['project'][1], c=s['stim labels'])
        ax1[int(state)].set(xlabel='PC1',ylabel='PC2',title='state '+state)

        ax2[int(state)].scatter(s['project'][0], s['project'][1], c=s['time labels'])
        ax2[int(state)].set(xlabel='PC1',ylabel='PC2',title='state '+state)

    f1.suptitle(str(session_id) + '\ncolored by stimulus')
    f1.tight_layout()
    f2.suptitle('colored by time')
    f2.tight_layout()

In [None]:
def plot_full_pca(session_dict):

    s = session_dict['full']
    f,ax = plt.subplots(1,len(states), figsize=(8,3.2))

    ax[0].scatter(s['project'][0], s['project'][1], c=s['stim labels'])
    ax[0].set(xlabel='PC1',ylabel='PC2',title='colored by stim')

    ax[1].scatter(s['project'][0], s['project'][1], c=s['time labels'])
    ax[1].set(xlabel='PC1',ylabel='PC2',title='colored by trial')

    ax[2].scatter(s['project'][0], s['project'][1], c=s['state labels'])
    ax[2].set(xlabel='PC1',ylabel='PC2',title='colored by state')

    f.suptitle(session_id)
    f.tight_layout()

## split by stim pres within trial

for each trial, count spikes in stim windows
each row will be 1 stim presentation (i.e. dims will be like stim df)

design matrix will look like this for each condition/state

| | unit 1 | unit 2 | unit 3 | ... | unit N |
|--|--|--|--|--|--|
| trial 1 stim 1|
| trial 1 stim 2 |
| trial 1 stim 3 |
| trial 1 stim 4 |
| trial 2 stim 1 |
| ... |
| trial M stim 4 |

**i.e. spike_mat, before averaging within trial**

In [None]:
all_evals = np.zeros((4, stim1_dm.shape[1])) #stim X neurons
all_evecs = np.zeros((4, stim1_dm.shape[1], stim1_dm.shape[1]))

for stim in range(4):
    stim_inds = counts_df.index[stim::4]
    stim_dm = get_dm(counts_df, stim_inds, agg_over_tr=False)
    evals, evecs = do_pca(stim_dm.T)
    
    all_evals[stim,:] = evals
    all_evecs[stim,:,:] = evecs
    
    print('stim', stim+1)
    print('    participation ratio:', participation_ratio(evals))
    print('    silhouette score:', silhouette_score(X=stim_dm,labels=list(stims_img)*2))

In [None]:
%matplotlib inline
f,ax = plt.subplots(1,1)

for stim in range(4):
    frac_expl = np.cumsum(all_evals[stim])/np.sum(all_evals[stim])
    ax.plot(frac_expl)
    ax.set(xlabel='Number of PCs', ylabel='Fraction of Variance Explained')
    ax.legend(labels=[1,2,3,4])

In [None]:
%matplotlib widget
f = plt.figure()
# ax = f.add_subplot(projection='3d')

for stim in range(4):
    stim_cent = stim_dm - stim_dm.mean() # mean-center the data
    project = np.dot(np.transpose(all_evecs[stim,:,0:3]), stim_cent.T) # project into PC space
    ax = f.add_subplot(2, 2, stim+1, projection='3d')
    ax.scatter(project[0], project[1], project[2], c=list(stims)*2);
    ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='colored by stim');
plt.show()

In [None]:
np.zeros(counts_df.shape).shape

In [None]:
full_stim_dm = np.zeros(counts_df.shape)
for stim in range(4):
    stim_inds = counts_df.index[stim::4]
    stim_dm = get_dm(counts_df, stim_inds, agg_over_tr=False)
    
    start_ind = stim * len(stim_inds)
    end_ind = stim * len(stim_inds) + len(stim_inds)
    
    full_stim_dm[start_ind:end_ind, :] = stim_dm

In [None]:
evals, evecs = do_pca(full_stim_dm.T)
participation_ratio(evals)

In [None]:
evals.shape

In [None]:
evecs.shape

In [None]:
stim_cent = full_stim_dm - full_stim_dm.mean() # mean-center the data
project = np.dot(np.transpose(evecs[:,0:3]), stim_cent.T) # project into PC space

In [None]:
import seaborn as sns
from matplotlib.colors import ListedColormap

In [None]:
cmap = ListedColormap(sns.color_palette('bright'))

In [None]:
project.shape

In [None]:
# plot and color by stim
%matplotlib widget
f = plt.figure()
ax = f.add_subplot(projection='3d')

stim_cent = full_stim_dm - full_stim_dm.mean() # mean-center the data
project = np.dot(np.transpose(evecs[:,0:3]), stim_cent.T) # project into PC space
ax.scatter(project[0], project[1], project[2], c=np.repeat(np.arange(1,5), len(stim_inds)));
ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='colored by stim order');
plt.show()

In [None]:
silhouette_score(
    X=full_stim_dm,
    labels=np.repeat(np.arange(2,6), len(stim_inds)))

## separate DM for each state

In [None]:
# stim presentation ids for active vs passive
state1_stims = trial_df.index[trial_df.active] #active
state2_stims = trial_df.index[~trial_df.active] #passive

In [None]:
counts_df['trials_id'] = trial_df.trials_id
state1_dm = get_dm(counts_df, state1_stims)
state2_dm = get_dm(counts_df, state2_stims)

In [None]:
state1_dm.shape

In [None]:
np.cov(state1_dm).shape

In [None]:
evals1, evecs1 = do_pca(state1_dm.T)
evals2, evecs2 = do_pca(state2_dm.T)

### participation ratio
"effective dimensionality"

In [None]:
participation_ratio(evals1)

In [None]:
participation_ratio(evals2)

### plots

In [None]:
# variance explained
%matplotlib inline
plt.plot(evals1)
plt.plot(evals2)
plt.xlim((-.5,20));
plt.show()

In [None]:
# variance explained curves (cumulative)
frac_expl1 = np.cumsum(evals1)/np.sum(evals1)
frac_expl2 = np.cumsum(evals2)/np.sum(evals2)

%matplotlib inline
plt.plot(frac_expl1)
plt.plot(frac_expl2)
plt.xlabel('Number of PCs')
plt.ylabel('Fraction of Variance Explained')
# plt.xlim((-1,80));
plt.show()

In [None]:
frac = .90
print(np.where(frac_expl1>frac)[0].min(), 'PCs needed for', frac*100, '% for active condition')
print(np.where(frac_expl2>frac)[0].min(), 'PCs needed for', frac*100, '% for passive condition')

In [None]:
# state 1 vs state 2 variance explained across all PCs
%matplotlib inline
plt.figure(figsize=(10,5))
plt.plot(frac_expl2 - frac_expl1)
plt.axhline(0, color='k', linestyle='--');
plt.show()
# plt.xlim((-.5,11));

In [None]:
%matplotlib inline
f,ax = plt.subplots(3,1, figsize=(5,5))
for PC in range(3):
    ax[PC].plot(abs(evecs1[:,PC]))
    ax[PC].plot(abs(evecs2[:,PC]))
    ax[PC].set(ylabel="contribution", title=PC);
plt.tight_layout()
plt.xlabel('unit')
plt.show()

In [None]:
# control/check: are high-weighted neurons also high-firing neurons?
%matplotlib inline
plt.scatter(spike_mat.mean(axis=0), abs(evecs1[:,1]), color='b')
plt.scatter(spike_mat.mean(axis=0), abs(evecs2[:,1]), color='b', alpha=.25)
plt.xlabel('mean FR')
plt.ylabel('PC1 weight');

In [None]:
%matplotlib widget

In [None]:
# get list of images in order of presentation

stims = trial_df.loc[state1_stims].groupby('abs_trial_id').agg('max').image_int.values

stims_img = trial_df.loc[state1_stims].groupby(['abs_trial_id', 'image_name']
                                          ).agg('max').index.get_level_values(1).to_numpy()

In [None]:
# mean-center the data
state1_cent = state1_dm - state1_dm.mean()
state2_cent = state2_dm - state2_dm.mean()

# project into PC space
project1 = np.dot(np.transpose(evecs1[:,0:3]), state1_cent.T)
project2 = np.dot(np.transpose(evecs2[:,0:3]), state2_cent.T)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(project1[0], project1[1], project1[2], c=stims);
ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='active; colored by stim');
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(project2[0], project2[1], project2[2], c=stims)
ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='passive; colored by stim');
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(project1[0], project1[1], project1[2], c=np.arange(0,len(stims)));
ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='active; colored by trial number');
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(project2[0], project2[1], project2[2], c=np.arange(0,len(stims)))
ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='passive; colored by trial number');
plt.show()

## full DM

In [None]:
full_dm = pd.concat([state1_dm, state2_dm])

In [None]:
evals_full, evecs_full = do_pca(full_dm.T)

In [None]:
plt.figure()
plt.plot(evals_full)
plt.show()

In [None]:
plt.figure()
plt.plot(np.cumsum(evals_full)/np.sum(evals_full));
plt.show()

In [None]:
frac_expl_full = np.cumsum(evals_full)/np.sum(evals_full)

In [None]:
print(np.where(frac_expl_full>frac)[0].min(), 'PCs needed for', frac*100, '% for full (active+passive)')

In [None]:
full_cent = full_dm - full_dm.mean()
project_full = np.dot(np.transpose(evecs_full[:,0:3]), full_cent.T)

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(project_full[0], project_full[1], project_full[2], c=list(stims)*2);
ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='full; colored by stim');
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(project_full[0], project_full[1], project_full[2], c=full_dm.index);
ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='full; colored by time');
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(project_full[0], project_full[1], project_full[2], c=np.repeat([0,1], len(state1_stims) / 4));
ax.set(xlabel='PC1', ylabel='PC2', zlabel='PC3', title='full; colored by state');
plt.show()

## dim metrics

### silhouette score
+2 to stims bc something gets weird when 0 and 1 are included as classes; can use image names instead

In [None]:
silhouette_score(
    X=state1_dm,
    labels=stims_img)

In [None]:
silhouette_score(
    X=state2_dm,
    labels=stims_img)

In [None]:
silhouette_score(
    X=full_dm,
    labels=all_stims+2)

### decoding

kozleo

"For sample decoding experiments, we first smoothed trials (as above, with 10ms Gaussian kernel), and then took average rates per neuron in 50ms time bins and attempted to decode sample type from the population activity at each time point. For our classifier, we used a linear Support Vector Machine and standard regularization (C = 1 in SciKit Learn). For each session we used a 10-fold cross-validation to calculate decoding accuracy. (That is, we held out 10% of trials as a test set, and trained on the remaining 90% of the data. We did this for 10 non-overlapping test sets, and then took the average test accuracy). We performed a decoding analysis on each session separately, and then took the average and standard error across sessions for our result."

In [None]:
from sklearn.model_selection import KFold
from sklearn.svm import SVC

In [None]:
# ran for each timepoint

In [None]:
def run_svc(y, X, svc_type="linear", n_split=10, C=1):
    """
        Train SVM classifier on data and return test accuracy.

    Args:
        y: (ndarray) Labels array with dims (samples,)
        X: (ndarray) Covariates with dims (samples, features);
            features are typically some function of neural data
        svc_type: (str) Option for SVC type. Options are "linear" or "rbf".
            Default is "linear".
        n_split: (int) Number of splits for cross-validation. Default is 10.
        C: (float) Regularization parameter for SVC. Default is 1.

    Returns:
        res_test: (float) Resulting test accuracy of trained classifier on test
            data.
    """
    kf = KFold(n_splits=n_split, shuffle=True, random_state=0)
    kf.get_n_splits(X, y)
    y_true = []
    y_pred = []

    for train_index, test_index in kf.split(X):
        X_train, X_test = X[train_index], X[test_index]
        y_train, y_test = y[train_index], y[test_index]
        if svc_type == "linear":
            clf = SVC(kernel="linear", C=C)
        if svc_type == "rbf":
            clf = SVC(C=C)
        clf.fit(X_train, y_train)
        pred = clf.predict(X_test)

        y_true.append(y_test)
        y_pred.append(pred)

    true = np.concatenate(y_true)
    pred = np.concatenate(y_pred)

    test_acc = (pred == true).sum() / true.size
    return test_acc

In [None]:
all_stims = trial_df.groupby('abs_trial_id').agg('max').image_int.values

print('decoding stim:')
run_svc(
    X=full_dm.to_numpy(),
    y=all_stims+2)

In [None]:
states = trial_df.active.values*1
states[states] = 2
states[~states] = 3

print('decoding state:')
run_svc(
    X=full_dm.to_numpy(),
    y=states)

In [None]:
coded_trials = np.zeros_like(full_dm.index.values)
for t in reversed(range(0, int(np.ceil(len(trials)/100)))):
    cap = t*100+100
    coded_trials[trials<=cap] = cap
print('decoding time (blocks of 100 trials):')
run_svc(
    X=full_dm.to_numpy(),
    y=coded_trials)

### other

In [None]:
# visualize how many trials we're actually using
all_stims = session.stimulus_presentations
all_stims['analyzed'] = 0
all_stims.loc[trial_df.index, 'analyzed'] = 1
plt.figure(figsize=(20,2))
plt.plot(all_stims.analyzed);

ideas
- run on gray screens instead of stim windows
- run stim 1 vs stim 2 vs stim 3 vs stim 4
- project passive onto active PC space + vice versa
- UMAP

noise corr stuff says decodeability might not change but direction of noise fluctuations is changing

In [None]:
import xarray as xr
xr.DataArray(
    data = spike_mat,
    dims = ('trial_num', 'unit')
)

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
stimulus_presentations = session.stimulus_presentations
change_times = stimulus_presentations[stimulus_presentations['active']&
                            stimulus_presentations['is_change']]['start_time'].values

In [None]:
#first let's sort our units by depth
unit_channels = unit_channels.sort_values('probe_vertical_position', ascending=False)

#now we'll filter them
good_unit_filter = ((unit_channels['snr']>1)&
                    (unit_channels['isi_violations']<1)&
                    (unit_channels['firing_rate']>0.1))

good_units = unit_channels.loc[good_unit_filter]
spike_times = session.spike_times

In [None]:
#Convenience function to compute the PSTH
def makePSTH(spikes, startTimes, windowDur, binSize=0.001):
    bins = np.arange(0,windowDur+binSize,binSize)
    counts = np.zeros(bins.size-1)
    for i,start in enumerate(startTimes):
        startInd = np.searchsorted(spikes, start)
        endInd = np.searchsorted(spikes, start+windowDur)
        counts = counts + np.histogram(spikes[startInd:endInd]-start, bins)[0]

    counts = counts/startTimes.size
    return counts/binSize, bins

In [None]:
#Here's where we loop through the units in our area of interest and compute their PSTHs
area_of_interest = 'VISp'
area_change_responses = []
area_units = good_units[good_units['structure_acronym']==area_of_interest]
time_before_change = 1
duration = 2.5
for iu, unit in area_units.iterrows():
    unit_spike_times = spike_times[iu]
    unit_change_response, bins = makePSTH(unit_spike_times,
                                          change_times-time_before_change,
                                          duration, binSize=0.01)
    area_change_responses.append(unit_change_response)
area_change_responses = np.array(area_change_responses)

In [None]:
#Plot the results
fig, ax = plt.subplots(1,2)
fig.set_size_inches([12,4])

clims = [np.percentile(area_change_responses, p) for p in (0.1,99.9)]
im = ax[0].imshow(area_change_responses, clim=clims)
ax[0].set_title('Active Change Responses for {}'.format(area_of_interest))
ax[0].set_ylabel('Unit number, sorted by depth')
ax[0].set_xlabel('Time from change (s)')
ax[0].set_xticks(np.arange(0, bins.size-1, 20))
_ = ax[0].set_xticklabels(np.round(bins[:-1:20]-time_before_change, 2))

ax[1].plot(bins[:-1]-time_before_change, np.mean(area_change_responses, axis=0), 'k')
ax[1].set_title('{} population active change response (n={})'\
                .format(area_of_interest, area_change_responses.shape[0]))
ax[1].set_xlabel('Time from change (s)')
ax[1].set_ylabel('Firing Rate')

# old code

## setup

    git status
    git pull origin main
    git push origin chloe

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import seaborn as sns
import pandas as pd
pd.set_option('display.max_columns', None)

from data.load_data import *

In [None]:
cache = load_cache_behavior_neuropixel()

__vbn precomputed tables__: 

Inside, you will find a few units tables with precomputed metrics that might save you some time. For example, in the master_units table, there's a column called 'structure_acronym_with_layer' that includes the CCF layer assignment for all cortical units. The 'unit_id' column in this table is the same ID as the index of the units table you get back from the SDK.

^^ not sure how to laod

### `session` data
`session.trials` : raw; stim info; start and stop times, image IDs, licks, outcome; 'trial'=stim

`behavior_data` : constructed from `session.running_speed.timestamps`, `session.running_speed.speed`, `session.eye_tracking.timestamps`, `session.eye_tracking.pupil_area`

In [None]:
example_sessions = [1139846596,  1124507277, 1069461581 ]
session_id = example_sessions[0]

In [None]:
# session is the variable that has all the data we need
session = cache.get_ecephys_session(session_id)

## stim

In [None]:
stim_df = session.stimulus_presentations
pres_df = stim_df.loc[(stim_df.stimulus_block==0)|(stim_df.stimulus_block==5)]
change_inds = pres_df.loc[pres_df.is_change==True].index

In [None]:
change_inds

In [None]:
images = sorted( pres_df.image_name.unique() )

image_to_int = dict()
for i, image in enumerate(images):
    if image=='omitted':
        image_to_int[image] = -10
    else:
        image_to_int[image] = i

pres_df['image_int'] = pres_df.image_name.apply( lambda img: image_to_int[img] )

pres_df.head()

In [None]:
# get indices of last 4 images before change
trials = np.zeros((len(change_inds), 4))
for trial,ind in enumerate(change_inds):
    trials[trial,:] = np.arange(ind-4,ind)

In [None]:
trials

In [None]:
# remove cases where 1 of the 4 stims was removed
trials_to_keep = []
for trial in trials:
    if pres_df.loc[trial].omitted.sum() == 0: #if no stims omitted
        trials_to_keep.append(trial)

trial_inds_arr = np.vstack(trials_to_keep)
trial_inds_vec = np.concatenate(trials_to_keep)

In [None]:
pres_df.loc[trial_inds_vec]

In [None]:
df = pres_df.loc[trial_inds_vec].filter([
    'active', 
    'trials_id',
    'start_time',
    'end_time',
    'image_int',
    'image_name'
])

In [None]:
df.loc[df.active==True]

In [None]:
df.loc[df.active==False]

## behavior

In [None]:
def make_behavior_table(session, df):
    '''
    Input session and stim/pres df with selected trials to analyze (4 before change)
    '''
    
    # Get timestamps corresponding to go trials
    trial_start = df.start_time
    trial_stop = df.end_time

    # Get running speed and corresponding timestamps
    running_time = session.running_speed.timestamps
    running_speed = session.running_speed.speed
    mean_speed = [np.nanmean(running_speed[np.logical_and(s1 <= running_time, running_time <= s2)]) for s1, s2 in zip(trial_start, trial_stop)]

    # Get pupil size and corresponding timestamps
    pupil_time = session.eye_tracking.timestamps
    pupil_area = session.eye_tracking.pupil_area
    mean_pupil_area = [np.nanmean(pupil_area[np.logical_and(s1 <= pupil_time, pupil_time <= s2)]) for s1, s2 in zip(trial_start, trial_stop)]
    # impute missing values
    inds = np.where(np.isnan(mean_pupil_area))[0]
    for i in inds:
        mean_pupil_area[i] = np.nanmean(mean_pupil_area[i-1:i+1])

    # Get lick counts
    # lick_count = session.trials.apply(lambda row : len(row['lick_times']), axis = 1)

    # Calculate hit rate
    # hit_rate = session.trials.hit.rolling(10).mean().values
    # hit_rate[:9] = 0 #otherwise these will be nans

    # Construct a dataframe
    behavior_data = pd.DataFrame({
                'Mean speed': mean_speed, 
                'Mean pupil area': mean_pupil_area})
    
    return behavior_data

In [None]:
behavior_data = make_behavior_table(session, df)

In [None]:
behavior_data

In [None]:
plt.figure(figsize=(15,5))
plt.plot(behavior_data['Mean speed'])

In [None]:
plt.figure(figsize=(15,5))
plt.plot(session.running_speed.speed)

In [None]:
behavior_data = make_behavior_table(session)
print(behavior_data.isnull().values.any())
behavior_data.head(1)

In [None]:
session.trials.head(1)