# Setup

In [63]:
import sys
module_path = 'W:/home/nero/phasic_tonic/notebooks/buzsaki_method'
if module_path not in sys.path:
    sys.path.insert(0, module_path)
    
from src.DatasetLoader import DatasetLoader
from src.runtime_logger import logger_setup
from src.utils import *
from src.helper import get_metadata

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yasa
import seaborn as sns

from pathlib import Path
from scipy.io import loadmat
from scipy.signal import spectrogram
from mne.filter import resample
from scipy.signal import hilbert
from neurodsp.filt import filter_signal
import pynapple as nap

plt.style.use('seaborn-v0_8-white')
custom_params = {"axes.spines.right": False, "axes.spines.top": False}
sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params)

def preprocess(signal, n_down, target_fs=500):
    """
    Downsample and remove artifacts.
    """
    data = resample(signal, down=n_down, method='fft', npad='auto')
    
    # Remove artifacts
    art_std, _ = yasa.art_detect(data, target_fs , window=1, method='std', threshold=4)
    art_up = yasa.hypno_upsample_to_data(art_std, 1, data, target_fs)
    data[art_up] = 0
    
    data -= data.mean()
    return data

def _detect_troughs(signal, thr):
    lidx  = np.where(signal[0:-2] > signal[1:-1])[0]
    ridx  = np.where(signal[1:-1] <= signal[2:])[0]
    thidx = np.where(signal[1:-1] < thr)[0]
    sidx = np.intersect1d(lidx, np.intersect1d(ridx, thidx))+1
    return sidx

def _despine_axes(ax):
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.axes.get_xaxis().set_visible(False)
    ax.axes.get_yaxis().set_visible(False)

def get_start_end(sleep_states, sleep_state_id, target_fs=500):
    seq = get_sequences(np.where(sleep_states == sleep_state_id)[0])
    start = []
    end = []
    for s, e in seq:
        start.append(s/target_fs)
        end.append(e/target_fs)
    return (start, end)

In [64]:
plt.style.use('seaborn-v0_8-white')

CBD_DIR = "W:/home/nero/datasets/CBD"
RGS_DIR = "W:/home/nero/datasets/RGS14"
OS_DIR = "W:/home/nero/datasets/OSbasic/"
CONF = "W:/home/nero/phasic_tonic/configs/dataset_loading.yaml"

fs_cbd = 2500
fs_os = 2500
fs_rgs = 1000

target_fs = 500
n_down_cbd = fs_cbd/target_fs
n_down_rgs = fs_rgs/target_fs
n_down_os = fs_os/target_fs

datasets = {
# 'dataset_name' : {'dir' : '/path/to/dataset', 'pattern_set': 'pattern_set_in_config'}
    "CBD": {"dir": CBD_DIR, "pattern_set": "CBD"},
    "RGS": {"dir": RGS_DIR, "pattern_set": "RGS"},
    "OS": {"dir": OS_DIR, "pattern_set": "OS"}
}

def get_name(m):
    return f"Rat{m['rat_id']}_SD{m['study_day']}_{m['condition']}_{m['treatment']}_posttrial{m['trial_num']}"

Datasets = DatasetLoader(datasets, CONF)
mapped_datasets = Datasets.load_datasets()
n_down = n_down_rgs

In [65]:
test_id = "Rat6_SD1_HC_2_posttrial5"
sleep_states, lfpHPC, lfpPFC = mapped_datasets[test_id]

lfpHPC = loadmat(lfpHPC)['HPC']
lfpPFC = loadmat(lfpPFC)['PFC']

# Load the states
hypnogram = loadmat(sleep_states)['states']
hypnogram = hypnogram.flatten()
sleep_states = np.repeat(hypnogram, target_fs)

# Downsample to 500 Hz. Remove artifacts with yasa.
lfpHPC = preprocess(lfpHPC.flatten(), n_down)
lfpPFC = preprocess(lfpPFC.flatten(), n_down)

lfpHPC = lfpHPC[:len(sleep_states)]
lfpPFC = lfpPFC[:len(sleep_states)]



# Phasic detection

In [66]:
rem_seq = get_sequences(np.where(hypnogram == 5)[0])
rem_idx = [(start * target_fs, (end+1) * target_fs) for start, end in rem_seq]

rem_idx = ensure_duration(rem_idx, min_dur=3)
if len(rem_idx) == 0:
    raise ValueError("No REM epochs greater than min_dur.")

# get REM segments
rem_epochs = get_segments(rem_idx, lfpHPC)

