<a href="https://www.kaggle.com/code/axha241419/preprocessing?scriptVersionId=288851252" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# COG-BCI Dataset Preprocessing & Windowing

**Purpose**: Convert raw COG-BCI `.set` files to processed `.pkl` files ready for training.

This notebook:
- Loads raw COG-BCI `.set` files (BIDS format)
- **Detects and interpolates bad channels** via abnormal variance
- Applies preprocessing (band-pass filter, notch filter, average reference, resampling)
- Creates 1-second windows with 0.25s hop
- **Rejects artifact-contaminated windows** via amplitude thresholding
- Computes band-power features (delta, theta, alpha, beta, gamma)
- Computes **next-second labels** for prediction tasks
- Performs subject-wise train/val/test splits
- Computes and saves training normalization statistics
- Saves processed data as `.pkl` files for training pipeline

## 1. Install Dependencies

In [None]:
!pip install mne pandas scipy numpy


## 2. Configuration

In [None]:
import os
import glob
import pickle
import numpy as np
import pandas as pd


# -------- NumPy compatibility patches - MUST run BEFORE importing MNE --------
# Patch 1: np.trapz (removed in NumPy 2.0)
if not hasattr(np, "trapz"):
    def _trapz(y, x=None, dx=1.0, axis=-1):
        y = np.asarray(y)
        if x is None:
            return dx * (y[..., 1:] + y[..., :-1]).sum(axis=axis) / 2.0
        else:
            x = np.asarray(x)
            dx = np.diff(x, axis=axis)
            return (dx * (y[..., 1:] + y[..., :-1]) / 2.0).sum(axis=axis)
    np.trapz = _trapz

# Patch 2: np.in1d (removed in NumPy 2.0, replaced by np.isin)
if not hasattr(np, "in1d"):
    def _in1d(ar1, ar2, assume_unique=False, invert=False):
        return np.isin(ar1, ar2, assume_unique=assume_unique, invert=invert)
    np.in1d = _in1d

# NOW import MNE and scipy
import mne
from scipy.signal import welch


# Configuration
# Point to all three dataset directories
RAW_DATA_DIRS = [
    "/kaggle/input/cognivue-raw-data",
    "/kaggle/input/cog-eeg-dataset",
    "/kaggle/input/cog-eeg-dataset-3"
]

print("Dataset directories:")
for d in RAW_DATA_DIRS:
    print(f"  - {d}")
    
SUBJECT_SPLITS_CSV = "/kaggle/input/cog-config/subject_splits.csv"

PROCESSED_DIR = "/kaggle/working/data/processed"
os.makedirs(PROCESSED_DIR, exist_ok=True)
for split in ["train", "val", "test"]:
    os.makedirs(os.path.join(PROCESSED_DIR, split), exist_ok=True)
DONE_FLAGS_DIR = os.path.join(PROCESSED_DIR, "done_flags")
os.makedirs(DONE_FLAGS_DIR, exist_ok=True)

TRAIN_STATS_PATH = os.path.join(PROCESSED_DIR, "train_stats.json")

# --- SAMPLING & WINDOWING ---
RAW_SAMPLING_RATE = 512
PROCESSED_SAMPLING_RATE = 256
WINDOW_SIZE_SEC = 1.0
HOP_SIZE_SEC = 0.25
WINDOW_SIZE_SAMPLES = int(WINDOW_SIZE_SEC * PROCESSED_SAMPLING_RATE)  # 256
HOP_SIZE_SAMPLES = int(HOP_SIZE_SEC * PROCESSED_SAMPLING_RATE)        # 64

NUM_CHANNELS = 58

# --- ARTIFACT REJECTION & BAD CHANNEL THRESHOLDS ---
ARTIFACT_THRESHOLD_UV = 100.0  # Microvolts - reject windows exceeding this amplitude
BAD_CHANNEL_STD_THRESHOLD = 5.0  # Standard deviations from mean variance for bad channel detection

