## Imports

In [81]:
from pathlib import Path
from itertools import compress
import re

import mne
from mne.io import read_raw_snirf
from mne.preprocessing.nirs import scalp_coupling_index, beer_lambert_law, temporal_derivative_distribution_repair as tddr
from mne import Epochs, events_from_annotations
from mne.filter import filter_data

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# matplotlib.use('Qt5Agg')
# %matplotlib qt
# %matplotlib inline

## Load Data

In [56]:
# Base path to BIDS dataset
data_path = Path("data/nemo-bids")

# Pick one subject and task
subject = "sub-133"
task = "empe"

# Construct the file path
snirf_file = data_path / subject / "nirs" / f"{subject}_task-{task}_nirs.snirf"

# Read the SNIRF file
raw = read_raw_snirf(snirf_file, preload=True)
print("Data loaded successfully!")

Loading C:\Users\humag\Masaüstü\Research Internship\fNIRS_NEMO_Data_Preprocessing_and_Analysis\data\nemo-bids\sub-133\nirs\sub-133_task-empe_nirs.snirf
Reading 0 ... 155979  =      0.000 ...  3119.580 secs...
Data loaded successfully!


## Preprocessing

In [58]:
# SCI
sci = mne.preprocessing.nirs.scalp_coupling_index(raw)
raw.info["bads"] = list(compress(raw.ch_names, sci < 0.8))   # mark all channels with a SCI less than 0.8 as bad
raw.interpolate_bads()    # default method -> 'nearest'    copies nearest channel

# since raw data is already od data I want to change the name to avoid confusion
raw_od = raw

# tddr
raw_od = tddr(raw_od)

# beer lambert law - Convert OD to HbO and HbR
raw_hb = beer_lambert_law(raw_od, ppf=6.0)

# filtering
raw_hb = raw_hb.filter(l_freq=0.01, h_freq=0.1,
                       l_trans_bandwidth=0.004,
                       h_trans_bandwidth=0.01,
                       verbose=False)
# extract events
events, event_dict = mne.events_from_annotations(raw_hb)

# Create metadata manually for epochs
event_ids = [e[-1] for e in events]
labels = [k for id_ in event_ids for k, v in event_dict.items() if v == id_]
metadata = pd.DataFrame({'condition': labels, 'subject': ['sub-133'] * len(labels)})

# epochs
epochs = mne.Epochs(raw_hb, events, event_id=event_dict,
                    metadata=metadata,
                    tmin=-5.0, tmax=12.0,
                    baseline=(None, 0),
                    preload=True,
                    reject=dict(hbo=80e-6),
                    verbose=False)

Setting channel interpolation method to {'fnirs': 'nearest'}.


  raw.interpolate_bads()    # default method -> 'nearest'    copies nearest channel


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']


In [63]:
print(events[:5])
print(event_dict)

[[19823     0     3]
 [21212     0     1]
 [22438     0     4]
 [23551     0     2]
 [24671     0     3]]
{'HANV': 1, 'HAPV': 2, 'LANV': 3, 'LAPV': 4}


## Feature engineering

In [64]:
# Only pick HbO channels
hbo_epochs = epochs.copy().pick(picks="hbo")

X = hbo_epochs.get_data()[:, :, 250:-1]  # from 0s to 12s

# Why 250? → 5s × 50Hz = 250
print(X.shape)  # (40, 24, 600)

(40, 24, 600)


In [65]:
ch_names = hbo_epochs.ch_names  # ['S5_D3 hbo', ..., 'S10_D8 hbo']
n_trials, n_channels, n_times = X.shape

# Flatten time series for each trial
X_flat = X.reshape(n_trials, n_channels * n_times)

# Create column names like: 'S5_D3 hbo_t1', ..., 'S5_D3 hbo_t600'
columns = [f'{ch}_t{t+1}' for ch in ch_names for t in range(n_times)]

# Make dataframe (each channels 600 HBO values)
df = pd.DataFrame(X_flat, columns=columns)
df.head(2)

