In [39]:
# Import statements
import os
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.patches as patches
from matplotlib.transforms import Bbox
from matplotlib.figure import Figure
from matplotlib.colors import LinearSegmentedColormap
import sys
from glob import glob
from pymatreader import read_mat
from scipy import signal
from scipy import stats
from scipy.ndimage import uniform_filter1d

import itertools

#import matlab.engine
# Test out matlab spike detection code
#eng = matlab.engine.start_matlab()
#eng.eval("startup", nargout=0)

# Importing custom python files
import consts
import importlib
importlib.reload(consts)

from open_ephys.analysis import Session

%matplotlib tk

f = os.sep

In [40]:
# Read in pickle file
interm_data_path =  f + 'home' +f+ 'pierfier' +f+ 'Projects' +f+ 'Pierre Fabris' +f+ 'PV DBS neocortex' +f+ 'Interm_Data' +f+ 'flicker.pkl'
savefig_path = f + 'home' +f+ 'pierfier' +f+ 'Dropbox' +f+ 'RKC-HanLab' +f+ 'Pierre PV DBS Project Dropbox' +f+ 'Materials' +f+ 'Plots' +f+ 'Flicker' +f+ 'Summary' +f
df = pd.read_pickle(interm_data_path)
print(df.columns)

Index(['frame_time', 'raw_trace', 'sub_vm', 'detrend_vm', 'trace_noise',
       'interp_time', 'interp_raw_trace', 'interp_raw_detrend_trace',
       'interp_spike_raster', 'spike_amp_raster', 'flicker_raster',
       'stim_raster', 'interp_subvm', 'stim_freq', 'flicker_freq', 'mouse_id',
       'session_id', 'trial_id', 'stim_param', 'fov_id', 'roi_id'],
      dtype='object')


In [41]:
# Loop through and plot all of the trials for both 40 Hz and 140 Hz
test_freqs = df['stim_freq'].unique()
cm = 0.394
for freq in test_freqs:
    freq_df = df[df['stim_freq'] == freq]
    plot_df = freq_df.pivot(index=['mouse_id', 'session_id', 'fov_id', 'trial_id'], columns='interp_time', values='interp_raw_detrend_trace')
    stim_df = freq_df.pivot(index=['mouse_id', 'session_id', 'fov_id', 'trial_id'], columns='interp_time', values='stim_raster')
    
    # Grab the number of trials for each multi-index
    num_trial_df = plot_df.groupby(level=['mouse_id', 'session_id', 'fov_id'], sort=False).size()

    # Remove fovs that had no trials
    filtered_groups = num_trial_df[num_trial_df > 0].index
    num_trial_df = num_trial_df.loc[filtered_groups]

    print(num_trial_df)

    # Create neuron boundaries
    neuron_bound = np.array([0.5])
    for sz in num_trial_df.values:
        print(sz)
        neuron_bound = np.append(neuron_bound, neuron_bound[-1] + sz)

    #print(stim_df.iloc[0])

    #print('Summed mean')
    print(stim_df.mean(axis=0).sum())
    #print('Summed std')
    #print(stim_df.std(axis=0).sum())
    avg_stim_raster = stim_df.mean(axis=0).to_numpy()
    timeline = plot_df.columns.values

    t = plot_df.columns.values[avg_stim_raster > 0]
    #plt.figure()
    #plt.plot(t)
    #plt.title(freq)
    block_idx = np.arange(plot_df.shape[0]) + 1 # Create row index array
    
    # Find the onset and offset of each burst period
    gap_start_idx = consts.find_start_idx(stim_df.iloc[0].values)
    gap_end_idx = avg_stim_raster.shape[0] - consts.find_start_idx(stim_df.iloc[0].values[::-1]) - 1
    gap_start_idx = gap_start_idx.astype(int)
    gap_end_idx = gap_end_idx.astype(int)[::-1]
    
    # Plot the heatmap
    fig, ax = plt.subplots()
    
    plt.rcParams['font.size'] = 7
    ax.set_position(Bbox.from_extents(1, 1, 11*cm, 9*cm))
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    surf_plot = ax.pcolormesh(timeline, block_idx, plot_df.values)
    cbar = plt.colorbar(surf_plot, label='Normalized Vm')

    # Plot the vertical stimulation onset and offset
    t_idx = np.concatenate((gap_start_idx, gap_end_idx))
    t = plot_df.columns.values[t_idx]
    ax.vlines(x=t, ymin=1, ymax=np.max(block_idx), linestyles='--', colors=consts.stim_color)

    # Plot the stimulation time bars
    stim_y = np.max(block_idx) + 2
    stim_h = 2
    for start_idx, end_idx in zip(gap_start_idx, gap_end_idx):
        anch_y = stim_y - stim_h/2
        stim_w = plot_df.columns.values[end_idx] - plot_df.columns.values[start_idx]
        rect_patch = patches.Rectangle((plot_df.columns.values[start_idx], anch_y), stim_w, stim_h, facecolor=consts.stim_color)
        ax.add_patch(rect_patch)
    ax.hlines(y=stim_y, xmin=np.min(timeline), xmax=np.max(timeline), linestyles='-', colors=consts.stim_color)

       
    # Plot all of the neuron boundaries
    ax.hlines(y=neuron_bound, xmin=np.min(timeline), xmax=np.max(timeline), colors='k')
    
    # Rename the Yticks
    #TODO calculate the midpoint of all neuronbounds
    
    ax.set_yticks(ytick_loc)
    ax.set_yticklabels(np.arange(neuron_bound.shape[0]) + 1) 
    
    ax.set_xlabel('Time (s)')
    ax.set_ylabel('Neuron # (Black outlines indicate all blocks per neuron)')
    
    ax.set_title(str(freq) + 'Hz')
    fig.savefig(savefig_path + 'Summary_Blocks_Vm_' + str(freq) + '.png', format='png')
    fig.savefig(savefig_path + 'Summary_Blocks_Vm_' + str(freq) + '.pdf', format='pdf')
    

plt.show()

mouse_id        session_id  fov_id
109558_Vb_male  20240308    2          8
                            3         15
                20240311    1          6
109567_Vb_male  20240311    1          7
                            4          4
                            5          6
                20240411    2          3
                            3          7
                            1          5
                            4          7
                            5          5
                20240424    3         11
                            4         10
dtype: int64
8
15
6
7
4
6
3
7
5
7
5
11
10
140.0


  num_trial_df = plot_df.groupby(level=['mouse_id', 'session_id', 'fov_id'], sort=False).size()
  num_trial_df = plot_df.groupby(level=['mouse_id', 'session_id', 'fov_id'], sort=False).size()


mouse_id        session_id  fov_id
109558_Vb_male  20240311    1          5
                            2          4
109567_Vb_male  20240311    2          5
                            4          3
                            5          7
                20240411    1          7
                            4          6
                            5          7
                            3          5
                20240424    4         10
                            3         10
dtype: int64
5
4
5
3
7
7
6
7
5
10
10
40.0


In [42]:
plt.close('all')