# --- FREQUENCY BANDS ---
FREQUENCY_BANDS = {
    "delta": (1, 4),
    "theta": (4, 8),
    "alpha": (8, 13),
    "beta":  (13, 30),
    "gamma": (30, 45),
}

# --- TASK ID MAPPING (BIDS task labels) ---
TASK_NAME_TO_ID = {
    "nback": 0,
    "matb": 1,
    "pvt":  2,
    "flanker": 3,
}

# --- CHANNEL -> REGION MAP ---
CHANNEL_TO_REGION = {
    "Fp1": "frontal", "Fp2": "frontal", "AF3": "frontal", "AF4": "frontal",
    "AF7": "frontal", "AF8": "frontal", "F1": "frontal", "F2": "frontal",
    "F3": "frontal", "F4": "frontal", "F5": "frontal", "F6": "frontal",
    "F7": "frontal", "F8": "frontal", "Fz": "fronto-central", "FC1": "fronto-central",
    "FC2": "fronto-central", "FC3": "fronto-central", "FC4": "fronto-central",
    "FC5": "fronto-central", "FC6": "fronto-central", "FCz": "fronto-central",
    "C1": "central", "C2": "central", "C3": "central", "C4": "central",
    "C5": "central", "C6": "central", "Cz": "central", "T7": "temporal-left",
    "TP7": "temporal-left", "FT7": "temporal-left", "T8": "temporal-right",
    "TP8": "temporal-right", "FT8": "temporal-right", "CP1": "parietal",
    "CP2": "parietal", "CP3": "parietal", "CP4": "parietal", "CP5": "parietal",
    "CP6": "parietal", "CPz": "parietal", "P1": "parietal", "P2": "parietal",
    "P3": "parietal", "P4": "parietal", "P5": "parietal", "P6": "parietal",
    "P7": "parietal", "P8": "parietal", "Pz": "parietal", "PO3": "parietal",
    "PO4": "parietal", "PO7": "parietal", "PO8": "parietal", "POz": "parietal",
    "O1": "occipital", "O2": "occipital", "Oz": "occipital",
}

REGIONS = {
    "frontal": 0, "fronto-central": 1, "central": 2,
    "temporal-left": 3, "temporal-right": 4,
    "parietal": 5, "occipital": 6,
}

CHANNEL_TO_REGION_ID = {ch: REGIONS.get(reg, 0) for ch, reg in CHANNEL_TO_REGION.items()}
CHANNEL_NAMES = list(CHANNEL_TO_REGION.keys())
CHANNEL_INDEX = {ch: i for i, ch in enumerate(CHANNEL_NAMES)}

print(f"Configuration loaded:")
print(f"  - Artifact threshold: {ARTIFACT_THRESHOLD_UV} ¬µV")
print(f"  - Bad channel detection: {BAD_CHANNEL_STD_THRESHOLD} std from mean variance")
print(f"  - Window size: {WINDOW_SIZE_SEC}s ({WINDOW_SIZE_SAMPLES} samples)")
print(f"  - Hop size: {HOP_SIZE_SEC}s ({HOP_SIZE_SAMPLES} samples)")



In [None]:
## 2.5 Performance Optimization

# Enable parallel processing
import os
os.environ['MNE_LOGGING_LEVEL'] = 'WARNING'  # Reduce logging overhead
os.environ['OMP_NUM_THREADS'] = '4'  # OpenMP threads
os.environ['MKL_NUM_THREADS'] = '4'  # Intel MKL threads

# Use all available CPU cores
N_JOBS = -1  # -1 = use all available cores (4 on Kaggle CPU)

print(f"Parallel processing enabled: using {os.cpu_count()} CPU cores")
print(f"MNE will use n_jobs={N_JOBS} for filtering and resampling")

## 3. Subject Split Loading

