In [1]:

import os
import math
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import GroupKFold
try:
    from sklearn.model_selection import StratifiedGroupKFold
except ImportError:
    StratifiedGroupKFold = None
from sklearn.metrics import classification_report, accuracy_score

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [None]:

from imblearn.over_sampling import BorderlineSMOTE
from imblearn.under_sampling import TomekLinks, EditedNearestNeighbours
from sklearn.preprocessing import QuantileTransformer, RobustScaler
from sklearn.decomposition import PCA
from sklearn.feature_selection import mutual_info_classif, RFE

print("Model 6 imbalance handling libraries loaded")



Model 6 imbalance handling libraries loaded


In [None]:
# Configuration
DEFAULT_DATASET_ROOT = Path("./Datasets")
DATASET_ROOT = Path(os.getenv("DATASET_ROOT", DEFAULT_DATASET_ROOT))
STATES = ["STRESS", "AEROBIC", "ANAEROBIC"]
TARGET_FS = 4.0
WINDOW_SECONDS = 60
WINDOW_STEP_SECONDS = 30
MIN_LABEL_COVERAGE = 0.6
SEED = 42
MAX_SUBJECTS = None
APPLY_CHANNEL_NORMALIZATION = True
APPLY_DIFF_CHANNELS = True

APPLY_TEMPORAL_AUG = False

TEMPORAL_AUG_COUNTS = {}
LABEL_SMOOTHING = 0.05
EMA_DECAY = 0.995
TWO_STAGE_THRESHOLD = 0.4


GROUP_SPLIT = True
NUM_FOLDS = 5
FOLD_INDEX = 0
USE_STRATIFIED_GROUP_SPLIT = True


In [None]:

STRESS_STAGE_ORDER_S = ["Stroop", "TMCT", "Real Opinion", "Opposite Opinion", "Subtract"]
STRESS_STAGE_ORDER_F = ["TMCT", "Real Opinion", "Opposite Opinion", "Subtract"]
STRESS_TAG_PAIRS_S = [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)]
STRESS_TAG_PAIRS_F = [(2, 3), (4, 5), (6, 7), (8, 9)]
STRESS_PHASES = {"Stroop", "TMCT", "Real Opinion", "Opposite Opinion", "Subtract"}
STRESS_LEVEL_BOUNDS = {"low": 3.0, "moderate": 6.0}
STRESS_LEVEL_PHASE_BOUNDS = {
    "Stroop": {"low": 2.5, "moderate": 5.0},
    "Opposite Opinion": {"low": 2.5, "moderate": 5.5},
    "Real Opinion": {"low": 2.8, "moderate": 5.5},
    "TMCT": {"low": 2.8, "moderate": 5.8},
    "Subtract": {"low": 2.8, "moderate": 5.8},
}
STRESS_LEVEL_FILES = ["Stress_Level_v1.csv", "Stress_Level_v2.csv"]
MIN_LABEL_COVERAGE = 0.6


def load_stress_levels():
    levels = {}
    for fname in STRESS_LEVEL_FILES:
        path = Path(fname)
        if not path.exists():
            continue
        df = pd.read_csv(path, index_col=0)
        df.columns = [str(c).strip() for c in df.columns]
        for subject, row in df.iterrows():
            subj = str(subject).strip()
            levels[subj] = {
                col: (float(row[col]) if not pd.isna(row[col]) else np.nan)
                for col in df.columns
            }
    return levels


STRESS_LEVELS = load_stress_levels()


def base_subject_id(subject: str) -> str:
    return subject.split("_")[0]


def read_signal(path: Path):
    with open(path, "r") as f:
        start_line = f.readline().strip()
        if not start_line:
            raise ValueError(f"Missing start timestamp in {path}")
        start_ts = pd.to_datetime(start_line.split(",")[0])
        fs_line = f.readline().strip()
        if not fs_line:
            raise ValueError(f"Missing sample rate in {path}")
        fs = float(fs_line.split(",")[0])
        data = np.genfromtxt(f, delimiter=",")
    data = np.asarray(data, dtype=float)
    data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
    if data.ndim == 0:
        data = data.reshape(1, 1)
    return fs, data.squeeze(), start_ts


def read_tags(path: Path, start_ts: pd.Timestamp):
    if not path.exists():
        return []
    df = pd.read_csv(path, header=None)
    tags = []
    for ts_str in df[0].astype(str):
        ts = pd.to_datetime(ts_str)
        tags.append((ts - start_ts).total_seconds())
    return [(t, t) for t in tags]


def stress_intervals_from_tags(tags, subject):
    if not tags:
        return []
    times = [t for t, _ in tags]
    if subject.startswith("S"):
        idx_pairs = STRESS_TAG_PAIRS_S
        stage_order = STRESS_STAGE_ORDER_S
    else:
        idx_pairs = STRESS_TAG_PAIRS_F
        stage_order = STRESS_STAGE_ORDER_F
    base_id = base_subject_id(subject)
    spans = []
    for stage, (i, j) in zip(stage_order, idx_pairs):
        if i < len(times) and j < len(times) and times[j] > times[i]:
            level = STRESS_LEVELS.get(base_id, {}).get(stage)
            spans.append({"start": times[i], "end": times[j], "stage": stage, "stress_level": level})
    return spans


def active_intervals_from_tags(tags):
    if len(tags) < 2:
        return []
    spans = []
    for (a, _), (b, _) in zip(tags[:-1], tags[1:]):
        if b > a:
            spans.append({"start": a, "end": b, "stage": "active", "stress_level": 0.0})
    return spans


def stress_bucket(level: float = None, phase: str = None) -> str:
    if phase in {"aerobic", "anaerobic", "rest", "active"}:
        return "no_stress"
    if level is None or pd.isna(level) or level <= 0:
        return "no_stress"
    bounds = STRESS_LEVEL_PHASE_BOUNDS.get(phase, STRESS_LEVEL_BOUNDS)
    if level <= bounds["low"]:
        return "low_stress"
    if level <= bounds["moderate"]:
        return "moderate_stress"
    return "high_stress"


def resample_to_rate(signal: np.ndarray, src_fs: float, tgt_fs: float) -> np.ndarray:
    if signal.ndim == 1:
        signal = signal[:, None]
    signal = np.nan_to_num(signal, nan=0.0, posinf=0.0, neginf=0.0)
    src_len = signal.shape[0]
    duration = src_len / src_fs
    tgt_len = int(duration * tgt_fs)
    if tgt_len <= 0:
        return np.zeros((0, signal.shape[1]), dtype=np.float32)
    src_t = np.linspace(0, duration, src_len, endpoint=False)
    tgt_t = np.linspace(0, duration, tgt_len, endpoint=False)
    resampled = np.vstack([
        np.interp(tgt_t, src_t, signal[:, i])
        for i in range(signal.shape[1])
    ]).T.astype(np.float32)
    resampled = np.nan_to_num(resampled, nan=0.0, posinf=0.0, neginf=0.0)
    if resampled.shape[1] == 1:
        return resampled[:, 0]
    return resampled