Unnamed: 0,S1_D1 hbo_t1,S1_D1 hbo_t2,S1_D1 hbo_t3,S1_D1 hbo_t4,S1_D1 hbo_t5,S1_D1 hbo_t6,S1_D1 hbo_t7,S1_D1 hbo_t8,S1_D1 hbo_t9,S1_D1 hbo_t10,...,S10_D8 hbo_t591,S10_D8 hbo_t592,S10_D8 hbo_t593,S10_D8 hbo_t594,S10_D8 hbo_t595,S10_D8 hbo_t596,S10_D8 hbo_t597,S10_D8 hbo_t598,S10_D8 hbo_t599,S10_D8 hbo_t600
0,-1.41128e-07,-1.419696e-07,-1.428005e-07,-1.436206e-07,-1.444298e-07,-1.452279e-07,-1.46015e-07,-1.467908e-07,-1.475555e-07,-1.483089e-07,...,1.249216e-07,1.257428e-07,1.265623e-07,1.273802e-07,1.281964e-07,1.290108e-07,1.298233e-07,1.306339e-07,1.314425e-07,1.32249e-07
1,-9.442113e-09,-9.975142e-09,-1.05133e-08,-1.105656e-08,-1.160487e-08,-1.215823e-08,-1.271658e-08,-1.327993e-08,-1.38482e-08,-1.442136e-08,...,-1.028862e-07,-1.026222e-07,-1.023605e-07,-1.02101e-07,-1.018437e-07,-1.015887e-07,-1.01336e-07,-1.010856e-07,-1.008376e-07,-1.005919e-07


In [66]:
# Add labels from metadata
df['label'] = hbo_epochs.metadata['condition'].map(event_dict)

# Optional: add subject ID
df['subject'] = hbo_epochs.metadata['subject']

In [67]:
df.head()

Unnamed: 0,S1_D1 hbo_t1,S1_D1 hbo_t2,S1_D1 hbo_t3,S1_D1 hbo_t4,S1_D1 hbo_t5,S1_D1 hbo_t6,S1_D1 hbo_t7,S1_D1 hbo_t8,S1_D1 hbo_t9,S1_D1 hbo_t10,...,S10_D8 hbo_t593,S10_D8 hbo_t594,S10_D8 hbo_t595,S10_D8 hbo_t596,S10_D8 hbo_t597,S10_D8 hbo_t598,S10_D8 hbo_t599,S10_D8 hbo_t600,label,subject
0,-1.41128e-07,-1.419696e-07,-1.428005e-07,-1.436206e-07,-1.444298e-07,-1.452279e-07,-1.46015e-07,-1.467908e-07,-1.475555e-07,-1.483089e-07,...,1.265623e-07,1.273802e-07,1.281964e-07,1.290108e-07,1.298233e-07,1.306339e-07,1.314425e-07,1.32249e-07,3,sub-133
1,-9.442113e-09,-9.975142e-09,-1.05133e-08,-1.105656e-08,-1.160487e-08,-1.215823e-08,-1.271658e-08,-1.327993e-08,-1.38482e-08,-1.442136e-08,...,-1.023605e-07,-1.02101e-07,-1.018437e-07,-1.015887e-07,-1.01336e-07,-1.010856e-07,-1.008376e-07,-1.005919e-07,1,sub-133
2,4.832059e-08,4.847087e-08,4.861552e-08,4.875452e-08,4.888786e-08,4.901548e-08,4.913743e-08,4.925366e-08,4.936418e-08,4.946894e-08,...,-1.018578e-07,-1.022616e-07,-1.026671e-07,-1.030743e-07,-1.034831e-07,-1.038935e-07,-1.043054e-07,-1.047188e-07,4,sub-133
3,-5.593837e-09,-6.21848e-09,-6.848186e-09,-7.482923e-09,-8.122568e-09,-8.767094e-09,-9.416401e-09,-1.00704e-08,-1.072901e-08,-1.139213e-08,...,-3.075147e-08,-3.080947e-08,-3.086588e-08,-3.092065e-08,-3.097374e-08,-3.102514e-08,-3.107478e-08,-3.112263e-08,2,sub-133
4,-8.244687e-08,-8.327193e-08,-8.409038e-08,-8.49021e-08,-8.570694e-08,-8.650473e-08,-8.729536e-08,-8.807867e-08,-8.885453e-08,-8.96228e-08,...,1.605167e-07,1.602209e-07,1.599317e-07,1.596489e-07,1.593726e-07,1.591029e-07,1.588398e-07,1.585834e-07,3,sub-133


In [68]:
df.shape

(40, 14402)

# Steps for all subjects

In [90]:
# Path to dataset
data_path = Path("data/nemo-bids")
task = "empe"

# List of all subject folders
subject_dirs = sorted([p.name for p in (data_path).glob("sub-*")])
print(f"Found {len(subject_dirs)} subjects.")

all_dfs = []