In [None]:
def load_subject_splits(csv_path):
    """
    CSV format:
    subject_id,split
    sub-01,train
    sub-02,train
    ...
    sub-26,test
    """
    df = pd.read_csv(csv_path)
    mapping = {}
    for _, row in df.iterrows():
        mapping[row["subject_id"]] = row["split"]  # "train" / "val" / "test"
    return mapping

SUBJECT_SPLITS = load_subject_splits(SUBJECT_SPLITS_CSV)
print("Subject splits:", SUBJECT_SPLITS)



## 4. Preprocessing Functions

In [None]:
def detect_and_interpolate_bad_channels(raw):
    """
    Detect bad channels via abnormal variance and interpolate from neighbors.
    Returns the Raw object with bad channels interpolated.
    """
    # Get data for variance calculation
    data = raw.get_data(picks='eeg')
    
    # Compute variance per channel
    variances = np.var(data, axis=1)
    
    # Find outliers (channels with variance > threshold * std from mean)
    mean_var = np.mean(variances)
    std_var = np.std(variances)
    threshold = mean_var + BAD_CHANNEL_STD_THRESHOLD * std_var
    
    bad_indices = np.where(variances > threshold)[0]
    
    if len(bad_indices) > 0:
        bad_ch_names = [raw.ch_names[i] for i in bad_indices]
        print(f"  Detected {len(bad_ch_names)} bad channels: {bad_ch_names}")
        raw.info['bads'] = bad_ch_names
        raw.interpolate_bads(reset_bads=True, verbose=False)
    
    return raw

def preprocess_raw(raw):
    """
    Band-pass 1‚Äì40, notch at 60/120/180, detect/interpolate bad channels,
    average reference, resample 512‚Üí256.
    **OPTIMIZED with parallel processing**
    """
    # Apply filters first (OPTIMIZED - added n_jobs)
    raw.filter(l_freq=1, h_freq=40, method="fir", phase="zero", 
               n_jobs=N_JOBS, verbose=False)
    raw.notch_filter(freqs=[60, 120, 180], n_jobs=N_JOBS, verbose=False)
    
    # Detect and interpolate bad channels
    raw = detect_and_interpolate_bad_channels(raw)
    
    # Apply average reference and resample (OPTIMIZED - added n_jobs)
    raw.set_eeg_reference("average", projection=False, verbose=False)
    if raw.info["sfreq"] != PROCESSED_SAMPLING_RATE:
        raw.resample(PROCESSED_SAMPLING_RATE, n_jobs=N_JOBS, verbose=False)
    return raw

def compute_bandpower(window_data, sfreq):
    """
    window_data: shape (n_channels, n_samples) in volts
    returns: (n_channels, 5) log-power features
    """
    n_channels, n_samples = window_data.shape
    bp_features = np.zeros((n_channels, len(FREQUENCY_BANDS)), dtype=np.float32)

    for ch in range(n_channels):
        # If this channel is all zeros or NaNs, give default low power
        if not np.any(np.isfinite(window_data[ch])) or not np.any(window_data[ch]):
            bp_features[ch, :] = -10.0
            continue

        freqs, psd = welch(window_data[ch], fs=sfreq, nperseg=n_samples)

        # Clean PSD: remove NaNs and negatives
        psd = np.nan_to_num(psd, nan=0.0)
        psd[psd < 0] = 0.0

        for i, (band, (low, high)) in enumerate(FREQUENCY_BANDS.items()):
            idx = np.logical_and(freqs >= low, freqs <= high)

            # If no freq in band, use default value
            if not np.any(idx):
                bp_features[ch, i] = -10.0
                continue

            band_vals = psd[idx]

            # If all zero after masking, avoid log10(0)
            if not np.any(band_vals):
                bp_features[ch, i] = -10.0
                continue

            mean_psd = float(np.mean(band_vals))
            bp_features[ch, i] = np.log10(mean_psd + 1e-10).astype(np.float32)

    return bp_features