def window_intervals(duration: float, win_s: int, step_s: int):
    windows = []
    t = 0.0
    while t + win_s <= duration:
        windows.append((t, t + win_s))
        t += step_s
    return windows


def assign_label(win, intervals):
    start, end = win
    length = end - start
    best_label = None
    best_cov = 0.0
    best_span = None
    for label, spans in intervals.items():
        overlap = 0.0
        span_choice = None
        span_overlap = 0.0
        for span in spans:
            a = span["start"] if isinstance(span, dict) else span[0]
            b = span["end"] if isinstance(span, dict) else span[1]
            inter = max(0.0, min(end, b) - max(start, a))
            if inter > 0:
                overlap += inter
                if inter > span_overlap:
                    span_overlap = inter
                    span_choice = span
        coverage = overlap / length
        if coverage > best_cov:
            best_cov = coverage
            best_label = label
            best_span = span_choice
    if best_cov >= MIN_LABEL_COVERAGE and best_label is not None:
        return best_label, best_span
    return None, None


def make_label_intervals(state: str, subject: str, tags, duration: float):
    rest_span = [{"start": 0.0, "end": duration, "stage": "rest", "stress_level": 0.0}]
    if state == "STRESS":
        stress_spans = stress_intervals_from_tags(tags, subject)
        if not stress_spans:
            return {"rest": rest_span}
        return {"stress": stress_spans, "rest": rest_span}
    active = active_intervals_from_tags(tags)
    label = "aerobic" if state == "AEROBIC" else "anaerobic"
    if not active:
        return {label: rest_span, "rest": rest_span}
    return {label: active, "rest": rest_span}


def load_subject_state(state: str, subject: str):
    folder = DATASET_ROOT / state / subject
    if not folder.exists():
        raise FileNotFoundError(folder)
    fs_eda, eda_raw, start_ts = read_signal(folder / "EDA.csv")
    temp_path = folder / "TEMP.csv"
    if temp_path.exists():
        fs_temp, temp_raw, _ = read_signal(temp_path)
    else:
        fs_temp, temp_raw = fs_eda, np.zeros_like(eda_raw)
    fs_acc, acc_raw, _ = read_signal(folder / "ACC.csv")
    acc_raw = np.atleast_2d(acc_raw)
    acc_mag = np.linalg.norm(acc_raw, axis=1)
    bvp_path = folder / "BVP.csv"
    if bvp_path.exists():
        fs_bvp, bvp_raw, _ = read_signal(bvp_path)
    else:
        fs_bvp, bvp_raw = None, None
    tags = read_tags(folder / "tags.csv", start_ts)
    sensors = {
        "EDA": np.nan_to_num(np.asarray(eda_raw, dtype=float), nan=0.0, posinf=0.0, neginf=0.0),
        "TEMP": np.nan_to_num(np.asarray(temp_raw, dtype=float), nan=0.0, posinf=0.0, neginf=0.0),
        "ACC_MAG": np.nan_to_num(acc_mag, nan=0.0, posinf=0.0, neginf=0.0),
    }
    if bvp_raw is not None:
        sensors["BVP"] = np.nan_to_num(np.asarray(bvp_raw, dtype=float), nan=0.0, posinf=0.0, neginf=0.0)
    fs_map = {"EDA": fs_eda, "TEMP": fs_temp, "ACC_MAG": fs_acc}
    if fs_bvp:
        fs_map["BVP"] = fs_bvp
    duration = len(sensors["EDA"]) / fs_eda
    return {"sensors": sensors, "fs": fs_map, "tags": tags, "duration": duration}


In [8]:
EXPECTED_LEN = int(WINDOW_SECONDS * TARGET_FS)
BASE_CHANNELS = ["EDA", "TEMP", "ACC", "BVP"]
PHASE_ENCODING = {
    "baseline": 0,
    "rest": 0,
    "stress": 1,
    "stroop": 2,
    "tmct": 3,
    "real opinion": 4,
    "opposite opinion": 5,
    "subtract": 6,
    "aerobic": 7,
    "anaerobic": 8,
    "active": 9,
}
NUMERIC_STABILITY_EPS = 1e-6


def sanitize_array(array, dtype=np.float32):
    if array is None:
        return None
    return np.nan_to_num(np.asarray(array, dtype=dtype), nan=0.0, posinf=0.0, neginf=0.0)


def safe_corrcoef(a, b, eps=NUMERIC_STABILITY_EPS):
    if a is None or b is None or not len(a) or not len(b):
        return 0.0
    a_std = float(np.std(a))
    b_std = float(np.std(b))
    if a_std < eps or b_std < eps:
        return 0.0
    corr = np.corrcoef(a, b)[0, 1]
    if np.isnan(corr) or np.isinf(corr):
        return 0.0
    return float(corr)


def _slice_or_pad(signal: np.ndarray, start: int, end: int) -> np.ndarray:
    length = end - start
    if signal is None or len(signal) == 0:
        return np.zeros(length, dtype=np.float32)
    if end > len(signal):
        pad = end - len(signal)
        segment = signal[start: len(signal)]
        if pad > 0:
            segment = np.concatenate([segment, np.zeros(pad, dtype=segment.dtype)])
    else:
        segment = signal[start:end]
    segment = np.asarray(segment, dtype=np.float32)
    return sanitize_array(segment)


def extract_respiratory_signal(bvp, fs=4.0):
    from scipy.signal import butter, filtfilt, welch
    from scipy.integrate import trapezoid
    if bvp is None:
        return np.zeros(EXPECTED_LEN, dtype=np.float32), 0.0
    if not len(bvp):
        return np.zeros_like(bvp), 0.0
    ny = fs / 2
    low, high = 0.15 / ny, 0.4 / ny
    b, a = butter(4, [max(low, 0.001), min(high, 0.99)], btype='band')
    filtered = filtfilt(b, a, bvp)
    freqs, psd = welch(bvp, fs=fs, nperseg=min(128, len(bvp)))
    mask = (freqs >= 0.15) & (freqs <= 0.4)
    resp_power = trapezoid(psd[mask], freqs[mask]) if mask.any() else 0.0
    return sanitize_array(filtered), float(resp_power)


