# LSTM + ResNet Stress Classifier

End-to-end exploration of a convolutional + recurrent model for stress-level prediction.

In [1]:

import os
import math
from pathlib import Path
from typing import List, Tuple

import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import GroupKFold
from sklearn.metrics import classification_report, accuracy_score

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


In [2]:
# Configuration
DEFAULT_DATASET_ROOT = Path("./Datasets")
DATASET_ROOT = Path(os.getenv("DATASET_ROOT", DEFAULT_DATASET_ROOT))
STATES = ["STRESS", "AEROBIC", "ANAEROBIC"]
TARGET_FS = 4.0
WINDOW_SECONDS = 60
WINDOW_STEP_SECONDS = 30
MIN_LABEL_COVERAGE = 0.6
SEED = 42
MAX_SUBJECTS = None  # limit per state if desired
APPLY_CHANNEL_NORMALIZATION = True
APPLY_DIFF_CHANNELS = True
APPLY_SMOTE = True

# Cross-subject revalidation setup
GROUP_SPLIT = True
NUM_FOLDS = 5
FOLD_INDEX = 0


In [3]:

# Helper functions for Empatica-format data
STRESS_STAGE_ORDER_S = ["Stroop", "TMCT", "Real Opinion", "Opposite Opinion", "Subtract"]
STRESS_STAGE_ORDER_F = ["TMCT", "Real Opinion", "Opposite Opinion", "Subtract"]
STRESS_TAG_PAIRS_S = [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)]
STRESS_TAG_PAIRS_F = [(2, 3), (4, 5), (6, 7), (8, 9)]
STRESS_PHASES = {"Stroop", "TMCT", "Real Opinion", "Opposite Opinion", "Subtract"}
STRESS_LEVEL_BOUNDS = {"low": 3.0, "moderate": 6.0}
STRESS_LEVEL_PHASE_BOUNDS = {
    "Stroop": {"low": 2.5, "moderate": 5.0},
    "Opposite Opinion": {"low": 2.5, "moderate": 5.5},
    "Real Opinion": {"low": 2.8, "moderate": 5.5},
    "TMCT": {"low": 2.8, "moderate": 5.8},
    "Subtract": {"low": 2.8, "moderate": 5.8},
}
STRESS_LEVEL_FILES = ["Stress_Level_v1.csv", "Stress_Level_v2.csv"]
MIN_LABEL_COVERAGE = 0.6


def load_stress_levels():
    levels = {}
    for fname in STRESS_LEVEL_FILES:
        path = Path(fname)
        if not path.exists():
            continue
        df = pd.read_csv(path, index_col=0)
        df.columns = [str(c).strip() for c in df.columns]
        for subject, row in df.iterrows():
            subj = str(subject).strip()
            levels[subj] = {
                col: (float(row[col]) if not pd.isna(row[col]) else np.nan)
                for col in df.columns
            }
    return levels


STRESS_LEVELS = load_stress_levels()


def base_subject_id(subject: str) -> str:
    return subject.split("_")[0]


def read_signal(path: Path):
    with open(path, "r") as f:
        start_line = f.readline().strip()
        if not start_line:
            raise ValueError(f"Missing start timestamp in {path}")
        start_ts = pd.to_datetime(start_line.split(",")[0])
        fs_line = f.readline().strip()
        if not fs_line:
            raise ValueError(f"Missing sample rate in {path}")
        fs = float(fs_line.split(",")[0])
        data = np.genfromtxt(f, delimiter=",")
    data = np.asarray(data, dtype=float)
    if data.ndim == 0:
        data = data.reshape(1, 1)
    return fs, data.squeeze(), start_ts


def read_tags(path: Path, start_ts: pd.Timestamp):
    if not path.exists():
        return []
    df = pd.read_csv(path, header=None)
    tags = []
    for ts_str in df[0].astype(str):
        ts = pd.to_datetime(ts_str)
        tags.append((ts - start_ts).total_seconds())
    return [(t, t) for t in tags]


def stress_intervals_from_tags(tags, subject):
    if not tags:
        return []
    times = [t for t, _ in tags]
    if subject.startswith("S"):
        idx_pairs = STRESS_TAG_PAIRS_S
        stage_order = STRESS_STAGE_ORDER_S
    else:
        idx_pairs = STRESS_TAG_PAIRS_F
        stage_order = STRESS_STAGE_ORDER_F
    base_id = base_subject_id(subject)
    spans = []
    for stage, (i, j) in zip(stage_order, idx_pairs):
        if i < len(times) and j < len(times) and times[j] > times[i]:
            level = STRESS_LEVELS.get(base_id, {}).get(stage)
            spans.append({"start": times[i], "end": times[j], "stage": stage, "stress_level": level})
    return spans


def active_intervals_from_tags(tags):
    if len(tags) < 2:
        return []
    spans = []
    for (a, _), (b, _) in zip(tags[:-1], tags[1:]):
        if b > a:
            spans.append({"start": a, "end": b, "stage": "active", "stress_level": 0.0})
    return spans