def compute_dominant_labels(bandpower, channel_names):
    """
    bandpower: (n_channels, 5) for one window
    channel_names: list of actual channel names in this recording
    returns: dict with y_channel, y_region, y_band, y_state
    """
    total_power_ch = np.sum(bandpower, axis=1)       # over bands
    dom_ch_idx = int(np.argmax(total_power_ch))
    
    # Get the actual channel name from the available channels
    if dom_ch_idx < len(channel_names):
        dom_ch_name = channel_names[dom_ch_idx]
        region_id = CHANNEL_TO_REGION_ID.get(dom_ch_name, 0)
    else:
        dom_ch_name = channel_names[0] if channel_names else "unknown"
        region_id = 0

    total_power_band = np.sum(bandpower, axis=0)     # over channels
    dom_band_idx = int(np.argmax(total_power_band))

    theta = float(total_power_band[1])
    alpha = float(total_power_band[2])
    beta  = float(total_power_band[3])
    delta = float(total_power_band[0])

    state_id = 3  # neutral
    if alpha > 1e-10:
        ratio_theta_alpha = theta / alpha
        ratio_beta_alpha = beta / alpha
        if (1.0 <= ratio_theta_alpha <= 1.5) and (ratio_beta_alpha > 1.2):
            state_id = 0  # Focused
        elif ratio_theta_alpha > 1.5:
            state_id = 1  # Drift
        elif (delta + theta) > 2 * alpha:
            state_id = 2  # Drowsy

    return {
        "y_channel": dom_ch_idx,
        "y_region": region_id,
        "y_band": dom_band_idx,
        "y_state": state_id,
    }

## 5. Window Creation with Artifact Rejection & Next-Second Labels

In [None]:
 def create_windows_from_raw(raw, task_id, subject_id, split_name):
    """
    raw: preprocessed MNE Raw, sfreq=256, EEG channels
    task_id: int
    split_name: "train"/"val"/"test"

    Returns list of sample dicts:
    {
        "X":   np.array (n_channels, 256) raw window
        "bp":  np.array (n_channels, 5) bandpower for same window
        "task_idx": int
        y_channel, y_region, y_band, y_state from NEXT window
    }
    """
    data = raw.get_data(picks="eeg")  # (n_channels, n_samples)
    n_channels, n_samples = data.shape
    channel_names = [ch for ch in raw.ch_names if ch in CHANNEL_TO_REGION]

    samples = []
    artifacts_rejected = 0
    bad_bp_rejected = 0
    start = 0

    while start + WINDOW_SIZE_SAMPLES * 2 <= n_samples:
        # current window [t, t+1]
        w_cur = data[:, start: start + WINDOW_SIZE_SAMPLES]
        # next window [t+1, t+2] for labels
        w_next = data[:, start + WINDOW_SIZE_SAMPLES: start + 2 * WINDOW_SIZE_SAMPLES]

        # 1) Artifact rejection (amplitude)
        max_amplitude_uv = np.abs(w_cur).max() * 1e6
        if max_amplitude_uv > ARTIFACT_THRESHOLD_UV:
            artifacts_rejected += 1
            start += HOP_SIZE_SAMPLES
            continue

        # 2) Bandpower computation with safety
        bp_cur = compute_bandpower(w_cur, PROCESSED_SAMPLING_RATE).astype(np.float32)
        bp_next = compute_bandpower(w_next, PROCESSED_SAMPLING_RATE)

        # If bandpower is completely default (all -10) for either window, skip
        if not np.any(bp_cur > -9.9) or not np.any(bp_next > -9.9):
            bad_bp_rejected += 1
            start += HOP_SIZE_SAMPLES
            continue

        # 3) Labels from next window
        labels = compute_dominant_labels(bp_next, channel_names)

        sample = {
            "X": w_cur.astype(np.float32),  # (n_channels, 256)
            "bp": bp_cur,
            "task_idx": int(task_id),
            "y_channel": labels["y_channel"],
            "y_region": labels["y_region"],
            "y_band": labels["y_band"],
            "y_state": labels["y_state"],
            "subject_id": subject_id,
            "split": split_name,
            "n_channels": n_channels,
            "channel_names": channel_names,
        }
        samples.append(sample)

        start += HOP_SIZE_SAMPLES

    if artifacts_rejected > 0:
        print(f" Rejected {artifacts_rejected} artifact-contaminated windows")
    if bad_bp_rejected > 0:
        print(f" Rejected {bad_bp_rejected} windows due to invalid bandpower")

    return samples