def sample_entropy_signal(bvp, fs=4.0):
    from scipy.signal import find_peaks
    bvp = sanitize_array(bvp)
    peaks, _ = find_peaks(bvp, distance=max(int(0.5 * fs), 1), prominence=0.5 * np.std(bvp))
    if len(peaks) < 5:
        return np.zeros_like(bvp)
    ibi = np.diff(peaks) / fs
    if len(ibi) < 5:
        return np.zeros_like(bvp)
    window = max(int(10 * fs), 1)
    entropy_signal = np.zeros_like(bvp)
    for i in range(len(bvp)):
        start = max(0, i - window)
        segment_peaks = peaks[(peaks >= start) & (peaks < i)]
        if len(segment_peaks) < 5:
            continue
        seg_ibi = np.diff(segment_peaks) / fs
        if len(seg_ibi) < 5:
            continue
        m = 2
        r = max(0.2 * np.std(seg_ibi), NUMERIC_STABILITY_EPS)
        count_m = 0
        count_m1 = 0
        for j in range(len(seg_ibi) - m):
            template = seg_ibi[j:j + m]
            for k in range(j + 1, len(seg_ibi) - m):
                if np.max(np.abs(template - seg_ibi[k:k + m])) <= r:
                    count_m += 1
                    if np.abs(seg_ibi[j + m] - seg_ibi[k + m]) <= r:
                        count_m1 += 1
        if count_m > 0 and count_m1 > 0:
            entropy_signal[i] = -np.log(count_m1 / count_m)
    return sanitize_array(entropy_signal)


def add_wavelet_channels(raw_channels, wavelet='db4', level=3):
    import pywt
    raw_channels = sanitize_array(raw_channels)
    wavelet_channels = []
    for ch in raw_channels:
        coeffs = pywt.wavedec(ch, wavelet, level=level)
        for coeff in coeffs[1:]:
            coeff = sanitize_array(coeff)
            if len(coeff) < EXPECTED_LEN:
                coeff = np.pad(coeff, (0, EXPECTED_LEN - len(coeff)), mode='edge')
            elif len(coeff) > EXPECTED_LEN:
                coeff = coeff[:EXPECTED_LEN]
            wavelet_channels.append(coeff.astype(np.float32))
    if not wavelet_channels:
        return np.zeros((0, EXPECTED_LEN), dtype=np.float32)
    stacked = np.vstack(wavelet_channels)
    return sanitize_array(stacked)


def extract_window_features(eda, temp, acc, bvp):
    eda = sanitize_array(eda)
    temp = sanitize_array(temp)
    acc = sanitize_array(acc)
    bvp = sanitize_array(bvp)
    signals = [(eda, 'eda'), (temp, 'temp'), (acc, 'acc'), (bvp, 'bvp')]
    feats = []
    for signal, _ in signals:
        feats.extend([
            float(np.mean(signal)),
            float(np.std(signal)),
            float(np.min(signal)),
            float(np.max(signal)),
            float(np.percentile(signal, 25)),
            float(np.percentile(signal, 75)),
        ])
        diff = np.diff(signal) if len(signal) > 1 else np.zeros(1, dtype=np.float32)
        feats.extend([
            float(np.mean(diff)) if len(diff) else 0.0,
            float(np.std(diff)) if len(diff) else 0.0,
            float(np.max(np.abs(diff))) if len(diff) else 0.0,
        ])
    feats.append(safe_corrcoef(eda, bvp))
    feats.append(safe_corrcoef(eda, acc))
    return sanitize_array(np.array(feats, dtype=np.float32))


def phase_id_from_label(label: str) -> int:
    if label is None:
        return 0
    key = label.lower()
    return PHASE_ENCODING.get(key, 0)


def extract_enhanced_channels(eda, temp, acc, bvp, fs=4.0):
    from scipy.signal import hilbert
    eda = sanitize_array(eda)
    temp = sanitize_array(temp)
    acc = sanitize_array(acc)
    bvp = sanitize_array(bvp)
    base_channels = [eda, temp, acc, bvp]
    channels = base_channels.copy()
    channels.extend([
        np.diff(eda, prepend=eda[0]),
        np.diff(temp, prepend=temp[0]),
        np.diff(acc, prepend=acc[0]),
        np.diff(bvp, prepend=bvp[0]),
    ])
    window_tonic = int(10 * fs)
    eda_tonic = np.convolve(eda, np.ones(window_tonic) / window_tonic, mode='same')
    eda_phasic = eda - eda_tonic
    channels.extend([eda_tonic, eda_phasic])
    eda_diff = np.diff(eda, prepend=eda[0])
    eda_accel = np.diff(eda_diff, prepend=eda_diff[0])
    channels.append(eda_accel)
    window_short = int(5 * fs)
    window_long = int(15 * fs)
    eda_ma_short = np.convolve(eda, np.ones(window_short) / window_short, mode='same')
    eda_ma_long = np.convolve(eda, np.ones(window_long) / window_long, mode='same')
    channels.extend([eda_ma_short, eda_ma_long])
    bvp_envelope = np.abs(hilbert(bvp))
    channels.append(bvp_envelope)
    channels.append(eda * bvp)
    acc_smoothed = np.convolve(acc, np.ones(int(3 * fs)) / (3 * fs), mode='same')
    channels.append(acc_smoothed)
    resp_signal, _ = extract_respiratory_signal(bvp, fs)
    channels.append(resp_signal)
    entropy_signal = sample_entropy_signal(bvp, fs)
    channels.append(entropy_signal)
    wavelet_extra = add_wavelet_channels(np.array(base_channels))
    channels.extend(list(wavelet_extra))
    channels = [sanitize_array(ch) for ch in channels]
    stacked = np.stack(channels, axis=0).astype(np.float32)
    return sanitize_array(stacked)


