In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

from copy import deepcopy
import logging
import math
from matplotlib.cm import get_cmap, ScalarMappable
from matplotlib import colors
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
import os
import pandas as pd
import scipy
from scipy.spatial.distance import pdist, squareform
from sklearn.metrics import r2_score
import sys

sys.path.append('../../')
from data_utils import *

sys.path.append('../../../DeLASE')
from delase import *
from utils import numpy_torch_conversion
from stability_estimation import *
from parameter_choosing import *
from performance_metrics import *

sys.path.append('/om2/user/eisenaj/code/repos/jPCA')
from jPCA import jPCA

plt.style.use('../../sci_style.py')

# Load Data

In [3]:
# session = 'MrJones-Anesthesia-20160109-01'
session = 'Mary-Anesthesia-20160912-02'
results_dir = '/scratch2/weka/millerlab/eisenaj/ChaoticConsciousness/session_results'

In [4]:
all_data_dir = '/scratch2/weka/millerlab/eisenaj/datasets/anesthesia/mat'
data_class = get_data_class(session, all_data_dir)

os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"
# variables = ['electrodeInfo', 'lfp', 'lfpSchema', 'sessionInfo', 'trialInfo', 'unitInfo']
# session_vars, T, N, dt = load_session_data(session, all_data_dir, variables, data_class=data_class, verbose=False)
# electrode_info, lfp, lfp_schema, session_info, trial_info, unit_info = session_vars['electrodeInfo'], session_vars['lfp'], session_vars['lfpSchema'], session_vars['sessionInfo'], session_vars['trialInfo'], session_vars['unitInfo']

variables = ['electrodeInfo', 'lfp', 'lfpSchema', 'sessionInfo', 'trialInfo', 'unitInfo']
session_vars, T, N, dt = load_session_data(session, all_data_dir, variables, data_class=data_class, verbose=False)
electrode_info, lfp, lfp_schema, session_info, trial_info, unit_info = session_vars['electrodeInfo'], session_vars['lfp'], session_vars['lfpSchema'], session_vars['sessionInfo'], session_vars['trialInfo'], session_vars['unitInfo']

In [5]:
eyes_open = session_info['eyesOpen'][-1] if isinstance(session_info['eyesOpen'], np.ndarray) else session_info['eyesOpen']
eyes_close = session_info['eyesClose'][-1] if isinstance(session_info['eyesClose'], np.ndarray) else session_info['eyesClose']

In [6]:
section_times = dict( 
        wake=(0, session_info['drugStart'][0]),
        induction=(session_info['drugStart'][0], eyes_close),
        anesthesia=(eyes_close, session_info['drugEnd'][1]),
        recovery=(session_info['drugEnd'][1], T*dt)
)
sections = list(section_times.keys())

In [7]:
tone_on = trial_info['cpt_toneOn'][~np.isnan(trial_info['cpt_toneOn'])]
tone_off = trial_info['cpt_toneOff'][~np.isnan(trial_info['cpt_toneOff'])]

In [9]:
class PCA:
    def __init__(self, n_components=None, use_torch=False, device='cpu', dtype='torch.DoubleTensor'):
        self.n_components = n_components
        self.use_torch = use_torch
        self.device = device
        self.dtype = dtype
        
    def compute_multidim_mean(self, data):
        return data.mean(axis=tuple(np.arange(0, len(data.shape)-1)))
    
    def fit(self, data):
        data = numpy_torch_conversion(data, self.use_torch, self.device, self.dtype)   
        data_centered = data - self.compute_multidim_mean(data)
        if self.use_torch:
            U, S, Vh = torch.linalg.svd(data_centered, full_matrices=False)
            self.U = U.cpu().numpy()
            self.S = S.cpu().numpy()
            self.V = Vh.cpu().numpy().T

        else:
            U, S, Vh = np.linalg.svd(data_centered, full_matrices=False)
            self.U = U
            self.S = S
            self.V = Vh.T
        
        self.explained_variance_ = ((S**2)/(data.shape[0] - 1))[:self.n_components]
    
    def transform(self, data):
        data = numpy_torch_conversion(data, self.use_torch, self.device, self.dtype)
        data_centered = data - self.compute_multidim_mean(data)
        return (data_centered) @ self.V[:, :self.n_components]

    def fit_transform(self, data):
        self.fit(data)
        return self.transform(data)

In [10]:
leadup = 2000
post = 2000
p = 10
# leadup = 500
# post = 1500