## 6. Scan Raw Files (BIDS) and Process All Subjects

In [None]:
def parse_bids_info_from_path(path):
    """
    Your paths look like:
      /kaggle/input/cog-eeg-dataset-3/sub-21/ses-S1/eeg/zeroBACK.set
      /kaggle/input/cog-eeg-dataset-3/sub-21/ses-S1/eeg/MATBeasy.set
      /kaggle/input/cog-eeg-dataset-3/sub-21/ses-S1/eeg/Flanker.set
      /kaggle/input/cog-eeg-dataset-3/sub-21/ses-S1/eeg/PVT.set
    etc.

    Returns:
      subject_id: "sub-21"
      task_name: one of "nback", "matb", "flanker", "pvt"
                 (used with TASK_NAME_TO_ID)
    """
    parts = path.split(os.sep)

    # subject folder: first part that starts with "sub-"
    subject_id = next(p for p in parts if p.startswith("sub-"))

    # file name without extension, e.g. "zeroBACK", "MATBeasy", "Flanker"
    fname = os.path.splitext(os.path.basename(path))[0]
    task_lower = fname.lower()

    # Map your filenames to the 4 task families
    if "back" in task_lower:          # zeroBACK / oneBACK / twoBACK
        task_name = "nback"
    elif "matb" in task_lower:        # MATBeasy / MATBmed / MATBdiff
        task_name = "matb"
    elif "flanker" in task_lower:
        task_name = "flanker"
    elif "pvt" in task_lower:
        task_name = "pvt"
    else:
        # Unknown task; caller will skip this file
        task_name = None

    return subject_id, task_name

def find_all_eeg_files(raw_dirs):
    """
    Scan multiple directories for .set files
    raw_dirs: list of directory paths
    """
    all_files = []
    for raw_dir in raw_dirs:
        pattern = os.path.join(raw_dir, "**", "*.set")
        files = glob.glob(pattern, recursive=True)
        all_files.extend(files)
        print(f"  Found {len(files)} files in {os.path.basename(raw_dir)}")
    return all_files

# Call with list of directories
all_files = find_all_eeg_files(RAW_DATA_DIRS)
print(f"\nTotal EEG files across all datasets: {len(all_files)}")



## 7. First Pass: Build All Samples in Memory to Compute Train Mean/Std

In [None]:

#  First Pass with Progress Tracking (RESUMABLE)


from tqdm import tqdm
import time

all_samples = []
skipped_files = []
processing_stats = {
    'total_files': len(all_files),
    'processed': 0,
    'skipped': 0,
    'total_windows': 0,
    'start_time': time.time()
}

print(f"Processing {len(all_files)} files...\n")