def build_sequence_dataset(states: List[str] = STATES, max_subjects: int = 0):
    sequences = []
    labels = []
    subjects = []
    feature_vectors = []
    phase_ids = []
    for state in states:
        state_dir = DATASET_ROOT / state
        if not state_dir.exists():
            continue
        subject_ids = sorted([p.name for p in state_dir.iterdir() if p.is_dir()])
        if max_subjects and max_subjects > 0:
            subject_ids = subject_ids[:max_subjects]
        for subj in tqdm(subject_ids, desc=f"{state}"):
            try:
                info = load_subject_state(state, subj)
            except Exception as exc:
                print(f"Skip {state}/{subj}: {exc}")
                continue
            sensors = info["sensors"]
            fs_map = info["fs"]
            tags = info["tags"]
            duration = info["duration"]

            eda = resample_to_rate(sensors["EDA"], fs_map["EDA"], TARGET_FS)
            temp = resample_to_rate(
                sensors.get("TEMP", np.zeros_like(eda)),
                fs_map.get("TEMP", TARGET_FS),
                TARGET_FS,
            ) if "TEMP" in sensors else np.zeros_like(eda)
            acc = resample_to_rate(sensors["ACC_MAG"], fs_map["ACC_MAG"], TARGET_FS)
            if "BVP" in sensors and "BVP" in fs_map:
                bvp = resample_to_rate(sensors["BVP"], fs_map["BVP"], TARGET_FS)
            else:
                bvp = np.zeros_like(eda)

            eda = sanitize_array(eda)
            temp = sanitize_array(temp)
            acc = sanitize_array(acc)
            bvp = sanitize_array(bvp)

            intervals = make_label_intervals(state, subj, tags, duration)
            windows = window_intervals(duration, WINDOW_SECONDS, WINDOW_STEP_SECONDS)

            for win in windows:
                label_name, span_meta = assign_label(win, intervals)
                if label_name is None or span_meta is None:
                    continue
                start_idx = int(round(win[0] * TARGET_FS))
                end_idx = start_idx + EXPECTED_LEN
                eda_win = sanitize_array(_slice_or_pad(eda, start_idx, end_idx))
                temp_win = sanitize_array(_slice_or_pad(temp, start_idx, end_idx))
                acc_win = sanitize_array(_slice_or_pad(acc, start_idx, end_idx))
                bvp_win = sanitize_array(_slice_or_pad(bvp, start_idx, end_idx))

                stress_stage = span_meta.get("stage") if isinstance(span_meta, dict) else None
                stress_level = span_meta.get("stress_level") if isinstance(span_meta, dict) else None
                if label_name == "stress":
                    if stress_level is None or np.isnan(stress_level):
                        continue
                else:
                    stress_level = 0.0
                phase_label = stress_stage if stress_stage else label_name
                stress_class = stress_bucket(stress_level, phase_label)

                tensor = sanitize_array(extract_enhanced_channels(eda_win, temp_win, acc_win, bvp_win, TARGET_FS))
                stats = sanitize_array(extract_window_features(eda_win, temp_win, acc_win, bvp_win))
                sequences.append(tensor)
                labels.append(stress_class)
                subjects.append(base_subject_id(subj))
                feature_vectors.append(stats)
                phase_ids.append(phase_id_from_label(phase_label))

    sequences = sanitize_array(np.stack(sequences))
    labels = np.array(labels)
    subjects = np.array(subjects)
    feature_vectors = sanitize_array(np.stack(feature_vectors))
    phase_ids = np.array(phase_ids, dtype=np.int64)
    return sequences, labels, subjects, feature_vectors, phase_ids


In [9]:
sequences, labels, subjects, feature_vectors, phase_ids = build_sequence_dataset(max_subjects=MAX_SUBJECTS if MAX_SUBJECTS else 0)

num_channels = sequences.shape[1]
print("=" * 80)
print("DATASET STATISTICS")
print("=" * 80)
print(f"Raw sequences: {sequences.shape}")
print(f"Aux features: {feature_vectors.shape}")
print(f"Phase IDs: {np.unique(phase_ids)}")
print(f"Label distribution:")
label_dist = pd.Series(labels).value_counts().sort_index()
for label, count in label_dist.items():
    pct = 100 * count / len(labels)
    print(f"  {label:20s}: {count:5d} ({pct:5.1f}%)")

if APPLY_CHANNEL_NORMALIZATION and sequences.size:
    print("" + "=" * 80)
    print("APPLYING SUBJECT-SPECIFIC BASELINE NORMALIZATION")
    print("=" * 80)
    normalized = sequences.copy()
    for subject in np.unique(subjects):
        subject_mask = subjects == subject
        rest_mask = subject_mask & (labels == 'no_stress')
        if rest_mask.sum() > 0:
            baseline_mean = sequences[rest_mask].mean(axis=(0, 2), keepdims=True)
            baseline_std = sequences[rest_mask].std(axis=(0, 2), keepdims=True) + 1e-6
            print(f"  {subject:8s}: {rest_mask.sum():4d} rest windows → baseline")
        else:
            baseline_mean = sequences[subject_mask].mean(axis=(0, 2), keepdims=True)
            baseline_std = sequences[subject_mask].std(axis=(0, 2), keepdims=True) + 1e-6
            print(f"  {subject:8s}: {subject_mask.sum():4d} total windows (NO REST DATA)")
        normalized[subject_mask] = (sequences[subject_mask] - baseline_mean) / baseline_std
    sequences = normalized
    print("✓ Subject-specific normalization complete")

print("" + "=" * 80)
print(f"FINAL DATASET: {sequences.shape}")
print(f"Feature matrix: {feature_vectors.shape}")
print(f"Subjects: {len(np.unique(subjects))}")
print("=" * 80)


STRESS:   0%|          | 0/37 [00:00<?, ?it/s]

Skip STRESS/f14_a: No columns to parse from file


AEROBIC:   0%|          | 0/31 [00:00<?, ?it/s]

ANAEROBIC:   0%|          | 0/32 [00:00<?, ?it/s]

DATASET STATISTICS
Raw sequences: (6788, 30, 240)
Aux features: (6788, 38)
Phase IDs: [0 2 3 5 9]
Label distribution:
  high_stress         :   462 (  6.8%)
  low_stress          :    91 (  1.3%)
  moderate_stress     :   664 (  9.8%)
  no_stress           :  5571 ( 82.1%)
APPLYING SUBJECT-SPECIFIC BASELINE NORMALIZATION
  S01     :  155 rest windows → baseline
  S02     :  188 rest windows → baseline
  S03     :  122 rest windows → baseline
  S04     :  140 rest windows → baseline
  S05     :  136 rest windows → baseline
  S06     :  124 rest windows → baseline
  S07     :  124 rest windows → baseline
  S08     :  135 rest windows → baseline
  S09     :  137 rest windows → baseline
  S10     :  137 rest windows → baseline
  S11     :  130 rest windows → baseline
  S12     :   69 rest windows → baseline
  S13     :  140 rest windows → baseline
  S14     :  143 rest windows → baseline
  S15     :  138 rest windows → baseline
  S16     :  146 rest windows → baseline
  S17     :  140 rest