In [11]:
areas = ['vlPFC', 'FEF', '7b', 'CPB']

In [12]:
num_wake_samples = np.sum(tone_on <= session_info['drugStart'][0])
num_anesthesia_samples = np.sum(np.logical_and(tone_on > eyes_close, tone_on < session_info['drugEnd'][1]))
random_times = np.sort(np.random.uniform(0, session_info['drugStart'][0], size=(num_wake_samples,)))
random_times = np.hstack([random_times, np.sort(np.random.uniform(eyes_close, session_info['drugStart'][1], size=(num_anesthesia_samples,)))])

In [14]:
lfp_traj_wake = {}
lfp_traj_anesthesia = {}

iterator = tqdm(total=len(areas)*(len(tone_on) + len(random_times)))

for i, area in enumerate(areas):
    
    lfp_traj_wake[area] = {}
    lfp_traj_anesthesia[area] = {}

    if area == 'all':
        unit_indices = np.arange(len(electrode_info['area']))
    else:
        unit_indices = np.where(electrode_info['area'] == area)[0]
    
    for time_locs, time_loc_array in [('tone', tone_on), ('random', random_times)]:
    
        lfp_tone_wake = []
        lfp_tone_anesthesia = []
        num_wake = 0
        num_anesthesia = 0
        for t in time_loc_array:
            if t < session_info['drugStart'][0] and t > ((leadup + p - 1)*dt):
                t = int(t/dt)
                lfp_tone_wake.append(embed_signal(lfp[t - leadup -p + 1:t + post, unit_indices], p, use_torch=False))
                num_wake += 1
            elif t > eyes_close and t < session_info['drugEnd'][1]:
                t = int(t/dt)
                lfp_tone_anesthesia.append(embed_signal(lfp[t - leadup - p + 1:t + post, unit_indices], p, use_torch=False))
                num_anesthesia += 1
            
            iterator.update()
        lfp_tone_wake = np.array(lfp_tone_wake)
        lfp_tone_anesthesia = np.array(lfp_tone_anesthesia)
        
        lfp_traj_wake[area][time_locs] = lfp_tone_wake
        lfp_traj_anesthesia[area][time_locs] = lfp_tone_anesthesia
iterator.close()

  0%|          | 0/5400 [00:00<?, ?it/s]

# jPCA

In [38]:
area = 'vlPFC'
time_loc_type = 'random'

In [46]:
time_vals = np.arange(-leadup, post, 1)
tstart = -leadup
tend = post - 1

In [173]:
num_dims = 6

In [207]:
# num_trials_wake = 100
# num_trials_anesthesia = 100

num_trials_wake = len(lfp_traj_wake[area][time_loc_type])
num_trials_anesthesia = len(lfp_traj_anesthesia[area][time_loc_type])

In [208]:
normalize = True

In [209]:
# Create a jPCA object
jpca_wake = jPCA.JPCA(num_jpcs=num_dims)
datas = list(lfp_traj_wake[area][time_loc_type][np.random.choice(np.arange(lfp_traj_wake[area][time_loc_type].shape[0], dtype=int), size=(num_trials_wake, ), replace=False)])
if normalize:
    for i in range(len(datas)):
        datas[i] = (datas[i] - datas[i].mean())/datas[i].std()
(projected_wake, 
 full_data_var_wake,
 pca_var_capt_wake,
 jpca_var_capt_wake) = jpca_wake.fit(datas, num_pcs=num_dims, times=time_vals, tstart=tstart, tend=tend, align_axes_to_data=False, verbose=True)
projected_wake = np.stack(projected_wake, 0)
projected_mean_wake = projected_wake.mean(axis=0)

Subtracting means
Performing PCA
PCA complete!
Preprocessing complete!
Calculating skew symmetric matrix
Optimization failed.
Desired error not necessarily achieved due to precision loss.
Complete!


In [None]:
# Create a jPCA object
jpca_anesthesia = jPCA.JPCA(num_jpcs=num_dims)
datas = list(lfp_traj_anesthesia[area][time_loc_type][np.random.choice(np.arange(lfp_traj_anesthesia[area][time_loc_type].shape[0], dtype=int), size=(num_trials_anesthesia, ), replace=False)])
if normalize:
    for i in range(len(datas)):
        datas[i] = (datas[i] - datas[i].mean())/datas[i].std()