for idx, path in enumerate(tqdm(all_files, desc="Processing files")):
    subject_id, task_name = parse_bids_info_from_path(path)

    # ---------- per-file checkpoint ----------
    flag_name = f"{os.path.basename(path).replace('.set','')}_{subject_id}.done"
    flag_path = os.path.join(DONE_FLAGS_DIR, flag_name)
    if os.path.exists(flag_path):
        # already successfully processed in a previous run
        continue
    # ----------------------------------------

    if subject_id not in SUBJECT_SPLITS:
        skipped_files.append({'path': path, 'subject': subject_id, 'reason': 'not_in_splits'})
        processing_stats['skipped'] += 1
        continue

    split_name = SUBJECT_SPLITS[subject_id]

    if task_name is None or task_name not in TASK_NAME_TO_ID:
        skipped_files.append({'path': path, 'subject': subject_id, 'reason': 'unknown_task'})
        processing_stats['skipped'] += 1
        continue

    task_id = TASK_NAME_TO_ID[task_name]

    # Only print every 10th file to reduce clutter
    if idx % 10 == 0:
        print(f"\n[{idx+1}/{len(all_files)}] {os.path.basename(path)} | {subject_id} | {task_name} | {split_name}")

    try:
        raw = mne.io.read_raw_eeglab(path, preload=True, verbose=False)

        available_channels = [ch for ch in CHANNEL_NAMES if ch in raw.ch_names]
        missing_channels = [ch for ch in CHANNEL_NAMES if ch not in raw.ch_names]

        if len(available_channels) < 50:
            skipped_files.append({
                'path': path,
                'subject': subject_id,
                'reason': f'insufficient_channels_{len(available_channels)}'
            })
            processing_stats['skipped'] += 1
            continue

        raw.pick(available_channels)
        raw = preprocess_raw(raw)
        subject_samples = create_windows_from_raw(raw, task_id, subject_id, split_name)
        all_samples.extend(subject_samples)

        processing_stats['processed'] += 1
        processing_stats['total_windows'] += len(subject_samples)

        # mark this file as successfully processed
        with open(flag_path, "w") as f:
            f.write(str(len(subject_samples)))

        # Print summary every 10 files
        if idx % 10 == 0:
            elapsed = time.time() - processing_stats['start_time']
            rate = processing_stats['processed'] / elapsed if elapsed > 0 else 0
            print(f" ‚úì {len(subject_samples)} windows | Total: {processing_stats['total_windows']} | Rate: {rate:.2f} files/sec")

    except Exception as e:
        print(f" ‚úó ERROR: {str(e)[:80]}")
        skipped_files.append({'path': path, 'subject': subject_id, 'reason': str(e)[:100]})
        processing_stats['skipped'] += 1
        continue

# Final summary
elapsed_total = time.time() - processing_stats['start_time']
print(f"\n{'='*70}")
print("PROCESSING COMPLETE")
print(f" Total time: {elapsed_total/60:.1f} minutes ({elapsed_total:.0f} seconds)")
print(f" Files processed: {processing_stats['processed']}/{processing_stats['total_files']}")
print(f" Files skipped: {processing_stats['skipped']}")
print(f" Total windows generated: {processing_stats['total_windows']}")
if processing_stats['processed'] > 0:
    print(f" Average rate: {processing_stats['processed']/elapsed_total:.2f} files/sec")
    print(f" Average time per file: {elapsed_total/processing_stats['processed']:.1f} seconds")
print(f"{'='*70}\n")

if len(skipped_files) > 0:
    print("Skipped files breakdown:")
    skip_reasons = {}
    for sf in skipped_files:
        reason = sf['reason'].split('_')[0]
        skip_reasons[reason] = skip_reasons.get(reason, 0) + 1
    for reason, count in skip_reasons.items():
        print(f" - {reason}: {count} files")


## for debugging

In [None]:
"""# DEBUG: inspect shapes of all TRAIN windows
train_windows = [s["X"] for s in all_samples if s["split"] == "train"]
print("Total train windows:", len(train_windows))

shape_counts = {}
for w in train_windows:
    shape_counts[w.shape] = shape_counts.get(w.shape, 0) + 1

print("Unique shapes and counts:")
for shp, cnt in shape_counts.items():
    print(f"  {shp}: {cnt}")
"""

In [None]:
"""print("Total samples in all_samples:", len(all_samples))
print("Train samples:", sum(1 for s in all_samples if s["split"] == "train"))
print("Val samples:",   sum(1 for s in all_samples if s["split"] == "val"))
print("Test samples:",  sum(1 for s in all_samples if s["split"] == "test"))
"""

## 8. Compute Train-Set Mean / Std Per Channel 

In [None]:
# ============================================================================
# SECTION 8: Compute Train-Set Mean / Std Per Channel
# ============================================================================

import json