# Combine the REM indices with the corresponding downsampled segments
rem = {seq:seg for seq, seg in zip(rem_seq, rem_epochs)}

w1 = 5.0
w2 = 12.0
nfilt = 11
thr_dur = 900

trdiff_list = []
rem_eeg = np.array([])
eeg_seq = {}
sdiff_seq = {}
tridx_seq = {}
filt = np.ones((nfilt,))
filt = filt / filt.sum()

for idx in rem:
    start, end = idx

    epoch = rem[idx]
    epoch = filter_signal(epoch, target_fs, 'bandpass', (w1,w2), remove_edges=False)
    epoch = hilbert(epoch)

    inst_phase = np.angle(epoch)
    inst_amp = np.abs(epoch)

    # trough indices
    tridx = _detect_troughs(inst_phase, -3)

    # differences between troughs
    trdiff = np.diff(tridx)

    # smoothed trough differences
    sdiff_seq[idx] = np.convolve(trdiff, filt, 'same')

    # dict of trough differences for each REM period
    tridx_seq[idx] = tridx

    eeg_seq[idx] = inst_amp

    # differences between troughs
    trdiff_list += list(trdiff)

    # amplitude of the entire REM sleep
    rem_eeg = np.concatenate((rem_eeg, inst_amp)) 

trdiff = np.array(trdiff_list)
trdiff_sm = np.convolve(trdiff, filt, 'same')

# potential candidates for phasic REM:
# the smoothed difference between troughs is less than
# the 10th percentile:
thr1 = np.percentile(trdiff_sm, 10)
# the minimum smoothed difference in the candidate phREM is less than
# the 5th percentile
thr2 = np.percentile(trdiff_sm, 5)
# the peak amplitude is larger than the mean of the amplitude
# of the REM EEG.
thr3 = rem_eeg.mean()

phasicREM = {rem_idx:[] for rem_idx in rem.keys()}

for rem_idx in tridx_seq:
    rem_start, rem_end = rem_idx
    offset = rem_start * target_fs

    # trough indices
    tridx = tridx_seq[rem_idx]

    # smoothed trough interval
    sdiff = sdiff_seq[rem_idx]

    # amplitude of the REM epoch
    eegh = eeg_seq[rem_idx]

    # get the candidates for phREM
    cand_idx = np.where(sdiff <= thr1)[0]
    cand = get_sequences(cand_idx)

    for start, end in cand:
        # Duration of the candidate in milliseconds
        dur = ( (tridx[end]-tridx[start]+1)/target_fs ) * 1000
        if dur < thr_dur:
            continue # Failed Threshold 1
        
        min_sdiff = np.min(sdiff[start:end])
        if min_sdiff > thr2:
            continue # Failed Threshold 2
        
        mean_amp =  np.mean(eegh[tridx[start]:tridx[end]+1])
        if mean_amp < thr3:
            continue # Failed Threshold 3
        
        t_a = tridx[start] + offset
        t_b = np.min((tridx[end] + offset, rem_end * target_fs))
        
        ph_idx = (t_a, t_b+1)
        phasicREM[rem_idx].append(ph_idx)

phasic = []
for rem_idx in phasicREM:
    phasic += phasicREM[rem_idx]

# Pynapple

Make interval sets for sleep states

In [67]:
state_epochs = {}

start, end = get_start_end(sleep_states=sleep_states, sleep_state_id=1)
wake_interval = nap.IntervalSet(start=start, end=end)
state_epochs['wake'] = wake_interval

start, end = get_start_end(sleep_states=sleep_states, sleep_state_id=3)
nrem_interval = nap.IntervalSet(start=start, end=end)
state_epochs['nrem'] = nrem_interval

start, end = get_start_end(sleep_states=sleep_states, sleep_state_id=5)
rem_interval = nap.IntervalSet(start=start, end=end)
state_epochs['rem'] = rem_interval

rem_interval

            start        end
       0      664    742.998
       1      993   1034
       2     1078   1298
       3     1593   1629
       4     2154   2268
       5     3451   3516
       6     3576   3614
       7     4856   4870
       8     5001   5021
       9     5304   5439
      10     5720   5767
      11     6000   6141
      12     7496   7588
      13     7858   7949
      14    10134  10162
shape: (15, 2), time unit: sec.

In [100]:
state_epochs['wake']

           start     end
0            0      1.998
1           60     67.998
2           743    769.998
3           867    877.998
4          1034     1053
5          1298     1313
6          1629     1664
7          1810     1844
8          2050     2069
9          2268     3000
          ...