for subject in subject_dirs:
    try:
        print(f"Processing {subject}...")

        snirf_file = data_path / subject / "nirs" / f"{subject}_task-{task}_nirs.snirf"
        raw = mne.io.read_raw_snirf(snirf_file, preload=True, verbose=False)

        # 1. SCI-based bad channel detection
        sci = scalp_coupling_index(raw)
        raw.info["bads"] = list(compress(raw.ch_names, sci < 0.8))
        raw.interpolate_bads()  # default 'nearest'

        # 2. Motion correction
        raw_od = raw  # already OD
        raw_od = tddr(raw_od)

        # 3. Convert to HbO/HbR
        raw_hb = beer_lambert_law(raw_od, ppf=6.0)

        # 4. Band-pass filtering
        raw_hb = raw_hb.filter(l_freq=0.01, h_freq=0.1,
                               l_trans_bandwidth=0.004,
                               h_trans_bandwidth=0.01,
                               verbose=False)

        # 5. Extract events and build metadata
        events, event_dict = mne.events_from_annotations(raw_hb)
        event_ids = [e[-1] for e in events]
        labels = [k for id_ in event_ids for k, v in event_dict.items() if v == id_]
        metadata = pd.DataFrame({'condition': labels, 'subject': [subject] * len(labels)})

        # 6. Epoching
        epochs = mne.Epochs(raw_hb, events, event_id=event_dict,
                            metadata=metadata,
                            tmin=-5.0, tmax=12.0,
                            baseline=(None, 0),
                            preload=True,
                            reject=dict(hbo=80e-6),
                            verbose=False)

        # 7. Pick only HbO channels
        hbo_epochs = epochs.copy().pick(picks="hbo")

        # 8. Cut to 0–12s post-stimulus (drop pre-stim)
        X = hbo_epochs.get_data()[:, :, 250:-1]  # (trials, channels, 600 samples)

        ch_names = hbo_epochs.ch_names
        n_trials, n_channels, n_times = X.shape
        X_flat = X.reshape(n_trials, n_channels * n_times)

        # Build column names like: 'S5_D3 hbo_t1', ..., 'S10_D8 hbo_t600'
        columns = [f'{ch}_t{t+1}' for ch in ch_names for t in range(n_times)]
        df = pd.DataFrame(X_flat, columns=columns)

        # Add label (1-4)
        # event_dict: {'HANV': 1, 'HAPV': 2, 'LANV': 3, 'LAPV': 4}
        df['label'] = hbo_epochs.metadata['condition'].map(event_dict) 

        # Add subject ID
        df['subject'] = subject

        all_dfs.append(df)

    except Exception as e:
        print(f"Error processing {subject}: {e}")

# Combine all subjects into one big DataFrame
df_all_subjects = pd.concat(all_dfs, ignore_index=True)
print(" All subjects processed and combined.")


Found 31 subjects.
Processing sub-101...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-105...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-107...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-108...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-109...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-112...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-113...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-119...
Setting channel interpolation method to {'fnirs': 'nearest'}.


  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-120...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-121...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-123...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-124...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-125...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-126...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-127...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-129...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-130...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-131...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-133...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-134...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-139...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-140...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-141...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-142...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-143...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-144...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-145...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-146...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-147...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-148...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
Processing sub-149...
Setting channel interpolation method to {'fnirs': 'nearest'}.
Interpolating bad channels.
    Automatic origin fit: head of radius 168.7 mm


  raw.interpolate_bads()  # 'nearest'
  raw.interpolate_bads()  # 'nearest'


Used Annotations descriptions: ['HANV', 'HAPV', 'LANV', 'LAPV']
✅ All subjects processed and combined.


In [91]:
df_all_subjects.tail()