(projected_anesthesia, 
 full_data_var_anesthesia,
 pca_var_capt_anesthesia,
 jpca_var_capt_anesthesia) = jpca_anesthesia.fit(datas, num_pcs=num_dims, times=time_vals, tstart=tstart, tend=tend, align_axes_to_data=False, verbose=True)
projected_anesthesia = np.stack(projected_anesthesia, 0)
projected_mean_anesthesia = projected_anesthesia.mean(axis=0)

Subtracting means
Performing PCA


In [None]:
plot_mean = False

In [None]:
title_size = 15

In [None]:
# plot_start = -leadup
# plot_length = (leadup + post)*dt
plot_start = 0
plot_length = 1
trial_num = 5
# norm = colors.TwoSlopeNorm(vmin=0, vcenter=leadup, vmax=leadup + post)
norm = colors.Normalize(vmin=plot_start, vmax=plot_start + int(plot_length/dt))
cmap = plt.cm.get_cmap('RdYlBu_r')

fig, axs = plt.subplots(3, 2, figsize=(12, 12), sharex=True, sharey=True)
ax = axs[0][0]
for i in range(plot_start + leadup, plot_start + leadup + int(plot_length/dt)):
    if plot_mean:
        ax.plot(projected_mean_wake[i:i+2, 0], projected_mean_wake[i:i+2, 1], c=cmap(norm(i - leadup)))
    else:
        ax.plot(projected_wake[trial_num, i:i+2, 0], projected_wake[trial_num, i:i+2, 1], c=cmap(norm(i - leadup)))
    ax.set_xlabel('jPC1')
    ax.set_ylabel('jPC2')
    if plot_mean:
        ax.set_title(f'Mean Wakeful Trajectory jPC Plane 1\n({np.sum(jpca_var_capt_wake[:2])*100/full_data_var_wake:.2f}% of Variance Across {num_trials_wake} Trials)', fontsize=title_size)
    else:
        ax.set_title(f'Wakeful Trajectory Trial {trial_num} Plane 1\n({np.sum(jpca_var_capt_wake[:2])*100/full_data_var_wake:.2f}% of Variance Across {num_trials_wake} Trials)', fontsize=title_size)
    
ax = axs[1][0]
for i in range(plot_start + leadup, plot_start + leadup + int(plot_length/dt)):
    if plot_mean:
        ax.plot(projected_mean_wake[i:i+2, 2], projected_mean_wake[i:i+2, 3], c=cmap(norm(i - leadup)))
    else:
        ax.plot(projected_wake[trial_num, i:i+2, 2], projected_wake[trial_num, i:i+2, 3], c=cmap(norm(i - leadup)))
    ax.set_xlabel('jPC3')
    ax.set_ylabel('jPC4')
    if plot_mean:
        ax.set_title(f'Mean Wakeful Trajectory jPC Plane 2\n({np.sum(jpca_var_capt_wake[2:4])*100/full_data_var_wake:.2f}% of Variance Across {num_trials_wake} Trials)', fontsize=title_size)
    else:
        ax.set_title(f'Wakeful Trajectory Trial {trial_num} Plane 2\n({np.sum(jpca_var_capt_wake[2:4])*100/full_data_var_wake:.2f}% of Variance Across {num_trials_wake} Trials)', fontsize=title_size)
    
ax = axs[2][0]
for i in range(plot_start + leadup, plot_start + leadup + int(plot_length/dt)):
    if plot_mean:
        ax.plot(projected_mean_wake[i:i+2, 4], projected_mean_wake[i:i+2, 5], c=cmap(norm(i - leadup)))
    else:
        ax.plot(projected_wake[trial_num, i:i+2, 4], projected_wake[trial_num, i:i+2, 5], c=cmap(norm(i - leadup)))
    ax.set_xlabel('jPC5')
    ax.set_ylabel('jPC6')
    if plot_mean:
        ax.set_title(f'Mean Wakeful Trajectory jPC Plane 3\n({np.sum(jpca_var_capt_wake[4:6])*100/full_data_var_wake:.2f}% of Variance Across {num_trials_wake} Trials)', fontsize=title_size)
    else:
        ax.set_title(f'Wakeful Trajectory Trial {trial_num} Plane 3\n({np.sum(jpca_var_capt_wake[4:6])*100/full_data_var_wake:.2f}% of Variance Across {num_trials_wake} Trials)', fontsize=title_size)

