# Create window datasets

In [1]:
import os
import numpy as np
import pandas as pd
from imblearn.under_sampling import RandomUnderSampler

# define data root
# this is the path to the ROAMM folder on local machine
roamm_root = r"/gpfs1/pi/djangraw/mindless_reading/ROAMM"
ml_data_root = os.path.join(roamm_root, 'subject_ml_data')
random_seed = 42
# define window size and sampling rate
sfreq = 256
window_seconds = 0.25

all_subjects = sorted([d for d in os.listdir(ml_data_root) if d.startswith('s') and os.path.isdir(os.path.join(ml_data_root, d))])
df = pd.DataFrame()
for subject_id in all_subjects:
    subject_dir = os.path.join(ml_data_root, subject_id)
    save_dir = os.path.join(subject_dir, 'window_datasets')
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if os.path.exists(save_dir) and len(os.listdir(save_dir)) == 3:
        print(f'Windowed data and labels for subject {subject_id} have already been saved.')
        continue
    pkl_files = [f for f in os.listdir(subject_dir) if f.endswith('.pkl')]

    # make sure each subject has 5 runs of data
    if len(pkl_files) != 5:
        raise ValueError(f"Subject {subject_id} has {len(pkl_files)} runs instead of 5")
    
    for pkl_file in pkl_files:
        df_sub_single_run = pd.read_pickle(os.path.join(subject_dir, pkl_file))
        df_sub_single_run = df_sub_single_run[df_sub_single_run['first_pass_reading'] == 1]
        # convert bool col explicitly to avoid pandas warning
        for col in ['is_blink', 'is_saccade', 'is_fixation', 'is_mw', 'first_pass_reading']:
            df_sub_single_run[col] = df_sub_single_run[col] == True

        # filter out samples 2 seconds before page end
        mask = df_sub_single_run['time'] < df_sub_single_run['page_end']-2
        df_sub_single_run = df_sub_single_run[mask]
        # append to the dataframe
        df = pd.concat([df, df_sub_single_run])
        # add subject id to the dataframe
        df['subject_id'] = subject_id
    print(f'Subject {subject_id} has been loaded.')
    
    # normalize pupil size features
    df['blink_interp_LPupil_norm'] = df['blink_interp_LPupil'] / df['blink_interp_LPupil'].median()
    df['blink_interp_RPupil_norm'] = df['blink_interp_RPupil'] / df['blink_interp_RPupil'].median()

    windowed_data = []
    windowed_labels = []
    window_size = int(sfreq * window_seconds)

    # Process data in chunks of window_size
    for i in range(0, len(df), window_size):
        window = df.iloc[i:i+window_size]
        # Skip if window is too small
        if len(window) < window_size:
            continue
        # Check if labels are consistent in this window
        labels_in_window = window['is_mw'].unique()
        if len(labels_in_window) > 1:
            # Skip windows with mixed labels
            continue

        # Extract features for this window: keep as 2D array (window_size x feature_number)
        windowed_data.append(window.values)
        # Use the consistent label
        windowed_labels.append(labels_in_window[0])

    # Use RandomUnderSampler on flattened data, then recover 3D structure
    windowed_data_flat = [w.flatten() for w in windowed_data]
    undersampler = RandomUnderSampler(random_state=random_seed)
    X_resampled_flat, y_resampled = undersampler.fit_resample(windowed_data_flat, windowed_labels)
    # Recover 3D array: (n_samples, window_size, n_features)
    window_size = windowed_data[0].shape[0]
    n_features = windowed_data[0].shape[1]
    X_resampled = np.array(X_resampled_flat).reshape(-1, window_size, n_features)
    X = np.transpose(X_resampled, (0, 2, 1))  # (N, num_channels, window_size)
    y = np.array(y_resampled, dtype=int)

    # get col names
    col_names = df.columns.tolist()

    # save windowed data and labels
    np.save(os.path.join(save_dir, f'{subject_id}_{window_size}windowed_data.npy'), X)
    np.save(os.path.join(save_dir, f'{subject_id}_{window_size}windowed_labels.npy'), y)

    # for col names, only save one copy
    file_path = os.path.join(save_dir, f'{subject_id}_col_names.npy')
    if not os.path.exists(file_path):
        np.save(file_path, col_names)
    
    print(f'Windowed data and labels for subject {subject_id} have been saved.')

ERROR:tornado.general:Uncaught exception in ZMQStream callback
Traceback (most recent call last):
  File "/gpfs1/pi/djangraw/hsun11/miniconda3/envs/iclr26/lib/python3.12/site-packages/traitlets/traitlets.py", line 632, in get
    value = obj._trait_values[self.name]
            ~~~~~~~~~~~~~~~~~^^^^^^^^^^^
KeyError: '_control_lock'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/gpfs1/pi/djangraw/hsun11/miniconda3/envs/iclr26/lib/python3.12/site-packages/zmq/eventloop/zmqstream.py", line 575, in _log_error
    f.result()
  File "/gpfs1/pi/djangraw/hsun11/miniconda3/envs/iclr26/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 301, in dispatch_control
    async with self._control_lock:
               ^^^^^^^^^^^^^^^^^^
  File "/gpfs1/pi/djangraw/hsun11/miniconda3/envs/iclr26/lib/python3.12/site-packages/traitlets/traitlets.py", line 687, in __get__
    return t.cast(G, self.get(obj, cls))  # the G should encode 

Windowed data and labels for subject s10014 have already been saved.
Windowed data and labels for subject s10052 have already been saved.
Windowed data and labels for subject s10059 have already been saved.
Windowed data and labels for subject s10073 have already been saved.
Windowed data and labels for subject s10081 have already been saved.
Windowed data and labels for subject s10084 have already been saved.
Windowed data and labels for subject s10085 have already been saved.
Windowed data and labels for subject s10089 have already been saved.
Windowed data and labels for subject s10094 have already been saved.
Windowed data and labels for subject s10100 have already been saved.
Windowed data and labels for subject s10103 have already been saved.
Windowed data and labels for subject s10110 have already been saved.
Windowed data and labels for subject s10111 have already been saved.
Windowed data and labels for subject s10115 have already been saved.
Windowed data and labels for subje