29         7588    7660
30         7949    8014
31         8036    9352
32         9393    9416
33         9459    9541
34         9750    9809
35         9838    9975
36         10010   10024
37         10069   10081
38         10162   10471
shape: (39, 2), time unit: sec.

In [68]:
start = []
end = []
for s, e in phasic:
    start.append(s/target_fs)
    end.append(e/target_fs)
phrem_interval = nap.IntervalSet(start=start, end=end)
state_epochs['phrem'] = phrem_interval

In [69]:
state_epochs['phrem']

           start     end
0         680.652  681.958
1         704.254  705.896
2         1100.82  1102.38
3         1120.01  1121.95
4         1169.08  1171.48
5         1192.43  1193.63
6         2190.71  2191.64
7         2250.05  2253.2
8         3451.06  3453.05
9         3453.99  3459.12
          ...
11        3508.24  3509.74
12        5367.44  5368.53
13        5403.55  5405.73
14        5408.43  5409.62
15        6039.91  6041.26
16        6041.59  6050.16
17        7539.55  7541.08
18        7562.05  7565.9
19        7567.56  7568.57
20        7928.62  7929.85
shape: (21, 2), time unit: sec.

Create TsdFrame for LFP recording of HPC and PFC

In [70]:
lfps = np.column_stack([lfpHPC, lfpPFC])
t = np.arange(0, len(lfpHPC)/target_fs, 1/target_fs)
tsd = nap.TsdFrame(t=t, d=lfps, columns=['HPC', 'PFC'])
tsd

Time (s)          HPC        PFC
----------  ---------  ---------
0.0          -80.2983  -198.232
0.002       -111.567   -236.122
0.004        -34.6435  -221.808
0.006        -88.0469  -152.304
0.008       -100.149   -217.981
0.01        -104.189   -187.283
0.012        -73.0968  -184.699
0.014        -51.0978  -198.999
0.016       -103.641   -196.816
0.018        -73.2302  -188.361
...
10470.98      -9.1383   -77.9168
10470.982    -85.6112  -172.53
10470.984    -36.8904  -129.081
10470.986    -53.6948  -306.563
10470.988    -89.8059  -306.069
10470.99     305.219    178.84
10470.992    136.096    -81.247
10470.994    121.219    -27.2737
10470.996     23.0028   -93.1656
10470.998     34.8787  -101.007
dtype: float64, shape: (5235500, 2)

In [71]:
tsd_list = []
for interval in rem_interval:
    tsd_list.append(tsd.restrict(interval))

In [72]:
tsd.restrict(rem_interval)

Time (s)           HPC         PFC
----------  ----------  ----------
664.0          7.38881    63.6957
664.002      -32.2174     84.5307
664.004     -103.122     118.447
664.006      -94.3268     76.2513
664.008      -86.7985     62.4011
664.01       -68.5082     25.1321
664.012      -59.6681     66.8974
664.014      -69.8047     51.4363
664.016      -55.9205     14.6035
664.018      -46.3999     25.0576
...
10161.98     -58.3848    -33.0059
10161.982    -34.7382    -47.1748
10161.984    -60.0722    -92.9125
10161.986    -93.9748   -107.778
10161.988    -73.8834    -58.4848
10161.99     -22.6279     36.2663
10161.992    -50.7517     55.5676
10161.994    -10.4319      4.49775
10161.996      6.6121     19.6085
10161.998     -4.09829   -54.7709
dtype: float64, shape: (580500, 2)

# matplotlib

In [73]:
%matplotlib qt

In [74]:
from matplotlib.colors import LinearSegmentedColormap
from scipy.signal import spectrogram

nsr_seg = 1
perc_overlap = 0.8
vmax = 3000
vmin = 0
cmap = plt.cm.hot

# Define the custom colors
colors = [[0, 0, 0], [0, 1, 1], [0.6, 0, 1], [0.8, 0.8, 0.8]]

# Create a custom colormap
my_map = LinearSegmentedColormap.from_list('brs', colors, N=5)

freq, t, SP = spectrogram(lfpHPC, fs=target_fs, window='hann', 
                          nperseg=int(nsr_seg * target_fs), 
                          noverlap=int(nsr_seg * target_fs * perc_overlap))

ifreq = np.where(freq <= 20)[0]

In [76]:
f, ax = plt.subplots(nrows=5, ncols=1, 
                     sharex=True, figsize=(12, 8),
                     gridspec_kw = {'height_ratios':[1, 8, 1, 8, 8],
                                        'hspace':0.1})
