# Explortatory analysis 

Notebook for exploratory analysis on tetrode data in the mPFC of a rat doing the parametric working memory (PWM) task.

Session: `data_sdc_20190902_145404_fromSD`

Current data storage:
* raw .dat, .rec, .mda, .bin and preprocessed .bin files are located on scratch under ``

* sorted data is located on bucket `Y:\jbreda\ephys\post_sort_analysis\sorted_pre_bdata`

* in a sorted folder: 
    * folder for each .bin bundle & cluster notes, matlab struct w/ spike info from scraped phy, matlab struct w/ behavior info scraped from bdata
    * in .bin bundle folder you will find curated kilosort output, mask info as npy and preprocessed .bin that was run

see [jbreda_PWM_ephys_analysis](https://github.com/Brody-Lab/jbreda_PWM_ephys_analysis) for more info on how this info was obtained


**TODO**
* spk struct --> data frame
* make utils.py
* turn df cell into function


## Libs & fxs

In [3]:
# libraries

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import scipy.io as spio
from utils import *

## Behavior

In [4]:
beh_path = 'Y:\jbreda\ephys\post_sort_analysis\sorted_pre_bdata\data_sdc_20190902_145404_fromSD\protocol_info.mat'
beh_dict = load_nested_mat(beh_path)
beh_dict = beh_dict['behS']
parsed_events_dict = beh_dict['parsed_events']

In [None]:
# initialize df
beh_df = pd.DataFrame()

# make any adjustments to make things more meaningful/readable
prev_side_adj = np.roll(beh_dict['prev_side'],1) # n-1 trial info
prev_side_adj = np.where(prev_side_adj == 114, 'RIGHT', 'LEFT' )
prev_side_adj[0] = 'NaN' # trial 0 doesn't have a previous

beh_df['hit_hist'] = beh_dict['hit_history']
beh_df['hit_hist'] = beh_df['hit_hist'].mask(beh_df['hit_hist'] == 1.0, "hit")
beh_df['hit_hist'] = beh_df['hit_hist'].mask(beh_df['hit_hist'] == 0.0, "miss")
beh_df['hit_hist'][beh_df['hit_hist'].isnull()] = "viol"

# get n_trial length items into df
beh_df['delay'] = beh_dict['delay']
beh_df['pair_hist'] = beh_dict['pair_history']
beh_df['correct_side'] = beh_dict['correct_side']
beh_df['prev_side'] = prev_side_adj
beh_df['aud1_sigma'] = beh_dict['aud1_sigma']
beh_df['aud2_sigma'] = beh_dict['aud2_sigma']
 
# initilize space    
c_poke = np.zeros((len(parsed_events_dict)))
hit_state = np.zeros((len(parsed_events_dict)))
aud1_time = np.zeros((len(parsed_events_dict)))
aud2_time = np.zeros((len(parsed_events_dict)))

# iterate over items from state matrix
for trial in range(len(parsed_events_dict)):
    
    # every trial has a center poke
    c_poke[trial] = parsed_events_dict[trial]['states']['cp'][0]
    
    # not all trials will have sound/hit time/etc, pull out info for hits
    if beh_df['hit_hist'][trial] == 1.0:
        
        hit_state[trial] = parsed_events_dict[trial]['states']['hit_state'][0]
        aud1_time[trial] = parsed_events_dict[trial]['waves']['stimAUD1'][0]
        aud2_time[trial] = parsed_events_dict[trial]['waves']['stimAUD2'][0]
    else:
        hit_state[trial] = float("NaN")
        aud1_time[trial] = float("NaN")
        aud2_time[trial] = float("NaN")

# add to df
beh_df['c_poke'] = c_poke
beh_df['hit_state'] = hit_state
beh_df['aud1_time'] = aud1_time
beh_df['aud2_time'] = aud2_time
 
beh_df

## Ephys 

In [None]:
spks_path = 'Y:\jbreda\ephys\post_sort_analysis\sorted_pre_bdata\data_sdc_20190902_145404_fromSD\ksphy_clusters_foranalysis.mat'
spks_dict = spio.loadmat(spks_path, squeeze_me = True)
spks_dict = spks_dict['PWMspkS']

In [None]:
"proof of cell"

# need to f/u w/ Tyler on these shapes

tt_num = 3
wave_snippet = spks_dict['waves_mn']
plt.plot(wave_snippet[tt_num][tt_num])

In [None]:
n_cells = len(spks_dict["event_ts_fsm"])
trode_num = spks_dict["trodenum"]
spk_qual = []
for cell in range(n_cells):
    if spks_dict["mua"][cell] == 1:
        spk_qual.append("multi")
    elif spks_dict["single"][cell] == 1:
        spk_qual.append("single")
    else:
        raise TypeError("cell not marked as multi or single")
        

In [None]:
from spykes.plot.neurovis import NeuroVis

def initiate_neurons(spk_data, sess_date = "20190902"):
    
    spk_in_fsm_time = spk_data["event_ts_fsm"] # fsm = behavior time
    neuron_list = []
    
    for neuron in range(len(spk_in_fsm_time)):
        spk_times = spk_in_fsm_time[neuron]
        
        # instantiate neuron
        neuron = NeuroVis(spk_times, name = '{} {}'.format(neuron + 1, sess_date))
        neuron_list.append(neuron)
    
    return neuron_list

In [None]:
neuronL = initiate_neurons(spks_dict)

In [None]:
event = 'c_poke'
condition = 'delay'
window = [-500, 10000]
binsize = 50
neuron_number = 1
neuron = neuronL[neuron_number - 1]

plt.figure(figsize=(10, 5))
psth = neuron.get_psth(event = event,
                       conditions = condition,
                       df = beh_df,
                       window = window,
                       binsize = binsize,
                       event_name = 'Center Poke')

In [None]:
window = [-500, 1000]
binsize = 10
neuron = neuronL[2]

plt.figure(figsize=(10, 8))
raster = neuron.get_raster(event = 'aud1_time', 
                           df = beh_df, 
                           window=window, 
                           binsize = binsize)

# Plotting 2 second, L hit trials

In [None]:
beh_df_d2_hl = beh_df[(beh_df['correct_side'] == 'LEFT') & (beh_df['delay'] == 2)
       & (beh_df['hit_hist'] == 'hit')]