In [10]:
class SequenceDataset(Dataset):
    def __init__(self, sequences: np.ndarray, features: np.ndarray, phase_ids: np.ndarray, labels: np.ndarray):
        self.sequences = torch.from_numpy(sequences)
        self.features = torch.from_numpy(features)
        self.phase_ids = torch.from_numpy(phase_ids).long()
        self.labels = torch.from_numpy(labels).long()

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.sequences[idx], self.features[idx], self.phase_ids[idx], self.labels[idx]


In [11]:
class ResNetBlock(nn.Module):
    def __init__(self, channels: int, kernel_size: int = 5, dilation: int = 1):
        super().__init__()
        padding = ((kernel_size - 1) // 2) * dilation
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=padding, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(channels)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size, padding=padding, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)


class MultiScaleSequenceEncoder(nn.Module):
    def __init__(self, input_channels: int, cnn_channels: int = 64):
        super().__init__()
        self.conv_short = nn.Sequential(
            nn.Conv1d(input_channels, cnn_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),
            ResNetBlock(cnn_channels, kernel_size=3, dilation=1),
            ResNetBlock(cnn_channels, kernel_size=3, dilation=1),
        )
        self.conv_medium = nn.Sequential(
            nn.Conv1d(input_channels, cnn_channels, kernel_size=7, padding=3),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),
            ResNetBlock(cnn_channels, kernel_size=7, dilation=2),
            ResNetBlock(cnn_channels, kernel_size=7, dilation=2),
        )
        self.conv_long = nn.Sequential(
            nn.Conv1d(input_channels, cnn_channels, kernel_size=15, padding=7),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),
            ResNetBlock(cnn_channels, kernel_size=15, dilation=4),
            ResNetBlock(cnn_channels, kernel_size=15, dilation=4),
        )
        merged_channels = cnn_channels * 3
        self.merge = nn.Sequential(
            nn.Conv1d(merged_channels, 128, kernel_size=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )
        self.lstm = nn.LSTM(
            input_size=128,
            hidden_size=128,
            num_layers=3,
            dropout=0.3,
            batch_first=True,
            bidirectional=True,
        )
        self.attention = nn.MultiheadAttention(embed_dim=256, num_heads=8, dropout=0.2, batch_first=True)
        self.global_pool = nn.AdaptiveAvgPool1d(1)

    def forward(self, x):
        feat_short = self.conv_short(x)
        feat_medium = self.conv_medium(x)
        feat_long = self.conv_long(x)
        merged = torch.cat([feat_short, feat_medium, feat_long], dim=1)
        merged = self.merge(merged)
        lstm_in = merged.transpose(1, 2)
        lstm_out, _ = self.lstm(lstm_in)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        temporal_features = attn_out.mean(dim=1)
        cnn_features = self.global_pool(merged).squeeze(-1)
        return torch.cat([temporal_features, cnn_features], dim=1)


class PhaseAwareHybridNet(nn.Module):
    def __init__(self, input_channels: int, num_features: int, num_classes: int, num_phases: int = len(PHASE_ENCODING) + 2):
        super().__init__()
        self.sequence_encoder = MultiScaleSequenceEncoder(input_channels)
        self.feature_branch = nn.Sequential(
            nn.Linear(num_features, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.2),
        )
        self.phase_embedding = nn.Embedding(num_phases, 32)
        fusion_in = 384 + 64 + 32
        self.classifier = nn.Sequential(
            nn.Linear(fusion_in, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes),
        )

    def forward(self, sequences, stats, phase_ids):
        seq_feat = self.sequence_encoder(sequences)
        stat_feat = self.feature_branch(stats)
        phase_feat = self.phase_embedding(phase_ids)
        combined = torch.cat([seq_feat, stat_feat, phase_feat], dim=1)
        return self.classifier(combined)


In [None]:

class AsymmetricFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma_per_class=None, label_smoothing=0.0):
        super().__init__()
        self.alpha = alpha
        self.gamma_per_class = gamma_per_class  
        self.label_smoothing = label_smoothing
    
    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(
            logits, targets, reduction='none', label_smoothing=self.label_smoothing
        )
        pt = torch.exp(-ce_loss)
        
        if self.gamma_per_class is not None:
    
            gamma = self.gamma_per_class[targets]
            focal_weight = (1 - pt) ** gamma
        else:
            focal_weight = (1 - pt) ** 2.0
        
        if self.alpha is not None:
            loss = self.alpha[targets] * focal_weight * ce_loss
        else:
            loss = focal_weight * ce_loss
        return loss.mean()


gamma_per_class = torch.tensor([2.5, 3.0, 2.0, 1.0], dtype=torch.float32, device=device)
print("Asymmetric Focal Loss initialized:")
print("  high_stress (0): gamma=2.5")
print("  low_stress (1): gamma=3.0 (hardest class)")
print("  moderate_stress (2): gamma=2.0")
print("  no_stress (3): gamma=1.0 (easiest class)")

Asymmetric Focal Loss initialized:
  high_stress (0): gamma=2.5
  low_stress (1): gamma=3.0 (hardest class)
  moderate_stress (2): gamma=2.0
  no_stress (3): gamma=1.0 (easiest class)


In [None]:

EPOCHS = 40
BATCH_SIZE = 32
LR = 1e-3
WEIGHT_DECAY = 5e-4
MAX_GRAD_NORM = 1.0
USE_MIXED_PRECISION = torch.cuda.is_available()

le = LabelEncoder()
encoded_labels = le.fit_transform(labels)
num_classes = len(le.classes_)

print("=" * 80)
print("LABEL ENCODING")
print("=" * 80)
print("Classes:", dict(zip(le.classes_, range(num_classes))))
print(f"Encoded label distribution:")
for i, class_name in enumerate(le.classes_):
    count = (encoded_labels == i).sum()
    print(f"  {i}: {class_name:20s} → {count:5d} samples")


LABEL ENCODING
Classes: {np.str_('high_stress'): 0, np.str_('low_stress'): 1, np.str_('moderate_stress'): 2, np.str_('no_stress'): 3}
Encoded label distribution:
  0: high_stress          →   462 samples
  1: low_stress           →    91 samples
  2: moderate_stress      →   664 samples
  3: no_stress            →  5571 samples


In [None]:
from sklearn.model_selection import train_test_split, GroupKFold
from sklearn.metrics import precision_score, recall_score, f1_score
from scipy.interpolate import interp1d
from contextlib import nullcontext
import math

print("=" * 80)
print("PHASE 2 PREPARATION: DATA BALANCING + TRAINING CONFIG")
print("=" * 80)

stress_class_names = {"high_stress", "low_stress", "moderate_stress"}
stress_indices = [i for i, name in enumerate(le.classes_) if name in stress_class_names]
no_stress_idx = int(np.where(le.classes_ == 'no_stress')[0][0]) if 'no_stress' in le.classes_ else None
stress_idx_tensor = torch.tensor(stress_indices, device=device) if stress_indices else None