def stress_bucket(level: float = None, phase: str = None) -> str:
    if phase in {"aerobic", "anaerobic", "rest", "active"}:
        return "no_stress"
    if level is None or pd.isna(level) or level <= 0:
        return "no_stress"
    bounds = STRESS_LEVEL_PHASE_BOUNDS.get(phase, STRESS_LEVEL_BOUNDS)
    if level <= bounds["low"]:
        return "low_stress"
    if level <= bounds["moderate"]:
        return "moderate_stress"
    return "high_stress"


def resample_to_rate(signal: np.ndarray, src_fs: float, tgt_fs: float) -> np.ndarray:
    if signal.ndim == 1:
        signal = signal[:, None]
    src_len = signal.shape[0]
    duration = src_len / src_fs
    tgt_len = int(duration * tgt_fs)
    if tgt_len <= 0:
        return np.zeros((0, signal.shape[1]))
    src_t = np.linspace(0, duration, src_len, endpoint=False)
    tgt_t = np.linspace(0, duration, tgt_len, endpoint=False)
    resampled = np.vstack([
        np.interp(tgt_t, src_t, signal[:, i])
        for i in range(signal.shape[1])
    ]).T
    if resampled.shape[1] == 1:
        return resampled[:, 0]
    return resampled


def window_intervals(duration: float, win_s: int, step_s: int):
    windows = []
    t = 0.0
    while t + win_s <= duration:
        windows.append((t, t + win_s))
        t += step_s
    return windows


def assign_label(win, intervals):
    start, end = win
    length = end - start
    best_label = None
    best_cov = 0.0
    best_span = None
    for label, spans in intervals.items():
        overlap = 0.0
        span_choice = None
        span_overlap = 0.0
        for span in spans:
            a = span["start"] if isinstance(span, dict) else span[0]
            b = span["end"] if isinstance(span, dict) else span[1]
            inter = max(0.0, min(end, b) - max(start, a))
            if inter > 0:
                overlap += inter
                if inter > span_overlap:
                    span_overlap = inter
                    span_choice = span
        coverage = overlap / length
        if coverage > best_cov:
            best_cov = coverage
            best_label = label
            best_span = span_choice
    if best_cov >= MIN_LABEL_COVERAGE and best_label is not None:
        return best_label, best_span
    return None, None


def make_label_intervals(state: str, subject: str, tags, duration: float):
    rest_span = [{"start": 0.0, "end": duration, "stage": "rest", "stress_level": 0.0}]
    if state == "STRESS":
        stress_spans = stress_intervals_from_tags(tags, subject)
        if not stress_spans:
            return {"rest": rest_span}
        return {"stress": stress_spans, "rest": rest_span}
    active = active_intervals_from_tags(tags)
    label = "aerobic" if state == "AEROBIC" else "anaerobic"
    if not active:
        return {label: rest_span, "rest": rest_span}
    return {label: active, "rest": rest_span}


def load_subject_state(state: str, subject: str):
    folder = DATASET_ROOT / state / subject
    if not folder.exists():
        raise FileNotFoundError(folder)
    fs_eda, eda_raw, start_ts = read_signal(folder / "EDA.csv")
    temp_path = folder / "TEMP.csv"
    if temp_path.exists():
        fs_temp, temp_raw, _ = read_signal(temp_path)
    else:
        fs_temp, temp_raw = fs_eda, np.zeros_like(eda_raw)
    fs_acc, acc_raw, _ = read_signal(folder / "ACC.csv")
    acc_raw = np.atleast_2d(acc_raw)
    acc_mag = np.linalg.norm(acc_raw, axis=1)
    bvp_path = folder / "BVP.csv"
    if bvp_path.exists():
        fs_bvp, bvp_raw, _ = read_signal(bvp_path)
    else:
        fs_bvp, bvp_raw = None, None
    tags = read_tags(folder / "tags.csv", start_ts)
    sensors = {
        "EDA": np.asarray(eda_raw, dtype=float),
        "TEMP": np.asarray(temp_raw, dtype=float),
        "ACC_MAG": acc_mag,
    }
    if bvp_raw is not None:
        sensors["BVP"] = np.asarray(bvp_raw, dtype=float)
    fs_map = {"EDA": fs_eda, "TEMP": fs_temp, "ACC_MAG": fs_acc}
    if fs_bvp:
        fs_map["BVP"] = fs_bvp
    duration = len(sensors["EDA"]) / fs_eda
    return {"sensors": sensors, "fs": fs_map, "tags": tags, "duration": duration}


In [4]:
EXPECTED_LEN = int(WINDOW_SECONDS * TARGET_FS)
BASE_CHANNELS = ["EDA", "TEMP", "ACC", "BVP"]


def _slice_or_pad(signal: np.ndarray, start: int, end: int) -> np.ndarray:
    length = end - start
    if signal is None or len(signal) == 0:
        return np.zeros(length, dtype=np.float32)
    if end > len(signal):
        pad = end - len(signal)
        segment = signal[start: len(signal)]
        if pad > 0:
            segment = np.concatenate([segment, np.zeros(pad, dtype=segment.dtype)])
        return segment.astype(np.float32)
    return signal[start:end].astype(np.float32)