print("\n" + "="*70)
print("SECTION 8: COMPUTING TRAIN STATISTICS")
print("="*70)

# Check what we have in all_samples
train_count = sum(1 for s in all_samples if s["split"] == "train")
val_count = sum(1 for s in all_samples if s["split"] == "val")
test_count = sum(1 for s in all_samples if s["split"] == "test")

print(f"\nTotal windows collected:")
print(f"  Train: {train_count:,}")
print(f"  Val:   {val_count:,}")
print(f"  Test:  {test_count:,}")
print(f"  TOTAL: {len(all_samples):,}\n")

if train_count == 0:
    raise RuntimeError(
        "No training samples found! Check that subjects 01-20 are being processed."
    )

# Extract training windows
print("Computing statistics from training data...")
train_windows = [s["X"] for s in all_samples if s["split"] == "train"]

# Filter by expected shape
EXPECTED_SHAPE = (NUM_CHANNELS, WINDOW_SIZE_SAMPLES)
filtered_train_windows = [w for w in train_windows if w.shape == EXPECTED_SHAPE]
dropped = len(train_windows) - len(filtered_train_windows)

print(f"  Using {len(filtered_train_windows):,} windows for statistics")
if dropped > 0:
    print(f"  Dropped {dropped} windows with incorrect shape")

if len(filtered_train_windows) == 0:
    raise RuntimeError(
        f"No train windows with expected shape {EXPECTED_SHAPE}. "
        f"Check NUM_CHANNELS={NUM_CHANNELS} and WINDOW_SIZE_SAMPLES={WINDOW_SIZE_SAMPLES}"
    )

# Stack windows: (N, n_channels, 256)
train_windows_np = np.stack(filtered_train_windows, axis=0)
print(f"  Stacked shape: {train_windows_np.shape}")

# Flatten time dimension: (n_channels, N*256)
flat = train_windows_np.transpose(1, 0, 2).reshape(train_windows_np.shape[1], -1)

# Compute statistics
train_mean = flat.mean(axis=1).astype(np.float32)
train_std = flat.std(axis=1).astype(np.float32)
train_std[train_std < 1e-6] = 1.0  # Avoid division by zero

print(f"  ‚úì Computed mean shape: {train_mean.shape}")
print(f"  ‚úì Computed std shape: {train_std.shape}")

# Save stats
stats = {
    "channel_names": CHANNEL_NAMES,
    "mean": train_mean.tolist(),
    "std": train_std.tolist(),
}
with open(TRAIN_STATS_PATH, "w") as f:
    json.dump(stats, f)

print(f"  ‚úì Saved stats to: {TRAIN_STATS_PATH}")
print("="*70 + "\n")

## 9. Apply Normalization and Save Split-Wise PKL Files

In [None]:
def normalize_window(window, mean, std):
    """
    window: (n_channels, 256)
    mean/std: (n_channels,)
    z-score per channel
    """
    n_ch = window.shape[0]
    return (window - mean[:n_ch, None]) / std[:n_ch, None]


split_samples = {"train": [], "val": [], "test": []}

EXPECTED_CHANNELS = train_mean.shape[0]  # should be 58
skipped_bad_shape = 0

for s in all_samples:
    w = s["X"]  # (n_channels, 256)

    # Skip windows whose channel count does not match the stats
    if w.shape[0] != EXPECTED_CHANNELS:
        skipped_bad_shape += 1
        continue

    norm_w = normalize_window(w, train_mean, train_std).astype(np.float32)
    s["X"] = norm_w
    split_samples[s["split"]].append(s)

print(f"Kept {sum(len(v) for v in split_samples.values())} samples after normalization; "
      f"skipped {skipped_bad_shape} with mismatched channels.")


# 10. FINAL STEP: SAVE PROCESSED SAMPLES TO PKL FILES

In [None]:
# ============================================================================
# SECTION 10: SAVE PROCESSED SAMPLES TO .PKL FILES
# ============================================================================

