In [3]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
%autoreload 2
import h5py
from matplotlib import colors
import matplotlib.pyplot as plt
import multiprocessing
from neural_analysis.matIO import loadmat
import numpy as np
import os
import pandas as pd
import ssm
from ssm.util import find_permutation
import sys
import time
from tqdm.auto import tqdm

sys.path.append('../../..')
from ld_utils import slds_compute_eigs
from up_down import get_up_down
from utils import get_binary_stimuli, get_sample_interval, load, save

## Load Neural Data

In [5]:
# filename = '../../__data__/Mary-Anesthesia-20160809-01.mat'
filename = r'/home/adameisen/millerdata/common/datasets/anesthesia/mat/propofolPuffTone/Mary-Anesthesia-20160809-01.mat'
print("Loading data ...")
start = time.process_time()
electrode_info, lfp, lfp_schema, session_info, spike_times, unit_info = loadmat(filename, variables=['electrodeInfo', 'lfp', 'lfpSchema', 'sessionInfo', 'spikeTimes', 'unitInfo'], verbose=False)
spike_times = spike_times[0]
dt = lfp_schema['smpInterval'][0]
T = lfp.shape[0]

f = h5py.File(filename, 'r')
airPuff_binary, audio_binary = get_binary_stimuli(f)

print(f"Data loaded (took {time.process_time() - start:.2f} seconds)")

Loading data ...
Data loaded (took 98.39 seconds)


# Run

In [39]:
# --------
# User-guided SLDS parameters
# --------
latent_dim = 32 # number of latent dimensions
transitions = "standard" # transition class
stride = 30 # s
window = 5 # s

length = int(window/dt)
start_times = np.arange(0, lfp.shape[0]*dt - window + 0.1, stride).astype(int)
# start_times = np.arange(0, lfp.shape[0]*dt - duration - 0.1, stride).astype(int)
# start_times = np.hstack([start_times, lfp.shape[0]*dt - duration - 5])

# areas = ['vlPFC', 'FEF', 'CPB', '7b']
# areas = np.unique(electrode_info['area'])
areas = ['vlPFC']
unit_indices = np.arange(lfp.shape[1])[pd.Series(electrode_info['area']).isin(areas)]
var_names = [f"unit_{unit_num} {electrode_info['area'][unit_num]}" for unit_num in unit_indices]

# --------
# Set the parameters of the SLDS
# --------
emissions_dim = len(unit_indices)     # number of observed dimensions

data_dir = "../../../__data__/propofol/SLDS/"
t = time.localtime()
timestamp = time.strftime('%b-%d-%Y_%H%M', t)
data_dir = os.path.join(data_dir, f"SLDS_{os.path.basename(filename)[:-4]}_latent_{latent_dim}_window_{window}_stride_{stride}_{timestamp}")
os.makedirs(data_dir, exist_ok=True)

In [41]:
anesthesia_bounds = [session_info['drugStart'][0], session_info['drugEnd'][1]]
param_list = []
for start_time in tqdm(start_times):

    # --------
    # Set Disc States for Each Segment
    # --------
    piece_bounds = [start_time, start_time + window]
    if piece_bounds[1] <= anesthesia_bounds[0] or piece_bounds[0] >= anesthesia_bounds[1]:
        # WAKEFUL
        n_disc_states = 2
    elif piece_bounds[1] > anesthesia_bounds[0] and piece_bounds[1] <= anesthesia_bounds[1]:
        if piece_bounds[0] < anesthesia_bounds[0]:
            # TRANSITION TO ANESTHESIA
            n_disc_states = 2
        else: # piece_bounds[0] >= anesthesia_bounds
            # FULL ANESTHESIA
            n_disc_states = 2
    else: # piece_bounds[0] > anesthesia_bounds[1] and piece_bounds[1] > anesthesia_bounds[1]
        # TRANSITION OUT OF ANESTHESIA
        n_disc_states = 2

    start_step = int(start_time/dt)
    data = lfp[start_step:start_step + length, unit_indices]
    results = slds_compute_eigs(data, transitions, emissions_dim, n_disc_states, latent_dim, verbose=False)
    
    results['start_time'] = start_time
    results['start_step'] = start_step
    results['n_disc_states'] = n_disc_states
    
    save(results, os.path.join(data_dir, f"start_time_{start_time}"))

run_params = dict(
    latent_dim=latent_dim,
    emissions_dim=emissions_dim,
    transitions=transitions,
    start_times=start_times,
    window=window,
    stride=stride,
    length=length,
    unit_indices=unit_indices,
    var_names=var_names
)
save(run_params, os.path.join(data_dir, f"run_params"))

  0%|          | 0/237 [00:00<?, ?it/s]