Unnamed: 0,S1_D1 hbo_t1,S1_D1 hbo_t2,S1_D1 hbo_t3,S1_D1 hbo_t4,S1_D1 hbo_t5,S1_D1 hbo_t6,S1_D1 hbo_t7,S1_D1 hbo_t8,S1_D1 hbo_t9,S1_D1 hbo_t10,...,S10_D8 hbo_t593,S10_D8 hbo_t594,S10_D8 hbo_t595,S10_D8 hbo_t596,S10_D8 hbo_t597,S10_D8 hbo_t598,S10_D8 hbo_t599,S10_D8 hbo_t600,label,subject
1198,-4.053866e-08,-4.124364e-08,-4.19498e-08,-4.265702e-08,-4.336523e-08,-4.407438e-08,-4.478438e-08,-4.549517e-08,-4.620669e-08,-4.691891e-08,...,-2.661735e-08,-2.679198e-08,-2.696465e-08,-2.713535e-08,-2.730408e-08,-2.747078e-08,-2.763546e-08,-2.779808e-08,3,sub-149
1199,-2.608959e-08,-2.623023e-08,-2.636795e-08,-2.650277e-08,-2.663465e-08,-2.676362e-08,-2.688966e-08,-2.701273e-08,-2.713282e-08,-2.724996e-08,...,-4.736868e-08,-4.752574e-08,-4.768331e-08,-4.784143e-08,-4.800008e-08,-4.815923e-08,-4.831891e-08,-4.84791e-08,2,sub-149
1200,-4.286421e-08,-4.302318e-08,-4.318004e-08,-4.333484e-08,-4.34876e-08,-4.363829e-08,-4.378692e-08,-4.39335e-08,-4.407805e-08,-4.422057e-08,...,-2.589318e-08,-2.565082e-08,-2.54097e-08,-2.516989e-08,-2.493146e-08,-2.469448e-08,-2.445903e-08,-2.422514e-08,4,sub-149
1201,-1.328963e-08,-1.332561e-08,-1.336543e-08,-1.340915e-08,-1.345685e-08,-1.350859e-08,-1.356442e-08,-1.362443e-08,-1.368866e-08,-1.375718e-08,...,2.333366e-07,2.338698e-07,2.344028e-07,2.349356e-07,2.354682e-07,2.360004e-07,2.365321e-07,2.370634e-07,1,sub-149
1202,1.443433e-08,1.439158e-08,1.43457e-08,1.429662e-08,1.424432e-08,1.418878e-08,1.412995e-08,1.406786e-08,1.400247e-08,1.393378e-08,...,1.109637e-07,1.109368e-07,1.109052e-07,1.108689e-07,1.108278e-07,1.10782e-07,1.107314e-07,1.106759e-07,3,sub-149


In [92]:
df_all_subjects['label'].value_counts()

label
3    301
4    301
2    301
1    300
Name: count, dtype: int64

## windowing

In [93]:
# Identify HBO time-series columns (exclude subject and label)
signal_cols = df_all_subjects.columns.difference(['subject', 'label']).tolist()

#################### Order channels correctly ############################
# Extract unique channel names (e.g., S1_D1) from time-series columns
def sort_key(ch):
    nums = re.findall(r'\d+', ch)  # extract numbers like ['1', '1']
    return list(map(int, nums))    # turn into [1, 1]

channel_names = sorted(
    {re.match(r'(.*) hbo_t\d+', col).group(1) for col in signal_cols},
    key=sort_key
)

# Sanity check: 24 channels expected
assert len(channel_names) == 24, f"Expected 24 channels, got {len(channel_names)}"

# Sort time steps per channel to maintain consistent order
channel_to_times = {
    ch: sorted([col for col in signal_cols if col.startswith(ch)]) for ch in channel_names
}

############################### Windowing ##################################

# Feature extraction
features = []
for _, row in df_all_subjects.iterrows():
    trial_data = np.array([
        row[channel_to_times[ch]].values for ch in channel_names
    ])  # Shape: (24, 600)
    
    # Split into 3 windows
    means = []
    for w in range(3):
        window = trial_data[:, w*200:(w+1)*200]  # 200 samples per window
        window_mean = window.mean(axis=1)        # shape: (24,)
        means.append(window_mean)
    
    # Concatenate into (72,) vector
    trial_features = np.concatenate(means)
    features.append(trial_features)

# Create proper column names like S1_D1_w1, S1_D1_w2, S1_D1_w3 ...
feature_names = [f"{ch}_w{w+1}" for ch in channel_names for w in range(3)]

# Final DataFrame
df = pd.DataFrame(features, columns=feature_names) # hbo channels (72 columns)
df['label'] = df_all_subjects['label'].values # add label
df['subject'] = df_all_subjects['subject'].values # add subject 

# bring subject and label to front
cols = ['subject', 'label'] + feature_names
df = df[cols]

df.head()