def time_warp_augment(sequence, warp_factor):
    channels, length = sequence.shape
    new_length = int(length * warp_factor)
    warped = []
    for ch in range(channels):
        if new_length < 4:
            warped.append(sequence[ch])
            continue
        f = interp1d(np.arange(length), sequence[ch], kind='cubic', fill_value='extrapolate')
        new_indices = np.linspace(0, length - 1, new_length)
        warped_ch = f(new_indices)
        warped_ch_resampled = np.interp(
            np.arange(length),
            np.linspace(0, length - 1, new_length),
            warped_ch,
        )
        warped.append(warped_ch_resampled)
    return np.array(warped, dtype=np.float32)


def stats_from_sequence(seq):
    eda, temp, acc, bvp = seq[0], seq[1], seq[2], seq[3]
    return extract_window_features(eda, temp, acc, bvp)


def temporal_augmentation(X, feats, phases, y, augment_counts):
    if not len(X):
        return X, feats, phases, y
    X_aug = list(X)
    F_aug = list(feats)
    P_aug = list(phases)
    y_aug = list(y)
    for class_idx, class_name in enumerate(le.classes_):
        if class_name not in augment_counts:
            continue
        class_indices = np.where(y == class_idx)[0]
        if not len(class_indices):
            continue
        aug_factor = augment_counts[class_name]
        print(f"  {class_name}: {len(class_indices)} samples → augmenting {aug_factor}x")
        for idx in class_indices:
            sample = X[idx]
            phase = phases[idx]
            for _ in range(aug_factor):
                warp_factor = np.random.uniform(0.95, 1.05)
                aug_sample = time_warp_augment(sample, warp_factor)
                noise_std = 0.03 * np.std(sample, axis=1, keepdims=True)
                aug_sample = aug_sample + np.random.normal(0, noise_std, aug_sample.shape).astype(np.float32)
                shift_range = max(1, int(0.1 * sample.shape[1]))
                shift = np.random.randint(-shift_range, shift_range)
                aug_sample = np.roll(aug_sample, shift, axis=1)
                aug_sample = np.nan_to_num(aug_sample, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
                X_aug.append(aug_sample)
                F_aug.append(stats_from_sequence(aug_sample))
                P_aug.append(phase)
                y_aug.append(class_idx)
    return (
        np.array(X_aug, dtype=np.float32),
        np.array(F_aug, dtype=np.float32),
        np.array(P_aug, dtype=np.int64),
        np.array(y_aug),
    )


class ModelEMA:
    def __init__(self, model, decay=0.995):
        self.decay = decay
        self.shadow = {
            name: param.detach().clone()
            for name, param in model.named_parameters()
            if param.requires_grad
        }
        self.backup = None

    def update(self, model):
        for name, param in model.named_parameters():
            if not param.requires_grad:
                continue
            self.shadow[name].mul_(self.decay).add_(param.detach(), alpha=1.0 - self.decay)

    def apply_shadow(self, model):
        self.backup = {
            name: param.detach().clone()
            for name, param in model.named_parameters()
            if param.requires_grad
        }
        for name, param in model.named_parameters():
            if name in self.shadow:
                param.data.copy_(self.shadow[name])

    def restore(self, model):
        if self.backup is None:
            return
        for name, param in model.named_parameters():
            if name in self.backup:
                param.data.copy_(self.backup[name])
        self.backup = None



if GROUP_SPLIT:
    splitter_desc = "GroupKFold"
    if USE_STRATIFIED_GROUP_SPLIT and 'StratifiedGroupKFold' in globals() and StratifiedGroupKFold is not None:
        splitter = StratifiedGroupKFold(n_splits=NUM_FOLDS, shuffle=True, random_state=SEED)
        splitter_desc = "StratifiedGroupKFold"
    else:
        splitter = GroupKFold(n_splits=NUM_FOLDS)
    splits = list(splitter.split(sequences, encoded_labels, groups=subjects))
    if FOLD_INDEX >= len(splits):
        raise ValueError(f"FOLD_INDEX {FOLD_INDEX} out of range for {NUM_FOLDS} folds.")
    train_idx, test_idx = splits[FOLD_INDEX]
    split_desc = f"{splitter_desc} fold {FOLD_INDEX + 1}/{NUM_FOLDS}"
else:
    train_idx, test_idx = train_test_split(
        np.arange(len(sequences)),
        test_size=0.2,
        stratify=encoded_labels,
        random_state=SEED,
    )
    split_desc = "Stratified random 80/20 split"

X_train, X_test = sequences[train_idx], sequences[test_idx]
feat_train, feat_test = feature_vectors[train_idx], feature_vectors[test_idx]
phase_train, phase_test = phase_ids[train_idx], phase_ids[test_idx]
y_train, y_test = encoded_labels[train_idx], encoded_labels[test_idx]
train_subjects = subjects[train_idx]
test_subjects = subjects[test_idx]
held_out_subjects = test_subjects.copy()

print(f"Split strategy: {split_desc}")
print(f"  Train windows: {X_train.shape[0]} from {len(np.unique(train_subjects))} subjects")
print(f"  Test windows:  {X_test.shape[0]} from {len(np.unique(test_subjects))} subjects")
print("Train subjects:", ', '.join(sorted(np.unique(train_subjects))))
print("Test subjects:", ', '.join(sorted(np.unique(test_subjects))))

print(f"Dataset split overview:")
print(f"  Train: {X_train.shape}")
print(f"  Test:  {X_test.shape}")

print(f"Train class distribution (BEFORE balancing):")
for i, class_name in enumerate(le.classes_):
    count = (y_train == i).sum()
    pct = 100 * count / len(y_train)
    print(f"  {class_name:20s}: {count:5d} ({pct:5.1f}%)")

print(f"Test class distribution:")
for i, class_name in enumerate(le.classes_):
    count = (y_test == i).sum()
    pct = 100 * count / len(y_test)
    print(f"  {class_name:20s}: {count:5d} ({pct:5.1f}%)")

if APPLY_TEMPORAL_AUG:
    print("=" * 80)
    print("TEMPORAL AUGMENTATION")
    print("=" * 80)
    X_train, feat_train, phase_train, y_train = temporal_augmentation(
        X_train,
        feat_train,
        phase_train,
        y_train,
        TEMPORAL_AUG_COUNTS,
    )

print("=" * 80)
print("FINAL CLASS DISTRIBUTION")
print("=" * 80)
print(f"Train: {X_train.shape}")
for i, class_name in enumerate(le.classes_):
    count = (y_train == i).sum()
    pct = 100 * count / len(y_train)
    print(f"  {class_name:20s}: {count:5d} ({pct:5.1f}%)")

scaler = StandardScaler()
feat_train = scaler.fit_transform(feat_train).astype(np.float32)
feat_test = scaler.transform(feat_test).astype(np.float32)
feat_train = np.nan_to_num(feat_train, nan=0.0, posinf=0.0, neginf=0.0)
feat_test = np.nan_to_num(feat_test, nan=0.0, posinf=0.0, neginf=0.0)
X_train = np.nan_to_num(X_train.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
X_test = np.nan_to_num(X_test.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0)
phase_train = phase_train.astype(np.int64)
phase_test = phase_test.astype(np.int64)

def ensure_long_enough(arr, target_len):
    if arr.ndim != 3:
        return arr
    if arr.shape[2] == target_len:
        return arr
    if arr.shape[2] > target_len:
        return arr[:, :, :target_len]
    pad = target_len - arr.shape[2]
    return np.pad(arr, ((0, 0), (0, 0), (0, pad)), mode='edge')

X_train = ensure_long_enough(X_train, EXPECTED_LEN)
X_test = ensure_long_enough(X_test, EXPECTED_LEN)

train_dataset = SequenceDataset(X_train, feat_train, phase_train, y_train)
test_dataset = SequenceDataset(X_test, feat_test, phase_test, y_test)

class_counts = np.bincount(y_train, minlength=num_classes).clip(min=1).astype(np.float32)
class_weights = len(y_train) / (num_classes * class_counts)
class_weights = np.nan_to_num(class_weights, nan=0.0, posinf=0.0, neginf=0.0)
alpha_tensor = torch.tensor(class_weights, dtype=torch.float32, device=device)

sample_weights = 1.0 / class_counts[y_train]
sample_weights = np.nan_to_num(sample_weights, nan=0.0, posinf=0.0, neginf=0.0)
sample_weights = torch.tensor(sample_weights, dtype=torch.double)
num_samples = BATCH_SIZE * math.ceil(len(y_train) / BATCH_SIZE)
train_sampler = WeightedRandomSampler(sample_weights, num_samples=num_samples, replacement=True)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"Class weights / alpha for focal loss: {dict(zip(le.classes_, class_weights))}")


class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, label_smoothing=0.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.label_smoothing = label_smoothing

    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(
            logits,
            targets,
            reduction='none',
            label_smoothing=self.label_smoothing,
        )
        pt = torch.exp(-ce_loss)
        if self.alpha is not None:
            loss = self.alpha[targets] * (1 - pt) ** self.gamma * ce_loss
        else:
            loss = (1 - pt) ** self.gamma * ce_loss
        return loss.mean()

num_channels = sequences.shape[1]
num_features = feature_vectors.shape[1]
model = PhaseAwareHybridNet(
    input_channels=num_channels,
    num_features=num_features,
    num_classes=num_classes,
).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)