print("\n" + "="*70)
print("SAVING PROCESSED SAMPLES")
print("="*70)

for split_name in ["train", "val", "test"]:
    samples = split_samples.get(split_name, [])
    
    if len(samples) == 0:
        print(f"\n‚ö†Ô∏è  {split_name.upper()}: No samples found - SKIPPING")
        continue
    
    # Create output directory
    output_dir = os.path.join(PROCESSED_DIR, split_name)
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, f"{split_name}_data.pkl")
    
    # Save to pickle
    print(f"\nüì¶ {split_name.upper()}:")
    print(f"   Saving {len(samples):,} samples...")
    
    with open(output_path, "wb") as f:
        pickle.dump(samples, f)
    
    file_size_mb = os.path.getsize(output_path) / (1024 * 1024)
    
    print(f"   ‚úì Saved successfully")
    print(f"   ‚úì File: {output_path}")
    print(f"   ‚úì Size: {file_size_mb:.2f} MB")
    
    # Show sample info
    if len(samples) > 0:
        sample = samples[0]
        print(f"   ‚úì Sample format:")
        print(f"      - X shape: {sample['X'].shape} (normalized EEG)")
        print(f"      - bp shape: {sample['bp'].shape} (bandpower features)")
        print(f"      - Subjects: {len(set(s['subject_id'] for s in samples))} unique")

print("\n" + "="*70)
print("‚úÖ PROCESSING COMPLETE!")
print("="*70)

# Final summary
total_samples = sum(len(v) for v in split_samples.values())
print(f"\nüìä FINAL SUMMARY:")
print(f"   Total samples: {total_samples:,}")
print(f"   Train samples: {len(split_samples.get('train', [])):,}")
print(f"   Val samples: {len(split_samples.get('val', [])):,}")
print(f"   Test samples: {len(split_samples.get('test', [])):,}")
print(f"\nüìÅ Output files saved to: {PROCESSED_DIR}")
print(f"   - train_data.pkl")
print(f"   - val_data.pkl")
print(f"   - test_data.pkl")
print(f"   - train_stats.json")
print("\n‚úì Ready to download from Kaggle Output!")
print("="*70 + "\n")

# 11. Verification

In [None]:
# ============================================================================
# SECTION 11: FINAL VERIFICATION
# ============================================================================

print("\n" + "="*70)
print("VERIFICATION CHECK")
print("="*70)

verification_passed = True

# Check train_stats.json
print("\n1. Train Statistics:")
if os.path.exists(TRAIN_STATS_PATH):
    size_kb = os.path.getsize(TRAIN_STATS_PATH) / 1024
    print(f"   ‚úì train_stats.json exists ({size_kb:.2f} KB)")
else:
    print(f"   ‚úó train_stats.json MISSING!")
    verification_passed = False

# Check .pkl files
print("\n2. Processed Data Files:")
for split_name in ["train", "val", "test"]:
    pkl_path = os.path.join(PROCESSED_DIR, split_name, f"{split_name}_data.pkl")
    if os.path.exists(pkl_path):
        size_mb = os.path.getsize(pkl_path) / (1024*1024)
        print(f"   ‚úì {split_name}_data.pkl exists ({size_mb:.2f} MB)")
    else:
        print(f"   ‚úó {split_name}_data.pkl MISSING!")
        verification_passed = False

# Check done flags
print("\n3. Processing Flags:")
done_count = len(glob.glob(os.path.join(DONE_FLAGS_DIR, "*.done")))
print(f"   ‚úì {done_count} files marked as processed")

print("\n" + "="*70)
if verification_passed:
    print("‚úÖ‚úÖ‚úÖ ALL CHECKS PASSED - READY TO DOWNLOAD! ‚úÖ‚úÖ‚úÖ")
else:
    print("‚ùå‚ùå‚ùå VERIFICATION FAILED - CHECK ERRORS ABOVE ‚ùå‚ùå‚ùå")
print("="*70 + "\n")
```