Unnamed: 0,subject,label,S1_D1_w1,S1_D1_w2,S1_D1_w3,S1_D2_w1,S1_D2_w2,S1_D2_w3,S2_D1_w1,S2_D1_w2,...,S9_D6_w3,S9_D8_w1,S9_D8_w2,S9_D8_w3,S10_D7_w1,S10_D7_w2,S10_D7_w3,S10_D8_w1,S10_D8_w2,S10_D8_w3
0,sub-101,3,-7.120779e-08,-1.591482e-08,4.037125e-08,1.059506e-07,-1.591482e-08,5.576864e-09,-5.74561e-08,1.158969e-09,...,-6.338075e-08,4.680183e-08,-6.375286e-08,1.427296e-08,4.680183e-08,-7.726111e-09,1.427296e-08,-7.726111e-09,1.427296e-08,-7.726111e-09
1,sub-101,1,-7.105098e-08,-1.428847e-07,-1.66149e-07,-1.067971e-07,-1.428847e-07,-1.549473e-07,-1.100738e-07,-9.783376e-08,...,-9.825674e-08,-1.156452e-07,-1.293089e-07,-1.468958e-07,-1.156452e-07,-1.101812e-07,-1.468958e-07,-1.101812e-07,-1.468958e-07,-1.101812e-07
2,sub-101,4,6.854295e-08,1.563502e-07,1.207582e-07,1.066198e-07,1.563502e-07,1.385606e-07,1.21298e-07,4.591666e-08,...,4.139978e-08,2.738268e-08,7.855582e-08,1.718898e-08,2.738268e-08,4.968935e-09,1.718898e-08,4.968935e-09,1.718898e-08,4.968935e-09
3,sub-101,2,1.206235e-07,1.431428e-07,2.246828e-08,2.433219e-08,1.431428e-07,1.236011e-07,1.035702e-07,1.868319e-09,...,-1.115965e-07,-2.759947e-08,-1.808735e-07,5.235214e-08,-2.759947e-08,1.151776e-08,5.235214e-08,1.151776e-08,5.235214e-08,1.151776e-08
4,sub-101,3,-1.319688e-07,-1.376738e-07,-2.00167e-07,-1.70707e-07,-1.376738e-07,-6.784068e-08,-1.378276e-07,-6.926143e-08,...,-1.588279e-07,-1.109203e-07,-2.869832e-07,-1.373283e-07,-1.109203e-07,-1.356308e-07,-1.373283e-07,-1.356308e-07,-1.373283e-07,-1.356308e-07


In [94]:
df.columns.unique() # 72 + 2 cols

Index(['subject', 'label', 'S1_D1_w1', 'S1_D1_w2', 'S1_D1_w3', 'S1_D2_w1',
       'S1_D2_w2', 'S1_D2_w3', 'S2_D1_w1', 'S2_D1_w2', 'S2_D1_w3', 'S2_D3_w1',
       'S2_D3_w2', 'S2_D3_w3', 'S3_D1_w1', 'S3_D1_w2', 'S3_D1_w3', 'S3_D2_w1',
       'S3_D2_w2', 'S3_D2_w3', 'S3_D3_w1', 'S3_D3_w2', 'S3_D3_w3', 'S3_D4_w1',
       'S3_D4_w2', 'S3_D4_w3', 'S4_D2_w1', 'S4_D2_w2', 'S4_D2_w3', 'S4_D4_w1',
       'S4_D4_w2', 'S4_D4_w3', 'S5_D3_w1', 'S5_D3_w2', 'S5_D3_w3', 'S5_D4_w1',
       'S5_D4_w2', 'S5_D4_w3', 'S6_D5_w1', 'S6_D5_w2', 'S6_D5_w3', 'S6_D6_w1',
       'S6_D6_w2', 'S6_D6_w3', 'S7_D5_w1', 'S7_D5_w2', 'S7_D5_w3', 'S7_D7_w1',
       'S7_D7_w2', 'S7_D7_w3', 'S8_D5_w1', 'S8_D5_w2', 'S8_D5_w3', 'S8_D6_w1',
       'S8_D6_w2', 'S8_D6_w3', 'S8_D7_w1', 'S8_D7_w2', 'S8_D7_w3', 'S8_D8_w1',
       'S8_D8_w2', 'S8_D8_w3', 'S9_D6_w1', 'S9_D6_w2', 'S9_D6_w3', 'S9_D8_w1',
       'S9_D8_w2', 'S9_D8_w3', 'S10_D7_w1', 'S10_D7_w2', 'S10_D7_w3',
       'S10_D8_w1', 'S10_D8_w2', 'S10_D8_w3'],
      dtype='objec

In [95]:
df['label'].value_counts()

label
3    301
4    301
2    301
1    300
Name: count, dtype: int64

In [99]:
print("Shape:", df.shape)
print("Unique labels:", df['label'].unique())
print("Subjects:", df['subject'].nunique(), "total trials:", df.shape[0])

Shape: (1203, 74)
Unique labels: [3 1 4 2]
Subjects: 31 total trials: 1203


## save to csv's

In [98]:
df.to_csv("all_subs_windowed.csv", index=False)
df_all_subjects.to_csv("all_subs_600timepoints.csv", index=False)