def extract_enhanced_channels(eda, temp, acc, bvp, fs=4.0):
    """
    PHASE 1.2: Extract enhanced channel features.
    Adds 8 new channels for stress-relevant patterns.
    """
    from scipy.signal import hilbert
    
    channels = [eda, temp, acc, bvp]  # Original 4 channels
    
    # First-order derivatives (rate of change)
    channels.extend([
        np.diff(eda, prepend=eda[0]),
        np.diff(temp, prepend=temp[0]),
        np.diff(acc, prepend=acc[0]),
        np.diff(bvp, prepend=bvp[0])
    ])
    
    # EDA decomposition (tonic = slow baseline, phasic = SCR responses)
    window_tonic = int(10 * fs)  # 10 second moving average
    eda_tonic = np.convolve(eda, np.ones(window_tonic)/window_tonic, mode='same')
    eda_phasic = eda - eda_tonic
    channels.extend([eda_tonic, eda_phasic])
    
    # Second-order derivative (acceleration of EDA - captures SCR onset speed)
    eda_diff = np.diff(eda, prepend=eda[0])
    eda_accel = np.diff(eda_diff, prepend=eda_diff[0])
    channels.append(eda_accel)
    
    # Moving averages for trend detection
    window_short = int(5 * fs)   # 5-second trend
    window_long = int(15 * fs)   # 15-second trend
    eda_ma_short = np.convolve(eda, np.ones(window_short)/window_short, mode='same')
    eda_ma_long = np.convolve(eda, np.ones(window_long)/window_long, mode='same')
    channels.extend([eda_ma_short, eda_ma_long])
    
    # BVP envelope (HRV proxy via Hilbert transform)
    bvp_envelope = np.abs(hilbert(bvp))
    channels.append(bvp_envelope)
    
    # Cross-channel interaction (autonomic coordination)
    eda_bvp_interaction = eda * bvp
    channels.append(eda_bvp_interaction)
    
    # Smoothed ACC (remove high-frequency noise)
    acc_smoothed = np.convolve(acc, np.ones(int(3*fs))/(3*fs), mode='same')
    channels.append(acc_smoothed)
    
    return np.stack(channels, axis=0)  # 16 total channels