# Plot sleep states
tmp = ax[0].pcolorfast(tsd["HPC"].t, [0, 1], np.array([hypnogram]), vmin=1, vmax=5)
tmp.set_cmap(my_map)
_despine_axes(ax[0])    

# Plot HPC region
ax[1].plot(tsd["HPC"], color='k')
for epoch in rem_interval:
    start, end = epoch["start"].item(), epoch["end"].item()
    ax[1].axvspan(start, end, facecolor=[0.7, 0.7, 0.8], alpha=0.4)
    
# Plot phasicREM
[ax[1].plot(tsd["HPC"].restrict(phrem_interval[i]), color='r') for i in range(len(phrem_interval))]

#ax[1].set_title("HPC")
ax[1].set_ylabel("mV (HPC)")

# Plot phasicREM as spikes
ax[2].eventplot((phrem_interval["end"]+phrem_interval["start"])/2)
_despine_axes(ax[2])

# Plot PFC region
#ax[3].set_title("PFC")
ax[3].plot(tsd["PFC"], color='orange')
ax[3].set_xlabel("Time (s)")
ax[3].set_ylabel("mV (PFC)")

# Plot spectrogram (theta range)
pcm = ax[4].pcolorfast(t, freq[ifreq], SP[ifreq, :], vmin=vmin, vmax=vmax, cmap=cmap)
ax[4].set_ylabel("Freq. (Hz)")

Text(0, 0.5, 'Freq. (Hz)')

# fastplotlib

In [11]:
%matplotlib inline

In [17]:
import fastplotlib as fpl
from ipywidgets import Layout, VBox, FloatSlider
from sidecar import Sidecar
from workshop_utils.store_model import TimeStore
from IPython.display import display

import warnings

warnings.simplefilter('ignore')
fpl.config.party_parrot = True
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [26]:
data = np.column_stack([tsd.t, tsd['HPC'][:].d])
subsample = 1_000

In [27]:
data.shape

(5235500, 2)

In [30]:
ls_x.get_selected_indices()

array([2058012, 2058013, 2058014, ..., 3366885, 3366886, 3366887],
      dtype=int64)

In [38]:
fig = fpl.Figure((2, 1), size=(800, 400))

graphic = fig[0, 0].add_line(data=data[::subsample])
ls_x = graphic.add_linear_region_selector()

# get the initial selected date of the linear region selector
zoomed_init = ls_x.get_selected_data()

# make a line graphic for displaying zoomed data
zoomed_x = fig[1, 0].add_line(zoomed_init)

@ls_x.add_event_handler("selection")
def set_zoom_x(ev):
    """sets zoomed x selector data"""
    # get the selected data
    ixs = ev.get_selected_indices() * subsample
    selected_data = data[ixs]
    # remove the current zoomed data
    # and update with new selected data
    global zoomed_x

    fig[1, 0].remove_graphic(zoomed_x)
    zoomed_x = fig[1, 0].add_line(selected_data)
    fig[1, 0].auto_scale()

fig.show(maintain_aspect=False)

RFBOutputContext()

JupyterOutputContext(children=(JupyterWgpuCanvas(css_height='400px', css_width='800px'), IpywidgetToolBar(chil…

In [39]:
fig.close()

In [61]:
# test_example = false
import time
import fastplotlib as fpl
import numpy as np
import time

# generate some data
start, stop = 0, 2 * np.pi
increment = (2 * np.pi) / 50

# make a simple sine wave
xs = np.linspace(start, stop, 100)
ys = np.sin(xs)

figure = fpl.Figure()

# plot the image data
sine = figure[0, 0].add_line(ys, name="sine", colors="r")
figure[0, 0].set_title(f"time: 0")


start = time.time_ns()// 1_000_000

iteration = 0
# increment along the x-axis on each render loop :D
def update_line(subplot):
    global increment, start, stop, iteration
    xs = np.linspace(start + increment, stop + increment, 100)
    ys = np.sin(xs)

    start += increment
    stop += increment

    # change only the y-axis values of the line
    subplot["sine"].data[:, 1] = ys
    
    curr = round(time.time_ns() // 1_000_000 - start)
    
    subplot.set_title(f"time: {curr} ms")
    iteration += 1

figure[0, 0].add_animations(update_line)

figure.canvas.set_logical_size(700, 560)

figure[0,0].auto_scale(maintain_aspect=False)

figure.show()

RFBOutputContext()

JupyterOutputContext(children=(JupyterWgpuCanvas(), IpywidgetToolBar(children=(Button(icon='expand-arrows-alt'…

In [62]:
figure.close()