ax = axs[0][1]
for i in range(plot_start + leadup, plot_start + leadup + int(plot_length/dt)):
    if plot_mean:
        ax.plot(projected_mean_anesthesia[i:i+2, 0], projected_mean_anesthesia[i:i+2, 1], c=cmap(norm(i - leadup)))
    else:
        ax.plot(projected_anesthesia[trial_num, i:i+2, 0], projected_anesthesia[trial_num, i:i+2, 1], c=cmap(norm(i - leadup)))
    ax.set_xlabel('jPC1')
    ax.set_ylabel('jPC2')
    if plot_mean:
        ax.set_title(f'Mean Anesthetic Trajectory jPC Plane 1\n({np.sum(jpca_var_capt_anesthesia[:2])*100/full_data_var_anesthesia:.2f}% of Variance Across {num_trials_anesthesia} Trials)', fontsize=title_size)
    else:
        ax.set_title(f'Anesthetic Trajectory Trial {trial_num} Plane 1\n({np.sum(jpca_var_capt_anesthesia[:2])*100/full_data_var_anesthesia:.2f}% of Variance Across {num_trials_anesthesia} Trials)', fontsize=title_size)
    
ax = axs[1][1]
for i in range(plot_start + leadup, plot_start + leadup + int(plot_length/dt)):
    if plot_mean:
        ax.plot(projected_mean_anesthesia[i:i+2, 2], projected_mean_anesthesia[i:i+2, 3], c=cmap(norm(i - leadup)))
    else:
        ax.plot(projected_anesthesia[trial_num, i:i+2, 2], projected_anesthesia[trial_num, i:i+2, 3], c=cmap(norm(i - leadup)))
    ax.set_xlabel('jPC3')
    ax.set_ylabel('jPC4')
    if plot_mean:
        ax.set_title(f'Mean Anesthetic Trajectory jPC Plane 2\n({np.sum(jpca_var_capt_anesthesia[2:4])*100/full_data_var_anesthesia:.2f}% of Variance Across {num_trials_anesthesia} Trials)', fontsize=title_size)
    else:
        ax.set_title(f'Anesthetic Trajectory Trial {trial_num} Plane 2\n({np.sum(jpca_var_capt_anesthesia[2:4])*100/full_data_var_anesthesia:.2f}% of Variance Across {num_trials_anesthesia} Trials)', fontsize=title_size)
    
ax = axs[2][1]
for i in range(plot_start + leadup, plot_start + leadup + int(plot_length/dt)):
    if plot_mean:
        ax.plot(projected_mean_anesthesia[i:i+2, 4], projected_mean_anesthesia[i:i+2, 5], c=cmap(norm(i - leadup)))
    else:
        ax.plot(projected_anesthesia[trial_num, i:i+2, 4], projected_anesthesia[trial_num, i:i+2, 5], c=cmap(norm(i - leadup)))
    ax.set_xlabel('jPC5')
    ax.set_ylabel('jPC6')
    if plot_mean:
        ax.set_title(f'Mean Anesthetic Trajectory jPC Plane 3\n({np.sum(jpca_var_capt_anesthesia[4:6])*100/full_data_var_anesthesia:.2f}% of Variance Across {num_trials_anesthesia} Trials)', fontsize=title_size)
    else:
        ax.set_title(f'Anesthetic Trajectory Trial {trial_num} Plane 3\n({np.sum(jpca_var_capt_anesthesia[4:6])*100/full_data_var_anesthesia:.2f}% of Variance Across {num_trials_anesthesia} Trials)', fontsize=title_size)

plt.tight_layout()
        
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
# plt.suptitle(f'{embed_method} of Tone Event-Related Potentials in {area_labels[area]}', fontsize=18)
#     plt.tight_layout()
fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.9, 0.25, 0.01, 0.5])
# cbar = fig.colorbar(sm, cax=cbar_ax, ticks=np.arange(0, leadup + post, 250), label='Time Relative to Tone (ms)')
# cbar.ax.set_yticklabels(np.arange(-leadup, post, 250), fontsize=12)

cbar = fig.colorbar(sm, cax=cbar_ax, ticks=np.arange(plot_start, plot_start + int(plot_length/dt) + 1, 100), label='Time Relative to Tone (ms)')
cbar.ax.set_yticklabels(np.arange(plot_start, plot_start + int(plot_length/dt) + 1, 100), fontsize=12)
cbar.set_label(label=f'Time Relative to {time_loc_type.capitalize()} Event (ms)', fontsize=14)
    

plt.show()