def build_sequence_dataset(states: List[str] = STATES, max_subjects: int = 0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    sequences = []
    labels = []
    subjects = []
    for state in states:
        state_dir = DATASET_ROOT / state
        if not state_dir.exists():
            continue
        subject_ids = sorted([p.name for p in state_dir.iterdir() if p.is_dir()])
        if max_subjects and max_subjects > 0:
            subject_ids = subject_ids[:max_subjects]
        for subj in tqdm(subject_ids, desc=f"{state}"):
            try:
                info = load_subject_state(state, subj)
            except Exception as exc:
                print(f"Skip {state}/{subj}: {exc}")
                continue
            sensors = info["sensors"]
            fs_map = info["fs"]
            tags = info["tags"]
            duration = info["duration"]

            eda = resample_to_rate(sensors["EDA"], fs_map["EDA"], TARGET_FS)
            temp = resample_to_rate(sensors.get("TEMP", np.zeros_like(eda)), fs_map.get("TEMP", TARGET_FS), TARGET_FS) if "TEMP" in sensors else np.zeros_like(eda)
            acc = resample_to_rate(sensors["ACC_MAG"], fs_map["ACC_MAG"], TARGET_FS)
            if "BVP" in sensors and "BVP" in fs_map:
                bvp = resample_to_rate(sensors["BVP"], fs_map["BVP"], TARGET_FS)
            else:
                bvp = np.zeros_like(eda)

            intervals = make_label_intervals(state, subj, tags, duration)
            windows = window_intervals(duration, WINDOW_SECONDS, WINDOW_STEP_SECONDS)

            for win in windows:
                label_name, span_meta = assign_label(win, intervals)
                if label_name is None or span_meta is None:
                    continue
                start_idx = int(round(win[0] * TARGET_FS))
                end_idx = start_idx + EXPECTED_LEN
                eda_win = _slice_or_pad(eda, start_idx, end_idx)
                temp_win = _slice_or_pad(temp, start_idx, end_idx)
                acc_win = _slice_or_pad(acc, start_idx, end_idx)
                bvp_win = _slice_or_pad(bvp, start_idx, end_idx)

                stress_stage = span_meta.get("stage") if isinstance(span_meta, dict) else None
                stress_level = span_meta.get("stress_level") if isinstance(span_meta, dict) else None
                if label_name == "stress":
                    if stress_level is None or np.isnan(stress_level):
                        continue
                else:
                    stress_level = 0.0
                phase_label = stress_stage if stress_stage else label_name
                stress_class = stress_bucket(stress_level, phase_label)

                # PHASE 1.2: Extract enhanced channels (16 total)
                tensor = extract_enhanced_channels(eda_win, temp_win, acc_win, bvp_win, TARGET_FS)
                
                sequences.append(tensor)
                labels.append(stress_class)
                subjects.append(base_subject_id(subj))
    
    sequences = np.stack(sequences).astype(np.float32)
    labels = np.array(labels)
    subjects = np.array(subjects)
    return sequences, labels, subjects

In [5]:
sequences, labels, subjects = build_sequence_dataset(max_subjects=MAX_SUBJECTS if MAX_SUBJECTS else 0)

# Define channel names for 16-channel input
channel_names = [
    'EDA', 'TEMP', 'ACC', 'BVP',                          # Original (4)
    'EDA_diff', 'TEMP_diff', 'ACC_diff', 'BVP_diff',      # First derivatives (4)
    'EDA_tonic', 'EDA_phasic',                             # EDA decomposition (2)
    'EDA_accel',                                           # Second derivative (1)
    'EDA_ma_short', 'EDA_ma_long',                         # Moving averages (2)
    'BVP_envelope',                                        # BVP envelope (1)
    'EDA_BVP_interaction',                                 # Cross-channel (1)
    'ACC_smoothed'                                         # Smoothed ACC (1)
]

print("="*80)
print("DATASET STATISTICS")
print("="*80)
print(f"\nRaw sequences: {sequences.shape}")
print(f"Channels: {len(channel_names)}")
print(f"Channel names: {channel_names}")
print(f"\nLabel distribution:")
label_dist = pd.Series(labels).value_counts().sort_index()
for label, count in label_dist.items():
    pct = 100 * count / len(labels)
    print(f"  {label:20s}: {count:5d} ({pct:5.1f}%)")

# PHASE 1.1: Subject-Specific Baseline Normalization (HIGHEST IMPACT: +5-7% macro F1)
if APPLY_CHANNEL_NORMALIZATION and sequences.size:
    print("\n" + "="*80)
    print("APPLYING SUBJECT-SPECIFIC BASELINE NORMALIZATION")
    print("="*80)
    normalized = sequences.copy()
    
    for subject in np.unique(subjects):
        subject_mask = subjects == subject
        rest_mask = subject_mask & (labels == 'no_stress')
        
        if rest_mask.sum() > 0:
            baseline_mean = sequences[rest_mask].mean(axis=(0, 2), keepdims=True)
            baseline_std = sequences[rest_mask].std(axis=(0, 2), keepdims=True) + 1e-6
            print(f"  {subject:8s}: {rest_mask.sum():4d} rest windows → baseline")
        else:
            baseline_mean = sequences[subject_mask].mean(axis=(0, 2), keepdims=True)
            baseline_std = sequences[subject_mask].std(axis=(0, 2), keepdims=True) + 1e-6
            print(f"  {subject:8s}: {subject_mask.sum():4d} total windows (NO REST DATA)")
        
        normalized[subject_mask] = (sequences[subject_mask] - baseline_mean) / baseline_std
    
    sequences = normalized
    print("\n✓ Subject-specific normalization complete")

print("\n" + "="*80)
print(f"FINAL DATASET: {sequences.shape}")
print(f"Subjects: {len(np.unique(subjects))}")
print("="*80)

STRESS:   0%|          | 0/37 [00:00<?, ?it/s]

Skip STRESS/f14_a: No columns to parse from file


AEROBIC:   0%|          | 0/31 [00:00<?, ?it/s]

ANAEROBIC:   0%|          | 0/32 [00:00<?, ?it/s]

DATASET STATISTICS

Raw sequences: (6788, 16, 240)
Channels: 16
Channel names: ['EDA', 'TEMP', 'ACC', 'BVP', 'EDA_diff', 'TEMP_diff', 'ACC_diff', 'BVP_diff', 'EDA_tonic', 'EDA_phasic', 'EDA_accel', 'EDA_ma_short', 'EDA_ma_long', 'BVP_envelope', 'EDA_BVP_interaction', 'ACC_smoothed']

Label distribution:
  high_stress         :   462 (  6.8%)
  low_stress          :    91 (  1.3%)
  moderate_stress     :   664 (  9.8%)
  no_stress           :  5571 ( 82.1%)

APPLYING SUBJECT-SPECIFIC BASELINE NORMALIZATION
  S01     :  155 rest windows → baseline
  S02     :  188 rest windows → baseline
  S03     :  122 rest windows → baseline
  S04     :  140 rest windows → baseline
  S05     :  136 rest windows → baseline
  S06     :  124 rest windows → baseline
  S07     :  124 rest windows → baseline
  S08     :  135 rest windows → baseline
  S09     :  137 rest windows → baseline
  S10     :  137 rest windows → baseline
  S11     :  130 rest windows → baseline
  S12     :   69 rest windows → baseli

In [6]:

class SequenceDataset(Dataset):
    def __init__(self, data: np.ndarray, labels: np.ndarray):
        self.features = torch.from_numpy(data)
        self.labels = torch.from_numpy(labels).long()

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


In [7]:
class ResNetBlock(nn.Module):
    def __init__(self, channels: int, kernel_size: int = 5, dilation: int = 1):
        super().__init__()
        padding = ((kernel_size - 1) // 2) * dilation
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=padding, dilation=dilation)
        self.bn1 = nn.BatchNorm1d(channels)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size, padding=padding, dilation=dilation)
        self.bn2 = nn.BatchNorm1d(channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)