criterion = AsymmetricFocalLoss(alpha=alpha_tensor, gamma_per_class=gamma_per_class, label_smoothing=LABEL_SMOOTHING)
def _null_autocast():
    return nullcontext()

scaler = None
autocast_cm = _null_autocast
if USE_MIXED_PRECISION and device.type == 'cuda':
    try:
        scaler = torch.amp.GradScaler('cuda')

        def autocast_cm():
            return torch.amp.autocast('cuda', enabled=True)
    except TypeError:
        scaler = torch.cuda.amp.GradScaler(enabled=True)

        def autocast_cm():
            return torch.cuda.amp.autocast(enabled=True)
ema = ModelEMA(model, decay=EMA_DECAY) if EMA_DECAY else None

print("=" * 80)
print(f"TRAINING: {EPOCHS} epochs (mixed precision: {USE_MIXED_PRECISION})")
print("=" * 80)

best_val_f1 = 0.0
best_model_state = None
global_step = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0.0
    for xb, fb, pb, yb in train_loader:
        xb, fb, pb, yb = xb.to(device), fb.to(device), pb.to(device), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        with autocast_cm():
            logits = model(xb, fb, pb)
            loss = criterion(logits, yb)
        if scaler:
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
            optimizer.step()
        if ema:
            ema.update(model)
        train_loss += loss.item() * xb.size(0)
        global_step += 1
        scheduler.step(global_step)
    train_loss /= len(train_loader.dataset)

    model.eval()
    if ema:
        ema.apply_shadow(model)
    val_loss = 0.0
    raw_preds = []
    raw_targets = []
    stage_preds = []
    with torch.no_grad():
        for xb, fb, pb, yb in test_loader:
            xb, fb, pb, yb = xb.to(device), fb.to(device), pb.to(device), yb.to(device)
            with autocast_cm():
                logits = model(xb, fb, pb)
                loss = criterion(logits, yb)
            val_loss += loss.item() * xb.size(0)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)
            stage = preds.clone()
            if stress_indices:
                stress_prob = probs[:, stress_indices].sum(dim=1)
                stress_mask = stress_prob >= TWO_STAGE_THRESHOLD
                if stress_mask.any() and stress_idx_tensor is not None:
                    stress_probs = probs[stress_mask][:, stress_indices]
                    best_local = torch.argmax(stress_probs, dim=1)
                    stage[stress_mask] = stress_idx_tensor[best_local]
                if no_stress_idx is not None:
                    stage[~stress_mask] = no_stress_idx
            raw_preds.append(preds.cpu().numpy())
            stage_preds.append(stage.cpu().numpy())
            raw_targets.append(yb.cpu().numpy())
    if ema:
        ema.restore(model)
    val_loss /= len(test_loader.dataset)
    raw_preds = np.concatenate(raw_preds)
    stage_preds = np.concatenate(stage_preds)
    raw_targets = np.concatenate(raw_targets)
    val_acc = accuracy_score(raw_targets, raw_preds)
    val_f1_macro = f1_score(raw_targets, raw_preds, average='macro', zero_division=0)
    val_f1_weighted = f1_score(raw_targets, raw_preds, average='weighted', zero_division=0)

    if val_f1_macro > best_val_f1:
        best_val_f1 = val_f1_macro
        best_model_state = model.state_dict().copy()

    print(
        f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
        f"Val Acc: {val_acc:.3f} | Val F1 (macro): {val_f1_macro:.3f} "
        f"| Val F1 (weighted): {val_f1_weighted:.3f}"
    )

if best_model_state is not None:
    model.load_state_dict(best_model_state)
if ema:
    ema.apply_shadow(model)

