In [7]:
import os
import importlib

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp

from utils import read_intan as ri
from utils import signal_preprocessing as spp
from utils import plot_chuncked_data as pcd

In [8]:
raw_path = "/Volumes/CHOO'S SSD/LINK/eeg/s-eeg/d21"
# npz_path = './data/compressed/attempt1'
exp_name = raw_path.split('/')[-1]
pin_map = {
        1: 'in11', 2: 'in10', 3: 'in9', 4: 'in8',
        5: 'in23', 6: 'in22', 7: 'in21', 8: 'in20',
        9: 'in12', 10: 'in13', 11: 'in14', 12: 'in15',
        13: 'in16', 14: 'in17', 15: 'in18', 16: 'in19'
    }

In [9]:
raw_df, metadata = ri.rhd2dataframe(raw_path, pin_map)

raw_df.head(10)

Found 29 RHD files in /Volumes/CHOO'S SSD/LINK/eeg/s-eeg/d21
Successfully loaded data from 250630 pilo_250630_155325.rhd
Successfully loaded data from 250630 pilo_250630_155825.rhd
Successfully loaded data from 250630 pilo_250630_160325.rhd
Successfully loaded data from 250630 pilo_250630_160825.rhd
Successfully loaded data from 250630 pilo_250630_161325.rhd
Successfully loaded data from 250630 pilo_250630_161825.rhd
Successfully loaded data from 250630 pilo_250630_162325.rhd
Successfully loaded data from 250630 pilo_250630_162825.rhd
Successfully loaded data from 250630 pilo_250630_163325.rhd
Successfully loaded data from 250630 pilo_250630_163825.rhd
Successfully loaded data from 250630 pilo_250630_164325.rhd
Successfully loaded data from 250630 pilo_250630_164825.rhd
Successfully loaded data from 250630 pilo_250630_165325.rhd
Successfully loaded data from 250630 pilo_250630_165825.rhd
Successfully loaded data from 250630 pilo_250630_170325.rhd
Successfully loaded data from 250630 pi

Unnamed: 0,time,pin_4,pin_3,pin_2,pin_1,pin_9,pin_7,pin_6,pin_5
0,0.0,-15.405,-14.04,45.63,-18.72,357.24,-74.88,-7.995,-54.21
1,5e-05,-19.89,-14.43,45.63,-21.45,348.855,-80.925,-8.775,-54.21
2,0.0001,-22.815,-23.985,45.63,-24.96,349.44,-80.925,-7.215,-54.6
3,0.00015,-19.11,-15.015,45.63,-21.255,357.24,-80.925,-1.755,-54.21
4,0.0002,-15.99,-11.505,45.63,-17.355,348.855,-81.9,-4.875,-45.825
5,0.00025,-11.895,6.825,51.87,-16.965,352.56,-75.075,-0.975,-52.065
6,0.0003,-11.895,6.24,61.035,-8.775,357.24,-75.855,7.41,-54.21
7,0.00035,-4.68,10.92,62.205,-8.775,357.24,-72.345,8.775,-41.145
8,0.0004,3.705,18.72,62.205,-8.775,358.8,-71.565,7.605,-39.585
9,0.00045,3.705,10.92,62.205,-7.02,363.48,-66.3,7.605,-39.585


In [10]:
df = raw_df.copy()

In [11]:
mosaic_groups = {
    'lat': [(4, 1), (1, 5), (5, 6), (6, 7)],
    'med': [(4, 2), (2, 3), (3, 9)],
    'ref_shift': [(4, 7), (1, 7), (2, 7), (5, 7), (3, 7), (6, 7), (9, 7)]
}

# df = spp.apply_common_average_reference(df)
# df = spp.zscore_normalize(df)
mosaic_data = spp.get_mosaic_df(df, mosaic_groups)

In [None]:
mosaic_data['lat'].head(10)

Unnamed: 0,time,pin_4-pin_1,pin_1-pin_5,pin_5-pin_6,pin_6-pin_7
0,0.0,3.315,35.49,-46.215,66.885
1,5e-05,1.56,32.76,-45.435,72.15
2,0.0001,2.145,29.64,-47.385,73.71
3,0.00015,2.145,32.955,-52.455,79.17
4,0.0002,1.365,28.47,-40.95,77.025
5,0.00025,5.07,35.1,-51.09,74.1
6,0.0003,-3.12,45.435,-61.62,83.265
7,0.00035,4.095,32.37,-49.92,81.12
8,0.0004,12.48,30.81,-47.19,79.17
9,0.00045,10.725,32.565,-47.19,73.905


: 

In [None]:
def plot_channels(df, time_min=None, time_max=None):
    """Plot each channel in df as a subplot with symmetric y-limits based on nearest multiple of 50 and exact x-axis range."""
    # Filter by time range if specified
    df_plot = df
    if time_min is not None:
        df_plot = df_plot[df_plot['time'] >= time_min]
    if time_max is not None:
        df_plot = df_plot[df_plot['time'] <= time_max]
    # Identify channel columns
    channels = [col for col in df_plot.columns if col != 'time']
    # Compute global min/max across all channels
    data_vals = df_plot[channels].values
    global_max = data_vals.max()
    global_min = data_vals.min()
    abs_max = max(abs(global_max), abs(global_min))
    # Round up to nearest multiple of 50 for symmetric limits
    y_limit = int(np.ceil(abs_max / 50.0) * 50)
    # Create subplots
    n_ch = len(channels)
    fig, axs = plt.subplots(n_ch, 1, sharex=True, figsize=(10, 2*n_ch))
    if n_ch == 1:
        axs = [axs]
    for ax, ch in zip(axs, channels):
        ax.plot(df_plot['time'], df_plot[ch])
        ax.set_ylim(-y_limit, y_limit)
        ax.set_ylabel(ch)
    # Set exact x-axis limits and remove padding
    x_start = time_min if time_min is not None else df_plot['time'].min()
    x_end = time_max if time_max is not None else df_plot['time'].max()
    for ax in axs:
        ax.set_xlim(x_start, x_end)
        ax.margins(x=0)
    axs[-1].set_xlabel('time')
    plt.tight_layout()
    return fig, axs

start_time = 200
duration = 15  # Duration in seconds to plot
end_time = start_time + duration

for group in list(mosaic_data.keys()):
    mosaic_df = mosaic_data[group]
    # mosaic_df = spp.remove_artifacts_hilbert_kurtosis(mosaic_df)
    mosaic_df = spp.bandpass_filter(mosaic_df, 1, 100, order=4)
    # mosaic_df = spp.resample_dataframe(mosaic_df, 1000)
    fig, axs = plot_channels(mosaic_df, start_time, end_time)


Estimated sampling rate: 20000.0 Hz