class MultiScaleLSTMResNet(nn.Module):
    """Phase 2 architecture: multi-scale convolutions + BiLSTM + attention."""

    def __init__(self, input_channels: int, num_classes: int = 4, cnn_channels: int = 64):
        super().__init__()
        self.conv_short = nn.Sequential(
            nn.Conv1d(input_channels, cnn_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),
            ResNetBlock(cnn_channels, kernel_size=3, dilation=1),
            ResNetBlock(cnn_channels, kernel_size=3, dilation=1),
        )
        self.conv_medium = nn.Sequential(
            nn.Conv1d(input_channels, cnn_channels, kernel_size=7, padding=3),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),
            ResNetBlock(cnn_channels, kernel_size=7, dilation=2),
            ResNetBlock(cnn_channels, kernel_size=7, dilation=2),
        )
        self.conv_long = nn.Sequential(
            nn.Conv1d(input_channels, cnn_channels, kernel_size=15, padding=7),
            nn.BatchNorm1d(cnn_channels),
            nn.ReLU(),
            ResNetBlock(cnn_channels, kernel_size=15, dilation=4),
            ResNetBlock(cnn_channels, kernel_size=15, dilation=4),
        )

        merged_channels = cnn_channels * 3
        self.merge = nn.Sequential(
            nn.Conv1d(merged_channels, 128, kernel_size=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
        )

        self.lstm = nn.LSTM(
            input_size=128,
            hidden_size=128,
            num_layers=3,
            dropout=0.3,
            batch_first=True,
            bidirectional=True,
        )
        self.attention = nn.MultiheadAttention(embed_dim=256, num_heads=8, dropout=0.2, batch_first=True)
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Sequential(
            nn.Linear(256 + 128, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, num_classes),
        )

    def forward(self, x):
        feat_short = self.conv_short(x)
        feat_medium = self.conv_medium(x)
        feat_long = self.conv_long(x)
        merged = torch.cat([feat_short, feat_medium, feat_long], dim=1)
        merged = self.merge(merged)

        lstm_in = merged.transpose(1, 2)
        lstm_out, _ = self.lstm(lstm_in)
        attn_out, _ = self.attention(lstm_out, lstm_out, lstm_out)
        temporal_features = attn_out.mean(dim=1)
        cnn_features = self.global_pool(merged).squeeze(-1)
        combined = torch.cat([temporal_features, cnn_features], dim=1)
        return self.classifier(combined)


In [8]:
# Training configuration
EPOCHS = 40
BATCH_SIZE = 32
LR = 1e-3
WEIGHT_DECAY = 1e-4
MAX_GRAD_NORM = 1.0
USE_MIXED_PRECISION = torch.cuda.is_available()

le = LabelEncoder()
encoded_labels = le.fit_transform(labels)
num_classes = len(le.classes_)

print("="*80)
print("LABEL ENCODING")
print("="*80)
print("Classes:", dict(zip(le.classes_, range(num_classes))))
print(f"
Encoded label distribution:")
for i, class_name in enumerate(le.classes_):
    count = (encoded_labels == i).sum()
    print(f"  {i}: {class_name:20s} → {count:5d} samples")


LABEL ENCODING
Classes: {np.str_('high_stress'): 0, np.str_('low_stress'): 1, np.str_('moderate_stress'): 2, np.str_('no_stress'): 3}

Encoded label distribution:
  0: high_stress          →   462 samples
  1: low_stress           →    91 samples
  2: moderate_stress      →   664 samples
  3: no_stress            →  5571 samples


In [9]:
from sklearn.model_selection import train_test_split, GroupKFold
from sklearn.metrics import precision_score, recall_score, f1_score
from scipy.interpolate import interp1d

print("="*80)
print("PHASE 2 PREPARATION: DATA BALANCING + TRAINING CONFIG")
print("="*80)

def time_warp_augment(sequence, warp_factor):
    """Temporal stretch/compress via interpolation."""
    channels, length = sequence.shape
    new_length = int(length * warp_factor)
    warped = []
    for ch in range(channels):
        if new_length < 4:
            warped.append(sequence[ch])
            continue
        f = interp1d(np.arange(length), sequence[ch], kind='cubic', fill_value='extrapolate')
        new_indices = np.linspace(0, length - 1, new_length)
        warped_ch = f(new_indices)
        warped_ch_resampled = np.interp(np.arange(length), np.linspace(0, length - 1, new_length), warped_ch)
        warped.append(warped_ch_resampled)
    return np.array(warped, dtype=np.float32)


def temporal_augmentation(X, y, augment_counts={'low_stress': 6, 'high_stress': 2, 'moderate_stress': 1}):
    """Augment minority classes with physiologically-plausible transforms."""
    X_aug = list(X)
    y_aug = list(y)
    for class_idx in range(num_classes):
        class_name = le.classes_[class_idx]
        if class_name not in augment_counts:
            continue
        class_mask = y == class_idx
        class_samples = X[class_mask]
        if not len(class_samples):
            continue
        aug_factor = augment_counts[class_name]
        print(f"  {class_name}: {class_mask.sum()} samples → augmenting {aug_factor}x")
        for sample in class_samples:
            for _ in range(aug_factor):
                warp_factor = np.random.uniform(0.90, 1.10)
                aug_sample = time_warp_augment(sample, warp_factor)
                noise_std = 0.05 * np.std(sample, axis=1, keepdims=True)
                aug_sample = aug_sample + np.random.normal(0, noise_std, aug_sample.shape).astype(np.float32)
                shift_range = int(0.15 * sample.shape[1])
                if shift_range > 0:
                    shift = np.random.randint(-shift_range, shift_range)
                    aug_sample = np.roll(aug_sample, shift, axis=1)
                X_aug.append(aug_sample)
                y_aug.append(class_idx)
    return np.array(X_aug, dtype=np.float32), np.array(y_aug)


def rebalance_sequences(X, y, target_count=None):
    """Oversample minority classes to match the dominant class count."""
    class_counts = np.bincount(y, minlength=num_classes)
    target = class_counts.max() if target_count is None else target_count
    balanced_X = []
    balanced_y = []
    for class_idx in range(num_classes):
        class_mask = y == class_idx
        class_samples = X[class_mask]
        if not len(class_samples):
            continue
        count = class_samples.shape[0]
        if count >= target:
            balanced_X.append(class_samples[:target])
            balanced_y.append(np.full(target, class_idx))
        else:
            extra = target - count
            extra_idx = np.random.choice(count, extra, replace=True)
            augmented_samples = np.concatenate([class_samples, class_samples[extra_idx]], axis=0)
            balanced_X.append(augmented_samples)
            balanced_y.append(np.full(target, class_idx))
    X_balanced = np.concatenate(balanced_X).astype(np.float32)
    y_balanced = np.concatenate(balanced_y)
    return X_balanced, y_balanced


# Ensure sequences are float32
sequences = sequences.astype(np.float32)

# Subject-aware split for revalidation
if GROUP_SPLIT:
    if NUM_FOLDS < 2:
        raise ValueError("NUM_FOLDS must be >= 2 for GroupKFold.")
    unique_subjects = np.unique(subjects)
    if len(unique_subjects) < NUM_FOLDS:
        raise ValueError(f"Not enough subjects ({len(unique_subjects)}) for NUM_FOLDS={NUM_FOLDS}.")
    gkf = GroupKFold(n_splits=NUM_FOLDS)
    splits = list(gkf.split(sequences, encoded_labels, groups=subjects))
    if FOLD_INDEX >= len(splits):
        raise ValueError(f"FOLD_INDEX {FOLD_INDEX} out of range for {NUM_FOLDS} folds.")
    train_idx, test_idx = splits[FOLD_INDEX]
    split_desc = f"GroupKFold fold {FOLD_INDEX + 1}/{NUM_FOLDS}"
else:
    train_idx, test_idx = train_test_split(
        np.arange(len(sequences)),
        test_size=0.2,
        stratify=encoded_labels,
        random_state=SEED,
    )
    split_desc = "Stratified random 80/20 split"

X_train, X_test = sequences[train_idx], sequences[test_idx]
y_train, y_test = encoded_labels[train_idx], encoded_labels[test_idx]
train_subjects = subjects[train_idx]
test_subjects = subjects[test_idx]
held_out_subjects = test_subjects.copy()

print(f"
Split strategy: {split_desc}")
print(f"  Train windows: {X_train.shape[0]} from {len(np.unique(train_subjects))} subjects")
print(f"  Test windows:  {X_test.shape[0]} from {len(np.unique(test_subjects))} subjects")
print("
Train subjects:", ', '.join(sorted(np.unique(train_subjects))))
print("Test subjects:", ', '.join(sorted(np.unique(test_subjects))))

print(f"
Dataset split overview:")
print(f"  Train: {X_train.shape}")
print(f"  Test:  {X_test.shape}")

print(f"
Train class distribution (BEFORE balancing):")
for i, class_name in enumerate(le.classes_):
    count = (y_train == i).sum()
    pct = 100 * count / len(y_train)
    print(f"  {class_name:20s}: {count:5d} ({pct:5.1f}%)")

print(f"
Test class distribution:")
for i, class_name in enumerate(le.classes_):
    count = (y_test == i).sum()
    pct = 100 * count / len(y_test)
    print(f"  {class_name:20s}: {count:5d} ({pct:5.1f}%)")

if APPLY_SMOTE:
    print("
" + "="*80)
    print("TEMPORAL AUGMENTATION + RESAMPLING")
    print("="*80)
    print("
Step 1: Temporal augmentation for minority classes...")
    X_train, y_train = temporal_augmentation(X_train, y_train)
    print(f"  After augmentation: {X_train.shape}")
    print("
Step 2: Oversampling minorities to match majority...")
    X_train, y_train = rebalance_sequences(X_train, y_train)

print("
" + "="*80)
print("FINAL CLASS DISTRIBUTION")
print("="*80)
print(f"
Train: {X_train.shape}")
for i, class_name in enumerate(le.classes_):
    count = (y_train == i).sum()
    pct = 100 * count / len(y_train)
    print(f"  {class_name:20s}: {count:5d} ({pct:5.1f}%)")

# Create datasets
X_train = X_train.astype(np.float32)
X_test = X_test.astype(np.float32)
train_dataset = SequenceDataset(X_train, y_train)
test_dataset = SequenceDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

class_counts = np.bincount(y_train, minlength=num_classes).clip(min=1)
class_weights = len(y_train) / (num_classes * class_counts)
alpha_tensor = torch.tensor(class_weights, dtype=torch.float32, device=device)

print(f"
Class weights / alpha for focal loss: {dict(zip(le.classes_, class_weights))}")


class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, logits, targets):
        ce_loss = F.cross_entropy(logits, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        if self.alpha is not None:
            at = self.alpha[targets]
            loss = at * (1 - pt) ** self.gamma * ce_loss
        else:
            loss = (1 - pt) ** self.gamma * ce_loss
        return loss.mean()


# Initialize model + optimizer/scheduler
num_channels = sequences.shape[1]
model = MultiScaleLSTMResNet(input_channels=num_channels, num_classes=num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = FocalLoss(alpha=alpha_tensor, gamma=2.0)
scaler = torch.cuda.amp.GradScaler(enabled=USE_MIXED_PRECISION)

print("
" + "="*80)
print(f"TRAINING: {EPOCHS} epochs (mixed precision: {USE_MIXED_PRECISION})")
print("="*80)

best_val_f1 = 0.0
best_model_state = None
global_step = 0

for epoch in range(1, EPOCHS + 1):
    model.train()
    train_loss = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=USE_MIXED_PRECISION):
            logits = model(xb)
            loss = criterion(logits, yb)
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
        scaler.step(optimizer)
        scaler.update()
        train_loss += loss.item() * xb.size(0)
        global_step += 1
        scheduler.step(global_step)
    train_loss /= len(train_loader.dataset)

    model.eval()
    val_loss = 0.0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for xb, yb in test_loader:
            xb, yb = xb.to(device), yb.to(device)
            with torch.cuda.amp.autocast(enabled=USE_MIXED_PRECISION):
                logits = model(xb)
                loss = criterion(logits, yb)
            val_loss += loss.item() * xb.size(0)
            preds = torch.argmax(logits, dim=1)
            all_preds.append(preds.cpu().numpy())
            all_targets.append(yb.cpu().numpy())
    val_loss /= len(test_loader.dataset)
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    val_acc = accuracy_score(all_targets, all_preds)
    val_f1_macro = f1_score(all_targets, all_preds, average='macro', zero_division=0)
    val_f1_weighted = f1_score(all_targets, all_preds, average='weighted', zero_division=0)

    if val_f1_macro > best_val_f1:
        best_val_f1 = val_f1_macro
        best_model_state = model.state_dict().copy()

    print(
        f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | "
        f"Val Acc: {val_acc:.3f} | Val F1 (macro): {val_f1_macro:.3f} | Val F1 (weighted): {val_f1_weighted:.3f}"
    )

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)

print("
" + "="*80)
print("FINAL EVALUATION ON TEST SET")
print("="*80)

model.eval()
all_preds = []
all_targets = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds.cpu().numpy())
        all_targets.append(yb.cpu().numpy())

all_preds = np.concatenate(all_preds)
all_targets = np.concatenate(all_targets)

test_acc = accuracy_score(all_targets, all_preds)
test_f1_macro = f1_score(all_targets, all_preds, average='macro', zero_division=0)
test_f1_weighted = f1_score(all_targets, all_preds, average='weighted', zero_division=0)
test_precision_macro = precision_score(all_targets, all_preds, average='macro', zero_division=0)
test_recall_macro = recall_score(all_targets, all_preds, average='macro', zero_division=0)

print(f"
OVERALL METRICS:")
print(f"  Accuracy:         {test_acc:.4f} ({test_acc*100:.1f}%)")
print(f"  Macro F1:         {test_f1_macro:.4f} ({test_f1_macro*100:.1f}%)")
print(f"  Weighted F1:      {test_f1_weighted:.4f} ({test_f1_weighted*100:.1f}%)")
print(f"  Macro Precision:  {test_precision_macro:.4f}")
print(f"  Macro Recall:     {test_recall_macro:.4f}")

print("
" + "="*80)
print("PER-CLASS METRICS")
print("="*80)
print("
" + classification_report(all_targets, all_preds, target_names=le.classes_, digits=4, zero_division=0))

from sklearn.metrics import confusion_matrix
cm = confusion_matrix(all_targets, all_preds)
cm_df = pd.DataFrame(cm, index=le.classes_, columns=le.classes_)

print("
" + "="*80)
print("CONFUSION MATRIX")
print("="*80)
print("
" + str(cm_df))

print("
" + "="*80)
print("SUBJECT-LEVEL PERFORMANCE (TEST SET)")
print("="*80)
subject_records = []
for subj in sorted(np.unique(held_out_subjects)):
    subj_mask = held_out_subjects == subj
    subj_true = all_targets[subj_mask]
    subj_pred = all_preds[subj_mask]
    subject_records.append({
        "subject": subj,
        "samples": int(subj_mask.sum()),
        "accuracy": accuracy_score(subj_true, subj_pred) if subj_mask.sum() else 0.0,
        "macro_f1": f1_score(subj_true, subj_pred, average='macro', zero_division=0),
    })
if subject_records:
    subject_df = pd.DataFrame(subject_records)
    print(subject_df.to_string(index=False, formatters={
        "accuracy": "{:.3f}".format,
        "macro_f1": "{:.3f}".format,
    }))
else:
    print("No held-out subjects to report (check split configuration).")

print("
" + "="*80)
print("COMPARISON TO BASELINE")
print("="*80)
print("
BASELINE (before Phase 1):")
print("  Accuracy:    75.9%")
print("  Macro F1:    36.0%")
print(f"
PHASE 2 (multi-scale CNN + focal loss + improved training):")
print(f"  Accuracy:    {test_acc*100:.1f}%  ({(test_acc-0.759)*100:+.1f}pp)")
print(f"  Macro F1:    {test_f1_macro*100:.1f}%  ({(test_f1_macro-0.360)*100:+.1f}pp)")

improvement = (test_f1_macro - 0.360) * 100
if improvement >= 8:
    print(f"
✓ TARGET ACHIEVED! Macro F1 improved by {improvement:+.1f}pp (target: +8-12pp)")
elif improvement >= 5:
    print(f"
✓ Good progress! Macro F1 improved by {improvement:+.1f}pp")
else:
    print(f"
⚠ Improvement: {improvement:+.1f}pp. Continue with next phases if needed.")

print("
" + "="*80)


PHASE 1.3: ADVANCED CLASS BALANCING

Split strategy: GroupKFold fold 1/5
  Train windows: 5426 from 29 subjects
  Test windows:  1362 from 7 subjects

Train subjects: S02, S04, S05, S06, S07, S09, S11, S12, S13, S15, S16, S17, S18, f01, f02, f03, f06, f07, f08, f09, f10, f11, f12, f13, f14, f15, f16, f17, f18
Test subjects: S01, S03, S08, S10, S14, f04, f05

Dataset split overview:
  Train: (5426, 16, 240)
  Test:  (1362, 16, 240)

Train class distribution (BEFORE balancing):
  high_stress         :   433 (  8.0%)
  low_stress          :    83 (  1.5%)
  moderate_stress     :   505 (  9.3%)
  no_stress           :  4405 ( 81.2%)

Test class distribution:
  high_stress         :    29 (  2.1%)
  low_stress          :     8 (  0.6%)
  moderate_stress     :   159 ( 11.7%)
  no_stress           :  1166 ( 85.6%)

APPLYING MULTI-STRATEGY BALANCING

Step 1: Temporal augmentation for minority classes...
  high_stress: 433 samples → augmenting 2x
  low_stress: 83 samples → augmenting 6x
  moder




After SMOTE:
  Train: (10392, 16, 240)
  high_stress         :  2598 ( 25.0%)
  low_stress          :  2598 ( 25.0%)
  moderate_stress     :  2598 ( 25.0%)
  no_stress           :  2598 ( 25.0%)

FINAL CLASS DISTRIBUTION

Train: (10392, 16, 240)
  high_stress         :  2598 ( 25.0%)
  low_stress          :  2598 ( 25.0%)
  moderate_stress     :  2598 ( 25.0%)
  no_stress           :  2598 ( 25.0%)

Class weights for loss: {np.str_('high_stress'): np.float64(1.0), np.str_('low_stress'): np.float64(1.0), np.str_('moderate_stress'): np.float64(1.0), np.str_('no_stress'): np.float64(1.0)}

TRAINING: 20 epochs
Epoch 01 | Train Loss: 0.9673 | Val Loss: 0.7456 | Val Acc: 0.691 | Val F1 (macro): 0.317 | Val F1 (weighted): 0.755
Epoch 02 | Train Loss: 0.6418 | Val Loss: 0.5358 | Val Acc: 0.773 | Val F1 (macro): 0.309 | Val F1 (weighted): 0.788
Epoch 03 | Train Loss: 0.5818 | Val Loss: 0.5362 | Val Acc: 0.757 | Val F1 (macro): 0.350 | Val F1 (weighted): 0.802
Epoch 04 | Train Loss: 0.5383 | Va