print("=" * 80)
print("FINAL EVALUATION ON TEST SET")
print("=" * 80)

model.eval()
raw_preds = []
stage_preds = []
raw_targets = []
with torch.no_grad():
    for xb, fb, pb, yb in test_loader:
        xb, fb, pb, yb = xb.to(device), fb.to(device), pb.to(device), yb.to(device)
        with autocast_cm():
            logits = model(xb, fb, pb)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1)
        stage = preds.clone()
        if stress_indices:
            stress_prob = probs[:, stress_indices].sum(dim=1)
            stress_mask = stress_prob >= TWO_STAGE_THRESHOLD
            if stress_mask.any() and stress_idx_tensor is not None:
                stress_probs = probs[stress_mask][:, stress_indices]
                best_local = torch.argmax(stress_probs, dim=1)
                stage[stress_mask] = stress_idx_tensor[best_local]
            if no_stress_idx is not None:
                stage[~stress_mask] = no_stress_idx
        raw_preds.append(preds.cpu().numpy())
        stage_preds.append(stage.cpu().numpy())
        raw_targets.append(yb.cpu().numpy())

if ema:
    ema.restore(model)

raw_preds = np.concatenate(raw_preds)
stage_preds = np.concatenate(stage_preds)
raw_targets = np.concatenate(raw_targets)

test_acc = accuracy_score(raw_targets, raw_preds)
test_f1_macro = f1_score(raw_targets, raw_preds, average='macro', zero_division=0)
test_f1_weighted = f1_score(raw_targets, raw_preds, average='weighted', zero_division=0)
test_precision_macro = precision_score(raw_targets, raw_preds, average='macro', zero_division=0)
test_recall_macro = recall_score(raw_targets, raw_preds, average='macro', zero_division=0)

print("OVERALL METRICS (direct predictions):")
print(f"  Accuracy:         {test_acc:.4f} ({test_acc*100:.1f}%)")
print(f"  Macro F1:         {test_f1_macro:.4f} ({test_f1_macro*100:.1f}%)")
print(f"  Weighted F1:      {test_f1_weighted:.4f} ({test_f1_weighted*100:.1f}%)")
print(f"  Macro Precision:  {test_precision_macro:.4f}")
print(f"  Macro Recall:     {test_recall_macro:.4f}")

stage_acc = accuracy_score(raw_targets, stage_preds)
stage_f1_macro = f1_score(raw_targets, stage_preds, average='macro', zero_division=0)
stage_f1_weighted = f1_score(raw_targets, stage_preds, average='weighted', zero_division=0)
print("Two-stage metrics:")
print(f"  Accuracy:         {stage_acc:.4f} ({stage_acc*100:.1f}%)")
print(f"  Macro F1:         {stage_f1_macro:.4f} ({stage_f1_macro*100:.1f}%)")
print(f"  Weighted F1:      {stage_f1_weighted:.4f} ({stage_f1_weighted*100:.1f}%)")

print("=" * 80)
print("PER-CLASS METRICS")
print("=" * 80)
print("" + classification_report(raw_targets, raw_preds, target_names=le.classes_, digits=4, zero_division=0))

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(raw_targets, raw_preds)
cm_df = pd.DataFrame(cm, index=le.classes_, columns=le.classes_)

print("=" * 80)
print("CONFUSION MATRIX")
print("=" * 80)
print("" + str(cm_df))

print("=" * 80)
print("SUBJECT-LEVEL PERFORMANCE (TEST SET)")
print("=" * 80)
subject_records = []
for subj in sorted(np.unique(held_out_subjects)):
    subj_mask = held_out_subjects == subj
    subj_true = raw_targets[subj_mask]
    subj_pred = stage_preds[subj_mask]
    subject_records.append({
        "subject": subj,
        "samples": int(subj_mask.sum()),
        "accuracy": accuracy_score(subj_true, subj_pred) if subj_mask.sum() else 0.0,
        "macro_f1": f1_score(subj_true, subj_pred, average='macro', zero_division=0),
    })
if subject_records:
    subject_df = pd.DataFrame(subject_records)
    print(subject_df.to_string(index=False, formatters={
        "accuracy": "{:.3f}".format,
        "macro_f1": "{:.3f}".format,
    }))
else:
    print("No held-out subjects to report (check split configuration).")

print("=" * 80)
print("COMPARISON TO BASELINE")
print("=" * 80)
print("BASELINE (before Phase 1):")
print("  Accuracy:    75.9%")
print("  Macro F1:    36.0%")
print(f"PHASE 2+ HYBRID (multi-scale CNN + hybrid features + focal loss):")
print(f"  Accuracy:    {stage_acc*100:.1f}%  ({(stage_acc-0.759)*100:+.1f}pp)")
print(f"  Macro F1:    {stage_f1_macro*100:.1f}%  ({(stage_f1_macro-0.360)*100:+.1f}pp)")

improvement = (stage_f1_macro - 0.360) * 100
if improvement >= 8:
    print(f"✓ TARGET ACHIEVED! Macro F1 improved by {improvement:+.1f}pp (target: +8-12pp)")
elif improvement >= 5:
    print(f"✓ Good progress! Macro F1 improved by {improvement:+.1f}pp")
else:
    print(f"⚠ Improvement: {improvement:+.1f}pp. Continue with next phases if needed.")

print("=" * 80)


PHASE 2 PREPARATION: DATA BALANCING + TRAINING CONFIG
Split strategy: StratifiedGroupKFold fold 1/5
  Train windows: 5511 from 30 subjects
  Test windows:  1277 from 6 subjects
Train subjects: S01, S03, S04, S05, S06, S08, S09, S10, S11, S12, S13, S15, S16, S17, S18, f01, f02, f03, f04, f05, f06, f08, f10, f12, f13, f14, f15, f16, f17, f18
Test subjects: S02, S07, S14, f07, f09, f11
Dataset split overview:
  Train: (5511, 30, 240)
  Test:  (1277, 30, 240)
Train class distribution (BEFORE balancing):
  high_stress         :   351 (  6.4%)
  low_stress          :    53 (  1.0%)
  moderate_stress     :   607 ( 11.0%)
  no_stress           :  4500 ( 81.7%)
Test class distribution:
  high_stress         :   111 (  8.7%)
  low_stress          :    38 (  3.0%)
  moderate_stress     :    57 (  4.5%)
  no_stress           :  1071 ( 83.9%)
FINAL CLASS DISTRIBUTION
Train: (5511, 30, 240)
  high_stress         :   351 (  6.4%)
  low_stress          :    53 (  1.0%)
  moderate_stress     :   607 ( 