In [None]:
# =============================================================
# PHASE 1: COMPACT LSTM MODEL ARCHITECTURE (~15-20K PARAMETERS)
# =============================================================
# Architecture:
# - Compact branches: 2 conv layers + pooling (no residual blocks)
# - Raw EMG: extracts BOTH time and frequency domain features
# - RMS/LMS: extracts ONLY time domain features
# - Wiener TD: extracts ONLY time domain features
# - Wiener FFT: extracts ONLY frequency domain features
# - Two-layer unidirectional LSTM with increased hidden size
# - Simple mean pooling instead of attention
# - Strong regularization (dropout=0.8, weight_decay=0.075)
# - Session-based data split (all users in training)
# =============================================================

import torch
import torch.nn as nn
import numpy as np

class UltraCompactBranch(nn.Module):
    """Ultra-compact branch: 2 conv layers + pooling (~300 params with d_model=8, ~500 with d_model=16)"""
    def __init__(self, in_channels=1, d_model=16):
        super().__init__()
        # First conv: extract basic features
        self.conv1 = nn.Conv1d(in_channels, 8, kernel_size=5, stride=2, padding=2)
        self.bn1 = nn.BatchNorm1d(8)
        # Second conv: refine features
        self.conv2 = nn.Conv1d(8, 16, kernel_size=3, stride=2, padding=1)
        self.bn2 = nn.BatchNorm1d(16)
        self.relu = nn.ReLU()
        # Global pooling + projection
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.proj = nn.Linear(16, d_model)
        self.dropout = nn.Dropout(0.3)  # Dropout after projection

    def forward(self, x):  # (batch, seq_len, channels, length) or (batch, channels, length)
        if x.dim() == 4:
            batch_size, seq_len, channels, length = x.shape
            x = x.view(batch_size * seq_len, channels, length)
            reshape = True
        else:
            reshape = False
            batch_size = x.shape[0]
            seq_len = 1

        x = self.relu(self.bn1(self.conv1(x)))  # (batch, 8, length/2)
        x = self.relu(self.bn2(self.conv2(x)))  # (batch, 16, length/4)
        x = self.pool(x)  # (batch, 16, 1)
        x = x.squeeze(-1)  # (batch, 16)
        x = self.proj(x)  # (batch, d_model)
        x = self.dropout(x)  # Regularization

        if reshape:
            x = x.view(batch_size, seq_len, -1)  # (batch, seq_len, d_model)
        return x

class RawEMGTimeBranch(UltraCompactBranch):
    """Raw EMG time-domain branch - ultra-compact version"""
    def __init__(self, d_model=16):
        super().__init__(in_channels=1, d_model=d_model)

class RawEMGFreqBranch(nn.Module):
    """Raw EMG frequency-domain branch - ultra-compact version"""
    def __init__(self, d_model=16, fft_bins=64):
        super().__init__()
        self.fft_bins = fft_bins
        # Process FFT bins directly with compact conv
        self.conv1 = nn.Conv1d(1, 8, kernel_size=5, stride=2, padding=2)  # 64 -> 32
        self.bn1 = nn.BatchNorm1d(8)
        self.conv2 = nn.Conv1d(8, 16, kernel_size=3, stride=2, padding=1)  # 32 -> 16
        self.bn2 = nn.BatchNorm1d(16)
        self.relu = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.proj = nn.Linear(16, d_model)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):  # (batch, seq_len, 1, 100) - raw EMG time domain
        if x.dim() == 4:
            batch_size, seq_len, channels, length = x.shape
            x = x.view(batch_size * seq_len, channels, length)
            reshape = True
        else:
            reshape = False
            batch_size = x.shape[0]
            seq_len = 1
            x = x.unsqueeze(0) if x.dim() == 2 else x

        # Compute FFT
        x_signal = x.squeeze(1)  # (batch, 100)
        x_fft = torch.fft.rfft(x_signal, dim=1)  # (batch, 51)
        x_fft_mag = torch.abs(x_fft)  # (batch, 51)

        # Interpolate to 64 bins
        if x_fft_mag.shape[1] != self.fft_bins:
            x_fft_mag = x_fft_mag.unsqueeze(1)
            x_fft_mag = torch.nn.functional.interpolate(
                x_fft_mag, size=self.fft_bins, mode='linear', align_corners=False
            )
            x_fft_mag = x_fft_mag.squeeze(1)

        x_fft_mag = x_fft_mag.unsqueeze(1)  # (batch, 1, 64)

        # Process with compact conv
        x = self.relu(self.bn1(self.conv1(x_fft_mag)))  # (batch, 8, 32)
        x = self.relu(self.bn2(self.conv2(x)))  # (batch, 16, 16)
        x = self.pool(x)  # (batch, 16, 1)
        x = x.squeeze(-1)  # (batch, 16)
        x = self.proj(x)  # (batch, d_model)
        x = self.dropout(x)

        if reshape:
            x = x.view(batch_size, seq_len, -1)
        return x

class RMSLMSBranch(UltraCompactBranch):
    """RMS+LMS filtered EMG branch - ultra-compact version"""
    def __init__(self, d_model=16):
        super().__init__(in_channels=1, d_model=d_model)

class WienerTDBranch(UltraCompactBranch):
    """Wiener filter time-domain branch - ultra-compact version"""
    def __init__(self, d_model=16):
        super().__init__(in_channels=1, d_model=d_model)

class WienerFFTBranch(nn.Module):
    """Wiener FFT frequency-domain branch - ultra-compact version"""
    def __init__(self, d_model=16, fft_bins=64):
        super().__init__()
        # Process pre-computed FFT bins directly
        self.conv1 = nn.Conv1d(1, 8, kernel_size=5, stride=2, padding=2)  # 64 -> 32
        self.bn1 = nn.BatchNorm1d(8)
        self.conv2 = nn.Conv1d(8, 16, kernel_size=3, stride=2, padding=1)  # 32 -> 16
        self.bn2 = nn.BatchNorm1d(16)
        self.relu = nn.ReLU()
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.proj = nn.Linear(16, d_model)
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):  # (batch, seq_len, 64) or (batch, 64) - FFT bins from CSV
        if x.dim() == 3:
            batch_size, seq_len, features = x.shape
            x = x.view(batch_size * seq_len, 1, features)
            reshape = True
        else:
            reshape = False
            batch_size = x.shape[0]
            seq_len = 1
            x = x.unsqueeze(1)  # (batch, 1, 64)

        x = self.relu(self.bn1(self.conv1(x)))  # (batch, 8, 32)
        x = self.relu(self.bn2(self.conv2(x)))  # (batch, 16, 16)
        x = self.pool(x)  # (batch, 16, 1)
        x = x.squeeze(-1)  # (batch, 16)
        x = self.proj(x)  # (batch, d_model)
        x = self.dropout(x)

        if reshape:
            x = x.view(batch_size, seq_len, -1)
        return x

class IMUBranch(UltraCompactBranch):
    """IMU branch - ultra-compact version"""
    def __init__(self, d_model=16):
        super().__init__(in_channels=6, d_model=d_model)

class EMGLSTMModel(nn.Module):
    """
    Phase 1: Compact multi-branch LSTM model for EMG gesture classification (~15-20K parameters)

    Architecture:
    - Compact branches: 2 conv layers + pooling (no residual blocks)
    - Raw EMG: extracts BOTH time and frequency domain features (noisy)
    - RMS/LMS: extracts ONLY time domain features (filtered)
    - Wiener TD: extracts ONLY time domain features (filtered)
    - Wiener FFT: extracts ONLY frequency domain features (filtered)
    - Two-layer unidirectional LSTM with increased hidden size
    - Simple mean pooling instead of attention
    - Strong regularization (dropout=0.8)
    """
    def __init__(self, num_classes=2, d_model=16, hidden_size=32, num_layers=2,
                 dropout=0.8, sequence_length=8, bidirectional=False):
        super().__init__()

        self.sequence_length = sequence_length
        self.d_model = d_model

        # Feature extraction branches (ultra-compact)
        # Raw EMG branches (kept for reference / potential diagnostics, not used as direct inputs)
        # These branches learn representations of the noisy EMG signal but are not concatenated
        # into the main feature vector. The LSTM instead relies on the filtered branches (RMS/LMS,
        # Wiener TD/FFT) and IMU to distinguish REST vs FIST.
        self.raw_time_branch = RawEMGTimeBranch(d_model=d_model)
        self.raw_freq_branch = RawEMGFreqBranch(d_model=d_model, fft_bins=64)

        # RMS/LMS: time domain only (filtered envelope)
        self.rms_lms_branch = RMSLMSBranch(d_model=d_model)

        # Wiener: time domain only (motion-artifact reduced)
        self.wiener_td_branch = WienerTDBranch(d_model=d_model)

        # Wiener: frequency domain only (motion-artifact reduced FFT bins)
        self.wiener_fft_branch = WienerFFTBranch(d_model=d_model, fft_bins=64)

        # IMU branch
        self.imu_branch = IMUBranch(d_model=d_model)

        # Concatenate filtered features only: rms_lms + wiener_td + wiener_fft + imu
        # Total: 4 branches * d_model = 4 * 16 = 64 features per window
        feature_dim = 4 * d_model

        # Two-layer unidirectional LSTM (increased capacity)
        self.lstm = nn.LSTM(
            input_size=feature_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=0.2 if num_layers > 1 else 0,  # Dropout between LSTM layers
            bidirectional=bidirectional,
            batch_first=True
        )

        # LSTM output dimension
        lstm_output_dim = hidden_size * 2 if bidirectional else hidden_size

        # Simple mean pooling (0 parameters) instead of attention
        # This reduces parameters significantly while maintaining effectiveness

        # Classification head with strong regularization
        self.dropout_cls = nn.Dropout(dropout)
        self.classifier = nn.Linear(lstm_output_dim, num_classes)

    def forward(self, batch):
        """
        Args:
            batch: Dictionary containing:
                - 'raw': (batch, seq_len, 1, window_size) - raw EMG
                - 'rms_lms': (batch, seq_len, 1, window_size) - RMS/LMS filtered
                - 'wiener_td': (batch, seq_len, 1, window_size) - Wiener time domain
                - 'wiener_fft': (batch, seq_len, 64) - Wiener frequency domain
                - 'imu': (batch, seq_len, 6, imu_samples) - IMU data
                - 'spectral_raw': (batch, seq_len, 5) - optional spectral features
                - 'spectral_wiener': (batch, seq_len, 5) - optional spectral features

        Returns:
            logits: (batch, num_classes)
        """
        # Extract inputs
        raw = batch['raw']  # (batch, seq_len, 1, window_size)
        rms_lms = batch['rms_lms']  # (batch, seq_len, 1, window_size)
        wiener_td = batch['wiener_td']  # (batch, seq_len, 1, window_size)
        wiener_fft = batch['wiener_fft']  # (batch, seq_len, 64)
        imu = batch['imu']  # (batch, seq_len, 6, imu_samples)

        batch_size, seq_len = raw.shape[:2]

        # Extract features from filtered branches
        # Note: Raw EMG branches are not concatenated into the main feature vector to avoid
        # confusing the classifier with noisy signals. The LSTM relies on filtered branches
        # (RMS/LMS, Wiener TD/FFT) and IMU, which already encode motion-artifact-reduced
        # time and frequency domain information.

        # RMS/LMS: time domain only (filtered envelope)
        features_rms_lms = self.rms_lms_branch(rms_lms)  # (batch, seq_len, d_model)

        # Wiener: time domain only (filtered)
        features_wiener_td = self.wiener_td_branch(wiener_td)  # (batch, seq_len, d_model)

        # Wiener: frequency domain only (filtered FFT bins)
        features_wiener_fft = self.wiener_fft_branch(wiener_fft)  # (batch, seq_len, d_model)

        # IMU features
        features_imu = self.imu_branch(imu)  # (batch, seq_len, d_model)

        # Concatenate filtered features only: [rms_lms, wiener_td, wiener_fft, imu]
        combined_features = torch.cat([
            features_rms_lms,       # Filtered time domain (RMS/LMS)
            features_wiener_td,     # Filtered time domain (Wiener)
            features_wiener_fft,    # Filtered frequency domain (Wiener)
            features_imu            # IMU features
        ], dim=2)  # (batch, seq_len, 4*d_model)

        # LSTM processes the sequence and learns temporal patterns
        lstm_out, (h_n, c_n) = self.lstm(combined_features)  # (batch, seq_len, hidden_size)

        # Simple mean pooling (0 parameters) - more efficient than attention
        pooled = lstm_out.mean(dim=1)  # (batch, hidden_size)

        # Classification with strong regularization
        pooled = self.dropout_cls(pooled)
        logits = self.classifier(pooled)  # (batch, num_classes)

        return logits

print("✓ Phase 1: Compact LSTM model architecture defined (~15-20K parameters)")
print("\nArchitecture Summary:")
print("  - Compact branches: 2 conv layers + pooling (no residual blocks)")
print("  - Raw EMG: extracts BOTH time and frequency domain features (noisy)")
print("  - RMS/LMS: extracts ONLY time domain features (filtered)")
print("  - Wiener TD: extracts ONLY time domain features (filtered)")
print("  - Wiener FFT: extracts ONLY frequency domain features (filtered)")
print("  - Two-layer unidirectional LSTM (increased capacity: d_model=16, hidden_size=32)")
print("  - Simple mean pooling (0 parameters)")
print("  - Strong regularization (dropout=0.8, weight_decay=0.075)")
print("  - Increased model capacity to address underfitting (~15-20K parameters)")

In [None]:
# =============================================================
# IMPORTS AND SETUP
# =============================================================

import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import json
from pathlib import Path
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# Check for TPU
try:
    import torch_xla.core.xla_model as xm
    import torch_xla.distributed.parallel_loader as pl
    DEVICE = xm.xla_device()
    USE_TPU = True
    print(f"Using TPU: {DEVICE}")
except:
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    USE_TPU = False
    print(f"Using device: {DEVICE}")

print(f"PyTorch version: {torch.__version__}")


In [None]:
# =============================================================
# DATA LOADING AND PREPROCESSING
# =============================================================
# Loads data from session folders, each containing:
# - emg_raw.csv, emg_rms_lms.csv, emg_wiener.csv, imu.csv
# - labels.csv with start_ms, end_ms, and gesture columns for timestamp-based labeling
# =============================================================

from torch.utils.data import Dataset
from scipy import signal
from scipy.fft import fft, fftfreq

class EMGDataset(Dataset):
    """Dataset for EMG gesture classification with sequential windowing and enhanced frequency features"""

    def __init__(self, data_dir: str, window_size: int = 100, stride: int = 25,
                 sequence_length: int = 8, normalize: bool = True, scalers: Dict = None,
                 use_augmentation: bool = False, augmentation_prob: float = 0.5,
                 sampling_rate: int = 500):
        """
        Args:
            data_dir: Root directory containing session folders
            window_size: Number of samples per window (100 = 200ms @ 500Hz)
            stride: Stride between windows (25 = 50ms @ 500Hz for 50% overlap)
            sequence_length: Number of sequential windows per sample (8 = 400ms context)
            normalize: Whether to normalize data
            scalers: Pre-fitted scalers for normalization (if None, fit new ones)
            use_augmentation: Whether to apply data augmentation during training
            augmentation_prob: Probability of applying augmentation to each sample
            sampling_rate: Sampling rate in Hz (default 500Hz for EMG)
        """
        self.data_dir = Path(data_dir)
        self.window_size = window_size
        self.stride = stride
        self.sequence_length = sequence_length
        self.normalize = normalize
        self.use_augmentation = use_augmentation
        self.augmentation_prob = augmentation_prob
        self.sampling_rate = sampling_rate

        # Find all session directories
        self.sessions = sorted([d for d in self.data_dir.iterdir() if d.is_dir()])
        print(f"Found {len(self.sessions)} sessions")

        # Load all data
        self.windows = []
        self.labels = []
        self.scalers = scalers if scalers else {}

        self._load_all_sessions()

        # Create sequences from windows
        self.sequences = []
        self.sequence_labels = []
        self._create_sequences()

        # Fit scalers if needed
        if normalize and not scalers:
            self._fit_scalers()

        # Normalize data
        if normalize:
            self._normalize_data()

    def _compute_spectral_features(self, signal_window, sampling_rate=500):
        """Compute comprehensive spectral features from a signal window"""
        n = len(signal_window)
        fft_vals = fft(signal_window)
        fft_magnitude = np.abs(fft_vals[:n//2])

        # Normalize to 64 bins to match CSV FFT bins
        if len(fft_magnitude) >= 64:
            indices = np.linspace(0, len(fft_magnitude)-1, 64).astype(int)
            fft_magnitude_64 = fft_magnitude[indices]
        else:
            fft_magnitude_64 = np.interp(np.linspace(0, len(fft_magnitude)-1, 64),
                                         np.arange(len(fft_magnitude)), fft_magnitude)

        freqs = fftfreq(n, 1/sampling_rate)[:n//2]

        # Compute bandpower in different frequency bands
        low_mask = (freqs >= 20) & (freqs <= 60)
        bandpower_low = np.sum(fft_magnitude[low_mask]**2) if np.any(low_mask) else 0.0

        mid_mask = (freqs >= 60) & (freqs <= 150)
        bandpower_mid = np.sum(fft_magnitude[mid_mask]**2) if np.any(mid_mask) else 0.0

        high_mask = (freqs >= 150) & (freqs <= 250)
        bandpower_high = np.sum(fft_magnitude[high_mask]**2) if np.any(high_mask) else 0.0

        # Spectral entropy
        power_spectrum = fft_magnitude**2
        power_spectrum_norm = power_spectrum / (np.sum(power_spectrum) + 1e-10)
        spectral_entropy = -np.sum(power_spectrum_norm * np.log(power_spectrum_norm + 1e-10))

        # Spectral centroid
        if np.sum(power_spectrum) > 0:
            spectral_centroid = np.sum(freqs * power_spectrum) / np.sum(power_spectrum)
        else:
            spectral_centroid = 0.0

        return {
            'fft_magnitude': fft_magnitude_64.astype(np.float32),
            'bandpower_low': np.float32(bandpower_low),
            'bandpower_mid': np.float32(bandpower_mid),
            'bandpower_high': np.float32(bandpower_high),
            'spectral_entropy': np.float32(spectral_entropy),
            'spectral_centroid': np.float32(spectral_centroid)
        }

    def _normalize_window_per_user(self, signal_window, method='z-score'):
        """
        Normalize a signal window to remove user-specific baseline and scale.
        This helps the model learn gesture patterns, not user-specific patterns.

        Args:
            signal_window: 1D numpy array of signal values
            method: 'z-score' (default), 'min-max', or 'robust'

        Returns:
            Normalized signal window
        """
        if method == 'z-score':
            # Standardize: (x - mean) / std
            mean = np.mean(signal_window)
            std = np.std(signal_window)
            if std < 1e-8:  # Avoid division by zero
                return signal_window - mean
            return (signal_window - mean) / std
        elif method == 'min-max':
            # Scale to [0, 1]
            min_val = np.min(signal_window)
            max_val = np.max(signal_window)
            if max_val - min_val < 1e-8:
                return signal_window - min_val
            return (signal_window - min_val) / (max_val - min_val)
        elif method == 'robust':
            # Use median and IQR (more robust to outliers)
            median = np.median(signal_window)
            q75 = np.percentile(signal_window, 75)
            q25 = np.percentile(signal_window, 25)
            iqr = q75 - q25
            if iqr < 1e-8:
                return signal_window - median
            return (signal_window - median) / iqr
        return signal_window

    def _check_timestamp_continuity(self, timestamps, session_name, signal_name, expected_rate_hz=500):
        """Check for timestamp gaps (UDP packet loss detection)"""
        if len(timestamps) < 2:
            return
        intervals = np.diff(timestamps)
        expected_interval = 1000 / expected_rate_hz
        threshold = expected_interval * 1.5
        gaps = intervals > threshold
        if np.any(gaps):
            gap_count = np.sum(gaps)
            max_gap = np.max(intervals)
            gap_pct = 100 * gap_count / len(intervals)
            if gap_pct > 1.0:
                print(f"⚠️  {session_name} {signal_name}: {gap_count} timestamp gaps ({gap_pct:.2f}%), max gap: {max_gap:.2f}ms")

    def _augment_window(self, raw_window, noise_level=0.1):
        """
        Enhanced data augmentation for EMG signals
        Carefully designed to preserve signal characteristics while increasing diversity
        """
        augmented = raw_window.copy()

        # 1. Gaussian noise (simulating sensor noise) - keep original
        noise_std = np.std(raw_window) * noise_level
        noise = np.random.normal(0, noise_std, size=raw_window.shape)
        augmented = augmented + noise

        # 2. Amplitude scaling (simulating different muscle activation levels) - keep original
        scale_factor = np.random.uniform(0.8, 1.2)
        augmented = augmented * scale_factor

        # 3. Time warping (slight stretching/compression) - keep original
        if len(augmented) > 10:
            warp_factor = np.random.uniform(0.95, 1.05)
            original_indices = np.arange(len(augmented))
            warped_indices = original_indices * warp_factor
            warped_indices = np.clip(warped_indices, 0, len(augmented) - 1).astype(int)
            augmented = augmented[warped_indices]

        # 4. Time shifting (small temporal delays) - NEW
        # Shift by up to 5% of window length (preserves signal structure)
        max_shift = max(1, int(len(augmented) * 0.05))
        shift = np.random.randint(-max_shift, max_shift + 1)
        if shift != 0:
            if shift > 0:
                # Shift right: move beginning to end
                augmented = np.concatenate([augmented[shift:], augmented[:shift]])
            else:
                # Shift left: move end to beginning (shift is negative)
                augmented = np.concatenate([augmented[shift:], augmented[:shift]])

        # 5. Frequency domain augmentation (slight frequency shifts) - NEW
        # Apply small frequency shift via phase manipulation (preserves amplitude)
        if len(augmented) > 20:
            # Compute FFT
            fft_signal = np.fft.fft(augmented)
            # Apply small random phase shift (max 5% frequency shift)
            phase_shift = np.random.uniform(-0.05, 0.05) * 2 * np.pi
            freqs = np.fft.fftfreq(len(augmented))
            phase_shift_array = np.exp(1j * phase_shift * freqs)
            fft_signal_shifted = fft_signal * phase_shift_array
            # Convert back to time domain
            augmented = np.real(np.fft.ifft(fft_signal_shifted))

        return augmented.astype(np.float32)

    def _load_all_sessions(self):
        """Load data from all sessions and create windows"""
        label_encoder = LabelEncoder()
        all_gestures = []

        for session_dir in self.sessions:
            try:
                # Load CSV files from session folder
                emg_raw = pd.read_csv(session_dir / 'emg_raw.csv')
                emg_rms_lms = pd.read_csv(session_dir / 'emg_rms_lms.csv')
                emg_wiener = pd.read_csv(session_dir / 'emg_wiener.csv')
                imu = pd.read_csv(session_dir / 'imu.csv')
                labels_df = pd.read_csv(session_dir / 'labels.csv')  # IMPORTANT: labels.csv with start_ms, end_ms, gesture

                # Extract timestamps
                timestamps_raw = emg_raw['timestamp_ms'].values
                timestamps_rms = emg_rms_lms['timestamp_ms'].values
                timestamps_wiener = emg_wiener['timestamp_ms'].values
                timestamps_imu = imu['timestamp_ms'].values

                # Check for timestamp gaps
                self._check_timestamp_continuity(timestamps_raw, session_dir.name, 'raw_EMG', expected_rate_hz=500)
                self._check_timestamp_continuity(timestamps_imu, session_dir.name, 'IMU', expected_rate_hz=100)

                # Extract signal data
                raw_signal = emg_raw['ch1'].values
                rms_signal = emg_rms_lms['rms_ch1'].values
                lms_signal = emg_rms_lms['lms_ch1'].values
                wiener_td = emg_wiener['wiener_td_ch1'].values

                # Extract FFT bins (64 bins: fft_bin_0 to fft_bin_63)
                fft_cols = [f'fft_bin_{i}' for i in range(64)]
                wiener_fft = emg_wiener[fft_cols].values if all(col in emg_wiener.columns for col in fft_cols) else None

                # IMU data
                imu_data = imu[['ax', 'ay', 'az', 'gx', 'gy', 'gz']].values

                # Create label mapping from timestamps using labels.csv
                label_map = self._create_label_map(labels_df, timestamps_raw)

                # Create windows
                session_windows, session_labels = self._create_windows(
                    timestamps_raw, raw_signal,
                    timestamps_rms, rms_signal, lms_signal,
                    timestamps_wiener, wiener_td, wiener_fft,
                    timestamps_imu, imu_data,
                    label_map
                )

                self.windows.extend(session_windows)
                self.labels.extend(session_labels)
                all_gestures.extend(session_labels)

            except Exception as e:
                print(f"Error loading session {session_dir.name}: {e}")
                import traceback
                traceback.print_exc()
                continue

        # Encode labels
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(all_gestures)
        self.num_classes = len(self.label_encoder.classes_)
        print(f"Classes: {self.label_encoder.classes_}")
        print(f"Total windows: {len(self.windows)}")

    def _create_sequences(self):
        """Create sequences of sequential windows"""
        i = 0
        while i + self.sequence_length <= len(self.windows):
            sequence = self.windows[i:i+self.sequence_length]
            sequence_window_labels = self.labels[i:i+self.sequence_length]

            # Use majority label for the sequence
            unique_labels, counts = np.unique(sequence_window_labels, return_counts=True)
            majority_label_idx = np.argmax(counts)
            label = unique_labels[majority_label_idx]

            # Only accept sequences with reasonable label consistency (>60% same label)
            label_consistency = np.max(counts) / len(sequence_window_labels)
            if label_consistency < 0.60:
                i += self.stride
                continue

            self.sequences.append(sequence)
            self.sequence_labels.append(label)
            i += self.stride

        print(f"Total sequences: {len(self.sequences)} (each with {self.sequence_length} windows)")

    def _create_label_map(self, labels_df: pd.DataFrame, timestamps: np.ndarray) -> np.ndarray:
        """
        Create label array aligned with timestamps from labels.csv
        Uses start_ms, end_ms, and gesture columns to label the data
        """
        label_map = np.full(len(timestamps), -1, dtype=int)

        for _, row in labels_df.iterrows():
            start_ms = row['start_ms']
            end_ms = row['end_ms']
            gesture = row['gesture']

            # Skip MIXED labels - these are block-level summaries, not actual gestures
            if gesture == 'MIXED':
                continue

            # Handle zero-duration labels (common in motion_artifact and electrode_drift blocks)
            if start_ms == end_ms:
                closest_idx = np.argmin(np.abs(timestamps - start_ms))
                closest_time = timestamps[closest_idx]
                window_half_ms = 10  # 10ms on each side = 20ms total
                start_ms = max(timestamps[0], closest_time - window_half_ms)
                end_ms = min(timestamps[-1], closest_time + window_half_ms)

            # Find indices within this time range
            mask = (timestamps >= start_ms) & (timestamps <= end_ms)

            if not np.any(mask):
                closest_idx = np.argmin(np.abs(timestamps - (start_ms + end_ms) / 2))
                mask = np.zeros(len(timestamps), dtype=bool)
                mask[closest_idx] = True

            # Assign labels (don't overwrite existing labels)
            if gesture == 'REST':
                label_map[mask & (label_map == -1)] = 0
            elif gesture == 'FIST':
                label_map[mask & (label_map == -1)] = 1

        return label_map

    def _create_windows(self, ts_raw, raw, ts_rms, rms, lms, ts_wiener, wiener_td, wiener_fft,
                       ts_imu, imu, label_map):
        """Create overlapping windows from time-series data"""
        windows = []
        labels = []

        max_len = len(raw)
        i = 0

        while i + self.window_size <= max_len:
            end_idx = i + self.window_size
            window_labels = label_map[i:end_idx]

            # Skip completely unlabeled windows
            if np.all(window_labels == -1):
                i += self.stride
                continue

            # Calculate label distribution in window
            unique_labels, counts = np.unique(window_labels, return_counts=True)
            labeled_mask = unique_labels != -1
            if not np.any(labeled_mask):
                i += self.stride
                continue

            labeled_unique = unique_labels[labeled_mask]
            labeled_counts = counts[labeled_mask]

            # Only accept windows with reasonable label purity (>70% same label)
            label_purity = np.max(labeled_counts) / np.sum(labeled_counts)
            if label_purity < 0.70:
                i += self.stride
                continue

            # Use majority label
            majority_label_idx = np.argmax(labeled_counts)
            label = labeled_unique[majority_label_idx]

            if label == -1:
                i += self.stride
                continue

            # Extract EMG windows
            raw_window = raw[i:end_idx].copy()

            # Apply augmentation if enabled
            if self.use_augmentation and np.random.random() < self.augmentation_prob:
                raw_window = self._augment_window(raw_window)

            # Align RMS/LMS and Wiener signals using timestamps
            window_start_time = ts_raw[i]
            window_end_time = ts_raw[end_idx-1]

            # RMS: Downsampled envelope (~10Hz) - interpolate
            rms_mask = (ts_rms >= window_start_time) & (ts_rms <= window_end_time)
            rms_indices = np.where(rms_mask)[0]
            if len(rms_indices) >= 2:
                rms_window = np.interp(ts_raw[i:end_idx], ts_rms[rms_indices], rms[rms_indices])
            elif len(rms_indices) == 1:
                rms_window = np.full(self.window_size, rms[rms_indices[0]])
            else:
                nearest_rms_idx = np.argmin(np.abs(ts_rms - window_start_time))
                rms_window = np.full(self.window_size, rms[nearest_rms_idx])

            # LMS: Full rate (500Hz)
            lms_mask = (ts_rms >= window_start_time) & (ts_rms <= window_end_time)
            lms_indices = np.where(lms_mask)[0]
            if len(lms_indices) >= self.window_size:
                lms_window = lms[lms_indices[:self.window_size]]
            elif len(lms_indices) >= 2:
                lms_window = np.interp(ts_raw[i:end_idx], ts_rms[lms_indices], lms[lms_indices])
            elif len(lms_indices) == 1:
                lms_window = np.full(self.window_size, lms[lms_indices[0]])
            else:
                nearest_lms_idx = np.argmin(np.abs(ts_rms - window_start_time))
                lms_window = np.full(self.window_size, lms[nearest_lms_idx])

            # Wiener TD
            wiener_mask = (ts_wiener >= window_start_time) & (ts_wiener <= window_end_time)
            wiener_indices = np.where(wiener_mask)[0]
            if len(wiener_indices) >= self.window_size:
                wiener_td_window = wiener_td[wiener_indices[:self.window_size]]
            elif len(wiener_indices) >= 2:
                wiener_td_window = np.interp(ts_raw[i:end_idx], ts_wiener[wiener_indices], wiener_td[wiener_indices])
            elif len(wiener_indices) == 1:
                wiener_td_window = np.full(self.window_size, wiener_td[wiener_indices[0]])
            else:
                nearest_wiener_idx = np.argmin(np.abs(ts_wiener - window_start_time))
                wiener_td_window = np.full(self.window_size, wiener_td[nearest_wiener_idx])

            # Compute spectral features
            spectral_features_raw = self._compute_spectral_features(raw_window, self.sampling_rate)
            spectral_features_wiener = self._compute_spectral_features(wiener_td_window, self.sampling_rate)

            # Use pre-computed Wiener FFT bins from CSV (64 bins)
            window_center_time = (window_start_time + window_end_time) / 2
            if wiener_fft is not None and len(wiener_fft) > 0:
                fft_time_distances = np.abs(ts_wiener - window_center_time)
                closest_fft_idx = np.argmin(fft_time_distances)
                if closest_fft_idx < len(wiener_fft):
                    wiener_fft_window = wiener_fft[closest_fft_idx]
                else:
                    wiener_fft_window = wiener_fft[-1]
            else:
                wiener_fft_window = np.zeros(64)

            # Extract IMU window (100Hz)
            window_duration_ms = (end_idx - i) / self.sampling_rate * 1000
            imu_samples_needed = int(window_duration_ms / 10)
            imu_mask = (ts_imu >= window_start_time) & (ts_imu <= window_end_time)
            imu_indices = np.where(imu_mask)[0]
            if len(imu_indices) >= imu_samples_needed:
                imu_window = imu[imu_indices[:imu_samples_needed]]
            elif len(imu_indices) > 0:
                imu_window = np.zeros((imu_samples_needed, 6))
                imu_window[:len(imu_indices)] = imu[imu_indices]
            else:
                nearest_idx = np.argmin(np.abs(ts_imu - window_start_time))
                imu_window = np.tile(imu[nearest_idx], (imu_samples_needed, 1))

            # =============================================================
            # STEP 1: PER-WINDOW Z-SCORE NORMALIZATION
            # =============================================================
            # Normalize each window to remove user-specific baselines/scales.
            # This is critical for LOUO - makes signals comparable across different users.
            # This step removes user-specific amplitude differences while preserving
            # the relative signal structure within each window.
            # =============================================================
            raw_window = self._normalize_window_per_user(raw_window, method='z-score')
            rms_window = self._normalize_window_per_user(rms_window, method='z-score')
            lms_window = self._normalize_window_per_user(lms_window, method='z-score')
            wiener_td_window = self._normalize_window_per_user(wiener_td_window, method='z-score')
            # Note: wiener_fft is already in frequency domain, normalize it too
            if len(wiener_fft_window) > 0:
                wiener_fft_window = self._normalize_window_per_user(wiener_fft_window, method='z-score')
            # IMU: normalize each channel separately (imu_window is shape (6, samples))
            imu_window_normalized = np.zeros_like(imu_window)
            for ch in range(imu_window.shape[0]):  # 6 channels
                imu_window_normalized[ch] = self._normalize_window_per_user(imu_window[ch], method='z-score')
            imu_window = imu_window_normalized

            # Pack window data
            window_data = {
                'raw': raw_window.astype(np.float32),
                'rms_lms': np.stack([rms_window, lms_window], axis=0).astype(np.float32),
                'wiener_td': wiener_td_window.astype(np.float32),
                'wiener_fft': wiener_fft_window.astype(np.float32),
                'spectral_features_raw': {
                    'bandpower_low': spectral_features_raw['bandpower_low'],
                    'bandpower_mid': spectral_features_raw['bandpower_mid'],
                    'bandpower_high': spectral_features_raw['bandpower_high'],
                    'spectral_entropy': spectral_features_raw['spectral_entropy'],
                    'spectral_centroid': spectral_features_raw['spectral_centroid']
                },
                'spectral_features_wiener': {
                    'bandpower_low': spectral_features_wiener['bandpower_low'],
                    'bandpower_mid': spectral_features_wiener['bandpower_mid'],
                    'bandpower_high': spectral_features_wiener['bandpower_high'],
                    'spectral_entropy': spectral_features_wiener['spectral_entropy'],
                    'spectral_centroid': spectral_features_wiener['spectral_centroid']
                },
                'imu': imu_window.T.astype(np.float32)
            }

            windows.append(window_data)
            labels.append(label)
            i += self.stride

        return windows, labels

    def _fit_scalers(self):
        """
        Fit StandardScalers for STEP 2: GLOBAL NORMALIZATION
        
        After per-window z-score normalization (STEP 1), we fit global StandardScalers
        on all training windows. This preserves REST vs FIST amplitude differences
        across the entire dataset, which is important for discriminative learning.
        
        These scalers are saved and must be applied in deployment to match training normalization.
        """
        raw_data = np.concatenate([w['raw'] for w in self.windows])
        rms_lms_data = np.concatenate([w['rms_lms'].flatten() for w in self.windows])
        wiener_td_data = np.concatenate([w['wiener_td'] for w in self.windows])
        wiener_fft_data = np.stack([w['wiener_fft'] for w in self.windows])
        imu_data = np.concatenate([w['imu'].flatten() for w in self.windows])

        spectral_raw_features = np.array([
            [w['spectral_features_raw']['bandpower_low'],
             w['spectral_features_raw']['bandpower_mid'],
             w['spectral_features_raw']['bandpower_high'],
             w['spectral_features_raw']['spectral_entropy'],
             w['spectral_features_raw']['spectral_centroid']]
            for w in self.windows
        ])

        spectral_wiener_features = np.array([
            [w['spectral_features_wiener']['bandpower_low'],
             w['spectral_features_wiener']['bandpower_mid'],
             w['spectral_features_wiener']['bandpower_high'],
             w['spectral_features_wiener']['spectral_entropy'],
             w['spectral_features_wiener']['spectral_centroid']]
            for w in self.windows
        ])

        self.scalers['raw'] = StandardScaler().fit(raw_data.reshape(-1, 1))
        self.scalers['rms_lms'] = StandardScaler().fit(rms_lms_data.reshape(-1, 1))
        self.scalers['wiener_td'] = StandardScaler().fit(wiener_td_data.reshape(-1, 1))
        self.scalers['wiener_fft'] = StandardScaler().fit(wiener_fft_data)
        self.scalers['imu'] = StandardScaler().fit(imu_data.reshape(-1, 1))
        self.scalers['spectral_raw'] = StandardScaler().fit(spectral_raw_features)
        self.scalers['spectral_wiener'] = StandardScaler().fit(spectral_wiener_features)

    def _normalize_data(self):
        """
        STEP 2: GLOBAL STANDARDSCALER NORMALIZATION
        
        Apply the fitted StandardScalers to all windows. This step:
        1. Preserves REST vs FIST amplitude differences across the dataset
        2. Ensures consistent feature scaling for the model
        3. Must be replicated in deployment using saved scalers
        
        Combined with STEP 1 (per-window z-score), this two-step normalization:
        - Removes user-specific baselines/scales (STEP 1)
        - Preserves class-discriminative amplitude patterns (STEP 2)
        """
        for window in self.windows:
            window['raw'] = self.scalers['raw'].transform(window['raw'].reshape(-1, 1)).flatten()
            rms_lms_flat = window['rms_lms'].flatten()
            rms_lms_norm = self.scalers['rms_lms'].transform(rms_lms_flat.reshape(-1, 1)).flatten()
            window['rms_lms'] = rms_lms_norm.reshape(2, -1)
            window['wiener_td'] = self.scalers['wiener_td'].transform(window['wiener_td'].reshape(-1, 1)).flatten()
            window['wiener_fft'] = self.scalers['wiener_fft'].transform(window['wiener_fft'].reshape(1, -1)).flatten()
            imu_flat = window['imu'].flatten()
            imu_norm = self.scalers['imu'].transform(imu_flat.reshape(-1, 1)).flatten()
            window['imu'] = imu_norm.reshape(6, -1)

            spectral_raw = np.array([
                window['spectral_features_raw']['bandpower_low'],
                window['spectral_features_raw']['bandpower_mid'],
                window['spectral_features_raw']['bandpower_high'],
                window['spectral_features_raw']['spectral_entropy'],
                window['spectral_features_raw']['spectral_centroid']
            ]).reshape(1, -1)
            spectral_raw_norm = self.scalers['spectral_raw'].transform(spectral_raw).flatten()
            window['spectral_features_raw'] = {
                'bandpower_low': spectral_raw_norm[0],
                'bandpower_mid': spectral_raw_norm[1],
                'bandpower_high': spectral_raw_norm[2],
                'spectral_entropy': spectral_raw_norm[3],
                'spectral_centroid': spectral_raw_norm[4]
            }

            spectral_wiener = np.array([
                window['spectral_features_wiener']['bandpower_low'],
                window['spectral_features_wiener']['bandpower_mid'],
                window['spectral_features_wiener']['bandpower_high'],
                window['spectral_features_wiener']['spectral_entropy'],
                window['spectral_features_wiener']['spectral_centroid']
            ]).reshape(1, -1)
            spectral_wiener_norm = self.scalers['spectral_wiener'].transform(spectral_wiener).flatten()
            window['spectral_features_wiener'] = {
                'bandpower_low': spectral_wiener_norm[0],
                'bandpower_mid': spectral_wiener_norm[1],
                'bandpower_high': spectral_wiener_norm[2],
                'spectral_entropy': spectral_wiener_norm[3],
                'spectral_centroid': spectral_wiener_norm[4]
            }

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        label = self.sequence_labels[idx]

        # Stack windows into sequences
        raw_seq = []
        rms_lms_seq = []
        wiener_td_seq = []
        wiener_fft_seq = []
        imu_seq = []
        spectral_raw_seq = []
        spectral_wiener_seq = []

        for window in sequence:
            raw_seq.append(window['raw'])
            rms_lms_seq.append(window['rms_lms'][0])  # Using RMS only
            wiener_td_seq.append(window['wiener_td'])
            wiener_fft_seq.append(window['wiener_fft'])
            imu_seq.append(window['imu'])

            spectral_raw_seq.append([
                window['spectral_features_raw']['bandpower_low'],
                window['spectral_features_raw']['bandpower_mid'],
                window['spectral_features_raw']['bandpower_high'],
                window['spectral_features_raw']['spectral_entropy'],
                window['spectral_features_raw']['spectral_centroid']
            ])
            spectral_wiener_seq.append([
                window['spectral_features_wiener']['bandpower_low'],
                window['spectral_features_wiener']['bandpower_mid'],
                window['spectral_features_wiener']['bandpower_high'],
                window['spectral_features_wiener']['spectral_entropy'],
                window['spectral_features_wiener']['spectral_centroid']
            ])

        return {
            'raw': torch.FloatTensor(np.stack(raw_seq)).unsqueeze(1),  # (seq_len, 1, window_size)
            'rms_lms': torch.FloatTensor(np.stack(rms_lms_seq)).unsqueeze(1),  # (seq_len, 1, window_size)
            'wiener_td': torch.FloatTensor(np.stack(wiener_td_seq)).unsqueeze(1),  # (seq_len, 1, window_size)
            'wiener_fft': torch.FloatTensor(np.stack(wiener_fft_seq)),  # (seq_len, 64)
            'spectral_raw': torch.FloatTensor(np.stack(spectral_raw_seq)),  # (seq_len, 5)
            'spectral_wiener': torch.FloatTensor(np.stack(spectral_wiener_seq)),  # (seq_len, 5)
            'imu': torch.FloatTensor(np.stack(imu_seq)),  # (seq_len, 6, imu_samples)
            'label': torch.LongTensor([label])[0]
        }

class EMGDatasetFiltered(EMGDataset):
    """EMG Dataset with session filtering for train/validation/test splits"""
    def __init__(self, data_dir: str, include_sessions: List[str] = None,
                 window_size: int = 100, stride: int = 25, sequence_length: int = 8,
                 normalize: bool = True, scalers: Dict = None,
                 use_augmentation: bool = False, augmentation_prob: float = 0.5,
                 sampling_rate: int = 500):
        self.data_dir = Path(data_dir)
        self.window_size = window_size
        self.stride = stride
        self.sequence_length = sequence_length
        self.normalize = normalize
        self.include_sessions = include_sessions
        self.use_augmentation = use_augmentation
        self.augmentation_prob = augmentation_prob
        self.sampling_rate = sampling_rate

        # Find all session directories
        all_sessions = sorted([d for d in self.data_dir.iterdir() if d.is_dir()])

        # Filter sessions if needed
        if self.include_sessions:
            self.sessions = [s for s in all_sessions if s.name in self.include_sessions]
        else:
            self.sessions = all_sessions

        print(f"Found {len(self.sessions)} sessions (filtered)")

        # Load all data
        self.windows = []
        self.labels = []
        self.scalers = scalers if scalers else {}

        self._load_all_sessions()

        # Create sequences from windows
        self.sequences = []
        self.sequence_labels = []
        self._create_sequences()

        # Fit scalers if needed
        if normalize and not scalers:
            self._fit_scalers()

        # Normalize data
        if normalize:
            self._normalize_data()

print("✓ EMGDataset and EMGDatasetFiltered classes defined")
print("  - Loads data from session folders")
print("  - Uses labels.csv with start_ms, end_ms, and gesture columns for timestamp-based labeling")
print("  - TWO-STEP NORMALIZATION ENABLED:")
print("    STEP 1: Per-window z-score normalization (removes user-specific baselines/scales)")
print("    STEP 2: Global StandardScaler normalization (preserves REST vs FIST amplitude differences)")
print("    This is critical for LOUO cross-validation and deployment consistency")
print("    StandardScalers are saved and must be applied in deployment to match training normalization")


In [None]:
# =============================================================
# LOSS FUNCTION: Focal Loss
# =============================================================

class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance"""
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        """
        Args:
            alpha: Weighting factor for each class (tensor). If None, uses uniform weights.
            gamma: Focusing parameter. Higher gamma focuses more on hard examples.
            reduction: 'mean' or 'sum'
        """
        super().__init__()
        if alpha is not None:
            self.register_buffer('alpha', alpha)
        else:
            self.alpha = None
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Calculate CE loss without weights first
        ce_loss = nn.functional.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss

        # Apply alpha per-class weighting to focal loss
        if self.alpha is not None:
            alpha_t = self.alpha[targets]  # Get alpha for each sample's true class
            focal_loss = alpha_t * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

print("✓ Focal Loss defined")


In [None]:
# =============================================================
# TRAINING AND EVALUATION FUNCTIONS
# =============================================================

def train_epoch(model, dataloader, criterion, optimizer, device, print_per_class_loss=False):
    """Train for one epoch with gradient clipping and per-class loss tracking"""
    model.train()
    total_loss = 0
    all_preds = []
    all_labels = []
    per_class_loss = {}  # Track loss per class for diagnostics
    
    for batch in dataloader:
        # Move to device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        labels = batch['label']
        
        # Forward pass
        optimizer.zero_grad()
        logits = model(batch)
        loss = criterion(logits, labels)
        
        # Track per-class loss for diagnostics
        if print_per_class_loss:
            with torch.no_grad():
                ce_loss = nn.functional.cross_entropy(logits, labels, reduction='none')
                for cls in torch.unique(labels):
                    cls_mask = labels == cls
                    if cls_mask.any():
                        if cls.item() not in per_class_loss:
                            per_class_loss[cls.item()] = []
                        per_class_loss[cls.item()].extend(ce_loss[cls_mask].cpu().numpy())
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping to prevent exploding gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    # Print per-class loss if requested
    if print_per_class_loss and per_class_loss:
        print(f"    Per-class avg loss: ", end="")
        for cls in sorted(per_class_loss.keys()):
            avg_loss = np.mean(per_class_loss[cls])
            print(f"Class {cls}: {avg_loss:.4f}  ", end="")
        print()
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)
    
    return avg_loss, accuracy, precision, recall, f1

def evaluate(model, dataloader, criterion, device):
    """Evaluate model"""
    model.eval()
    total_loss = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            # Move to device
            batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
            labels = batch['label']
            
            # Forward pass
            logits = model(batch)
            loss = criterion(logits, labels)
            
            # Metrics
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='weighted', zero_division=0)
    cm = confusion_matrix(all_labels, all_preds)
    
    return avg_loss, accuracy, precision, recall, f1, cm

print("✓ Training and evaluation functions defined")


In [None]:
# =============================================================
# SINGLE USER TEST (Leave-One-User-Out Style)
# =============================================================
# Tests true user generalization: train on all users except userB,
# validate to tune hyperparameters, then test on completely unseen userB.
# This simulates "new person walking in off the street" without computationally
# expensive retraining on previously seen data.
# =============================================================

# Configuration (same as Cell 4)
SEQUENCE_LENGTH = 8  # Number of sequential windows per sample
WINDOW_SIZE = 100    # Window size in samples (200ms @ 500Hz)
STRIDE = 25          # Stride between windows (50ms @ 500Hz, 50% overlap)

# PHASE 1: Compact LSTM Model hyperparameters with increased capacity
# Increased from d_model=8, hidden_size=16, num_layers=1 to address underfitting
D_MODEL = 16         # Feature dimension from each branch (increased from 8→14→16 for better capacity)
HIDDEN_SIZE = 32     # LSTM hidden size (increased from 16→28→32 for better capacity)
NUM_LAYERS = 2       # Number of LSTM layers (increased from 1 to 2)
DROPOUT = 0.82       # Dropout rate (balanced: between 0.8 and 0.85 for stable training)
BIDIRECTIONAL = False # Use unidirectional LSTM (keep unidirectional)

# Training hyperparameters with balanced regularization
BATCH_SIZE = 32      # Smaller batches for more stable gradients
INITIAL_LR = 5e-4    # Lower learning rate (reduced from 1e-3)
WEIGHT_DECAY = 0.075 # Balanced L2 regularization (between 0.05 and 0.1 for stable training)
NUM_EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10
AUGMENTATION_PROB = 0.9  # Very aggressive augmentation (increased from 0.7)

print(f"\n{'='*60}")
print("LOUO: COMPACT LSTM MODEL CONFIGURATION")
print(f"{'='*60}")
print(f"  Sequence length: {SEQUENCE_LENGTH} windows")
print(f"  Window size: {WINDOW_SIZE} samples ({WINDOW_SIZE * 2}ms @ 500Hz)")
print(f"  Stride: {STRIDE} samples ({STRIDE * 2}ms, 50% overlap)")
print(f"  d_model: {D_MODEL} (increased from 8 to address underfitting)")
print(f"  LSTM hidden size: {HIDDEN_SIZE} (increased from 16 to address underfitting)")
print(f"  LSTM layers: {NUM_LAYERS} (increased from 1 to address underfitting)")
print(f"  Dropout: {DROPOUT}")
print(f"  Weight decay: {WEIGHT_DECAY}")
print(f"  Learning rate: {INITIAL_LR}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Augmentation prob: {AUGMENTATION_PROB}")
print(f"{'='*60}\n")

# Set data directory - auto-detect Google Colab vs local
current_dir = os.getcwd()
is_colab = current_dir.startswith('/content')

if is_colab:
    # Google Colab path
    DATA_DIR = '/content/drive/MyDrive/emg_dataset/dataset'
    print(f"Detected Google Colab environment")
else:
    # Local path
    DATA_DIR = '/Users/gabdomingo/Documents/emg_dataset/dataset'

# Check if dataset exists
if not os.path.exists(DATA_DIR):
    print(f"⚠️  Dataset directory not found: {DATA_DIR}")
    print("Please update DATA_DIR to point to your dataset folder")
    if is_colab:
        print("\nFor Google Colab:")
        print("  1. Make sure you've mounted Google Drive")
        print("  2. Check that the dataset is at: /content/drive/MyDrive/emg_dataset/dataset")
        print("  3. Verify the path contains session subfolders")
    raise FileNotFoundError(f"Dataset directory not found: {DATA_DIR}")

print(f"✓ Dataset directory found: {DATA_DIR}")

# Load full dataset to get session info
full_dataset = EMGDataset(DATA_DIR, normalize=False, sequence_length=SEQUENCE_LENGTH,
                         window_size=WINDOW_SIZE, stride=STRIDE)
all_sessions = [s.name for s in full_dataset.sessions]

# Group sessions by user
def get_user_from_session(session_name):
    """Extract user name from session name"""
    if 'user' in session_name:
        user = session_name.split('user')[-1].split('/')[0]
        return user
    return None

user_sessions = {}
for session_name in all_sessions:
    user = get_user_from_session(session_name)
    if user:
        if user not in user_sessions:
            user_sessions[user] = []
        user_sessions[user].append(session_name)

all_users = sorted(user_sessions.keys())
print(f"Found {len(all_users)} users: {all_users}")
print(f"Total sessions: {len(all_sessions)}")

# SINGLE USER TEST: Test on userB, train/validate on all other users
# This avoids computationally expensive retraining on previously seen data
# We train once, use validation to tune hyperparameters, then test on completely unseen userB
test_user = 'B'  # Fixed test user: userB
print(f"\n{'='*80}")
print(f"SINGLE USER TEST: Testing on user '{test_user}' (completely unseen)")
print(f"{'='*80}")
print(f"⚠️  Training once on all other users, validating to tune hyperparameters,")
print(f"   then testing on user '{test_user}' (no retraining = no data leakage)")

# Results storage (single user, but keeping list format for compatibility)
louo_results = []

# Verify test user exists
if test_user not in user_sessions:
    raise ValueError(f"❌ ERROR: Test user '{test_user}' not found in dataset. Available users: {all_users}")

# Get test sessions (all sessions from test user)
test_sessions = user_sessions[test_user]

# Get training sessions (all sessions from other users)
train_val_sessions = []
for user in all_users:
    if user != test_user:
        train_val_sessions.extend(user_sessions[user])

print(f"  Training/Validation sessions: {len(train_val_sessions)} (from {len(all_users)-1} users)")
print(f"  Test sessions: {len(test_sessions)} (from user '{test_user}')")

# VALIDATION: Ensure no test user data leaks into training/validation
# Fix: Check for exact user match (e.g., 'userA' should not match 'userA2')
def session_belongs_to_user(session_name, user):
    """Check if session belongs to a specific user (exact match, not substring)"""
    if 'user' not in session_name:
        return False
    # Extract user from session name
    session_user = get_user_from_session(session_name)
    return session_user == user

# Check for data leakage
test_user_in_train = any(session_belongs_to_user(s, test_user) for s in train_val_sessions)
if test_user_in_train:
    leaking_sessions = [s for s in train_val_sessions if session_belongs_to_user(s, test_user)]
    raise ValueError(f"❌ DATA LEAKAGE DETECTED: Test user '{test_user}' found in training/validation sessions!\n"
                    f"   Leaking sessions: {leaking_sessions}")

# Verify test sessions belong to test user
for test_session in test_sessions:
    if not session_belongs_to_user(test_session, test_user):
        raise ValueError(f"❌ ERROR: Test session '{test_session}' does not belong to test user '{test_user}'")

print(f"  ✓ Data split validated: No leakage detected")

# Split training sessions into train/val (session-based within training users)
# This split is used to tune hyperparameters via validation set
np.random.seed(42)  # Reproducibility
shuffled_train_val = train_val_sessions.copy()
np.random.shuffle(shuffled_train_val)

# 80% train, 20% val (within training users)
split_idx = int(0.8 * len(shuffled_train_val))
train_sessions = shuffled_train_val[:split_idx]
validation_sessions = shuffled_train_val[split_idx:]

print(f"    Train: {len(train_sessions)} sessions")
print(f"    Val: {len(validation_sessions)} sessions (for hyperparameter tuning)")

# Create datasets
train_dataset = EMGDatasetFiltered(
    DATA_DIR,
    include_sessions=train_sessions,
    sequence_length=SEQUENCE_LENGTH,
    window_size=WINDOW_SIZE,
    stride=STRIDE,
    use_augmentation=True,
    augmentation_prob=AUGMENTATION_PROB
)

validation_dataset = EMGDatasetFiltered(
    DATA_DIR,
    include_sessions=validation_sessions,
    sequence_length=SEQUENCE_LENGTH,
    window_size=WINDOW_SIZE,
    stride=STRIDE,
    scalers=train_dataset.scalers,
    use_augmentation=False
)

test_dataset = EMGDatasetFiltered(
    DATA_DIR,
    include_sessions=test_sessions,
    sequence_length=SEQUENCE_LENGTH,
    window_size=WINDOW_SIZE,
    stride=STRIDE,
    scalers=train_dataset.scalers,
    use_augmentation=False
)

# Calculate class weights
train_labels = train_dataset.sequence_labels
unique, counts = np.unique(train_labels, return_counts=True)
class_weights = compute_class_weight('balanced', classes=unique, y=train_labels)
class_weights = torch.FloatTensor(class_weights).to(DEVICE)

# Enhanced class imbalance handling - BALANCED weighting for minority class
if len(unique) == 2:
    minority_class_idx = np.argmin(counts)
    majority_class_idx = np.argmax(counts)
    class_ratio = min(counts) / max(counts)
    
    # Moderate boosting to address class bias without causing training instability
    if class_ratio < 0.7:  # Back to original threshold for stability
        # Moderate boost factor: max(2.5, 1.2 / class_ratio) - balanced approach
        # This gives reasonable penalty for misclassifying minority class without destabilizing training
        boost_factor = max(2.5, 1.2 / class_ratio)
        original_weight = class_weights[minority_class_idx].item()
        class_weights[minority_class_idx] *= boost_factor
        minority_class_name = train_dataset.label_encoder.inverse_transform([minority_class_idx])[0]
        print(f"\n  ⚡ BALANCED: Boosted {minority_class_name} class weight by {boost_factor:.2f}x")
        print(f"     (Original weight: {original_weight:.4f}, New weight: {class_weights[minority_class_idx].item():.4f})")
    
    class_weights = class_weights / class_weights.sum() * len(class_weights)

# Print diagnostics
print(f"\n  Training class distribution:")
for cls, count in zip(unique, counts):
    pct = 100 * count / len(train_labels)
    class_name = train_dataset.label_encoder.inverse_transform([cls])[0]
    print(f"    {class_name}: {count} samples ({pct:.1f}%)")

# Check test distribution
test_labels = test_dataset.sequence_labels
unique_test, counts_test = np.unique(test_labels, return_counts=True)
print(f"\n  Test user '{test_user}' class distribution:")
for cls, count in zip(unique_test, counts_test):
    pct = 100 * count / len(test_labels) if len(test_labels) > 0 else 0
    class_name = test_dataset.label_encoder.inverse_transform([cls])[0]
    print(f"    {class_name}: {count} samples ({pct:.1f}%)")

if len(unique_test) < 2:
    print(f"  ⚠️  WARNING: Only one class in test set for user '{test_user}'!")

# Create dataloaders
class_counts = np.bincount(train_labels)
class_weights_sample = 1.0 / class_counts
sample_weights = class_weights_sample[train_labels]
weighted_sampler = WeightedRandomSampler(
    weights=sample_weights,
    num_samples=len(sample_weights),
    replacement=True
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=weighted_sampler, num_workers=0)
validation_loader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Verify model class signature before initialization
import inspect
model_signature = inspect.signature(EMGLSTMModel.__init__)
expected_params = ['num_classes', 'd_model', 'hidden_size', 'num_layers', 'dropout', 'sequence_length', 'bidirectional']
actual_params = list(model_signature.parameters.keys())[1:]  # Skip 'self'

if 'num_classes' not in actual_params:
    raise TypeError(
        f"EMGLSTMModel.__init__() is missing 'num_classes' parameter. "
        f"Found parameters: {actual_params}\n"
        f"Please re-run the first cell (model definition) to reload the updated class."
    )

# Initialize model
model = EMGLSTMModel(
    num_classes=train_dataset.num_classes,
    d_model=D_MODEL,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT,
    sequence_length=SEQUENCE_LENGTH,
    bidirectional=BIDIRECTIONAL
).to(DEVICE)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n  Model parameters: {total_params:,} (trainable: {trainable_params:,})")
print(f"  Training samples: {len(train_dataset.sequences):,}")
print(f"  Parameters per sample: {total_params / len(train_dataset.sequences):.2f}")

# Loss function and optimizer
# Balanced gamma: 2.2 (between 2.0 and 2.5) to focus on hard examples without over-emphasis
criterion = FocalLoss(alpha=class_weights, gamma=2.2).to(DEVICE)
print(f"\n  Loss function: Focal Loss with gamma=2.2 (balanced for hard example focus)")
print(f"  Class weights: {dict(zip([train_dataset.label_encoder.inverse_transform([c])[0] for c in unique], class_weights.cpu().numpy()))}")

optimizer = optim.AdamW(model.parameters(), lr=INITIAL_LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.7, patience=5,
    min_lr=1e-4, threshold=0.01
)

# Training loop
best_val_f1 = 0
best_model_state = None
patience_counter = 0

print(f"\n  Starting training (validation set used for hyperparameter tuning)...")

for epoch in range(NUM_EPOCHS):
    # Train
    train_loss, train_acc, train_prec, train_rec, train_f1 = train_epoch(
        model, train_loader, criterion, optimizer, DEVICE, print_per_class_loss=False
    )
    
    # Validate (used for hyperparameter tuning: early stopping, LR scheduling)
    val_loss, val_acc, val_prec, val_rec, val_f1, val_cm = evaluate(
        model, validation_loader, criterion, DEVICE
    )
    
    # Update learning rate (hyperparameter tuning)
    scheduler.step(val_f1)
    
    # Early stopping (hyperparameter tuning)
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        best_model_state = model.state_dict().copy()
        patience_counter = 0
    else:
        patience_counter += 1
        if patience_counter >= EARLY_STOPPING_PATIENCE:
            break
    
    # Print progress every 10 epochs with per-class validation metrics
    if (epoch + 1) % 10 == 0 or epoch == 0:
        # Calculate per-class validation metrics for diagnostics
        with torch.no_grad():
            all_val_preds = []
            all_val_labels = []
            for batch in validation_loader:
                batch = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
                labels = batch['label']
                logits = model(batch)
                preds = torch.argmax(logits, dim=1)
                all_val_preds.extend(preds.cpu().numpy())
                all_val_labels.extend(labels.cpu().numpy())
            
            from sklearn.metrics import precision_recall_fscore_support
            val_prec_per_class, val_rec_per_class, val_f1_per_class, _ = precision_recall_fscore_support(
                all_val_labels, all_val_preds, average=None, zero_division=0
            )
            class_names = [train_dataset.label_encoder.inverse_transform([i])[0] for i in range(len(val_f1_per_class))]
            
            print(f"    Epoch {epoch+1}: Train F1={train_f1:.4f}, Val F1={val_f1:.4f}")
            print(f"      Val per-class F1: {dict(zip(class_names, val_f1_per_class))}")
            print(f"      Val per-class Recall: {dict(zip(class_names, val_rec_per_class))}")

# Load best model (based on validation performance)
model.load_state_dict(best_model_state)

# =============================================================
# MODEL CONFIDENCE ANALYSIS (NO THRESHOLD TUNING)
# =============================================================
# Analyze model confidence on the validation set without changing predictions.
# We will use argmax over softmax probabilities (equivalent to a 0.5 decision
# boundary in the binary case) and keep thresholds out of the decision path.
# =============================================================
print(f"\n  Analyzing model confidence on validation set (argmax decision)...")

from sklearn.metrics import accuracy_score, f1_score

# Get validation predictions (probabilities + argmax)
val_probs = []
val_labels_list = []
model.eval()
with torch.no_grad():
    for batch in validation_loader:
        batch = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        labels = batch['label']
        logits = model(batch)
        probs = torch.softmax(logits, dim=1)
        val_probs.extend(probs.cpu().numpy())
        val_labels_list.extend(labels.cpu().numpy())

val_probs = np.array(val_probs)
val_labels_list = np.array(val_labels_list)

# Analyze model confidence
val_max_probs = np.max(val_probs, axis=1)  # Maximum probability (confidence)
val_confidence_mean = float(np.mean(val_max_probs))
val_confidence_std = float(np.std(val_max_probs))
val_confidence_median = float(np.median(val_max_probs))
val_low_confidence_pct = float(100 * np.sum(val_max_probs < 0.60) / len(val_max_probs))
val_very_low_confidence_pct = float(100 * np.sum(val_max_probs < 0.55) / len(val_max_probs))

print(f"    Model Confidence Statistics (validation):")
print(f"      Mean confidence: {val_confidence_mean:.4f}")
print(f"      Median confidence: {val_confidence_median:.4f}")
print(f"      Std confidence: {val_confidence_std:.4f}")
print(f"      Low confidence (<0.60): {val_low_confidence_pct:.1f}%")
print(f"      Very low confidence (<0.55): {val_very_low_confidence_pct:.1f}%")

# Validation performance with argmax (no threshold tuning)
val_preds = np.argmax(val_probs, axis=1)
val_acc = accuracy_score(val_labels_list, val_preds)
val_f1 = f1_score(val_labels_list, val_preds, average='weighted', zero_division=0)
print(f"\n    Validation performance (argmax):")
print(f"      Accuracy: {val_acc:.4f}")
print(f"      F1 (weighted): {val_f1:.4f}")

# REST baseline confidence (Option B monitoring)
# Assume REST is class index 0 (as in the dataset construction).
rest_class_idx = 0
rest_mask = val_labels_list == rest_class_idx
if np.any(rest_mask):
    rest_probs = val_probs[rest_mask, rest_class_idx]
    val_rest_conf_mean = float(np.mean(rest_probs))
    val_rest_conf_median = float(np.median(rest_probs))
    val_rest_low_conf_pct = float(100 * np.sum(rest_probs < 0.60) / len(rest_probs))
    print(f"\n    REST baseline (validation, true REST samples):")
    print(f"      Mean REST confidence: {val_rest_conf_mean:.4f}")
    print(f"      Median REST confidence: {val_rest_conf_median:.4f}")
    print(f"      Low REST confidence (<0.60): {val_rest_low_conf_pct:.1f}%")
else:
    val_rest_conf_mean = 0.0
    val_rest_conf_median = 0.0
    val_rest_low_conf_pct = 0.0
    print(f"\n    REST baseline (validation): no REST samples found.")

# For compatibility with export code, define best_threshold as the nominal 0.50
# decision boundary (but DO NOT use it to override argmax decisions).
best_threshold = 0.50
best_f1 = val_f1
best_rest_recall = None
best_fist_recall = None

print(f"\n    Note: No threshold tuning performed. Decisions use argmax over softmax.")

# =============================================================
# TEST EVALUATION (ARGMAX DECISION RULE)
# =============================================================
print(f"\n  Evaluating on test set (argmax decision)...")

from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

# Evaluate with the standard evaluate() helper
test_loss, test_acc, test_prec, test_rec, test_f1, test_cm = evaluate(
    model, test_loader, criterion, DEVICE
)

# Compute per-class metrics on test set
model.eval()
test_probs = []
test_labels_list = []
with torch.no_grad():
    for batch in test_loader:
        batch = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        labels = batch['label']
        logits = model(batch)
        probs = torch.softmax(logits, dim=1)
        test_probs.extend(probs.cpu().numpy())
        test_labels_list.extend(labels.cpu().numpy())

test_probs = np.array(test_probs)
test_labels_list = np.array(test_labels_list)

test_preds = np.argmax(test_probs, axis=1)
prec_per_class, rec_per_class, f1_per_class, _ = precision_recall_fscore_support(
    test_labels_list, test_preds, average=None, zero_division=0
)
class_names = [test_dataset.label_encoder.inverse_transform([i])[0] for i in range(len(f1_per_class))]

# Test confidence statistics
test_max_probs = np.max(test_probs, axis=1)
test_confidence_mean = float(np.mean(test_max_probs))
test_confidence_median = float(np.median(test_max_probs))
test_low_confidence_pct = float(100 * np.sum(test_max_probs < 0.60) / len(test_max_probs))

print(f"    Test set confidence statistics:")
print(f"      Mean confidence: {test_confidence_mean:.4f}")
print(f"      Median confidence: {test_confidence_median:.4f}")
print(f"      Low confidence (<0.60): {test_low_confidence_pct:.1f}%")

# REST baseline confidence on test set (for monitoring reference)
rest_mask_test = test_labels_list == rest_class_idx
if np.any(rest_mask_test):
    rest_probs_test = test_probs[rest_mask_test, rest_class_idx]
    test_rest_conf_mean = float(np.mean(rest_probs_test))
    test_rest_low_conf_pct = float(100 * np.sum(rest_probs_test < 0.60) / len(rest_probs_test))
else:
    test_rest_conf_mean = 0.0
    test_rest_low_conf_pct = 0.0

print(f"\n  ✓ User '{test_user}' Results (ARGMAX DECISION):")
print(f"    Test Accuracy: {test_acc:.4f}")
print(f"    Test F1 (weighted): {test_f1:.4f}")
print(f"    Test Precision (weighted): {test_prec:.4f}")
print(f"    Test Recall (weighted): {test_rec:.4f}")

print(f"\n    Per-Class Test Performance:")
for i, class_name in enumerate(class_names):
    print(f"      {class_name}:")
    print(f"        Precision: {prec_per_class[i]:.4f}")
    print(f"        Recall: {rec_per_class[i]:.4f}")
    print(f"        F1: {f1_per_class[i]:.4f}")

print(f"\n    Confusion Matrix (argmax):")
print(f"      {test_cm}")
print(f"      (Rows = True labels, Columns = Predicted labels)")

print(f"\n    Safety assessment (REST vs FIST):")
if len(prec_per_class) >= 2:
    print(f"      REST recall: {rec_per_class[0]:.4f} ({rec_per_class[0]*100:.1f}% of REST gestures detected)")
    print(f"      REST precision: {prec_per_class[0]:.4f} ({prec_per_class[0]*100:.1f}% precision - false stops rate)")
    print(f"      FIST recall: {rec_per_class[1]:.4f} ({rec_per_class[1]*100:.1f}% of FIST gestures detected)")

# Store results (single-user test)
louo_results.append({
    'test_user': test_user,
    'train_samples': len(train_dataset.sequences),
    'val_samples': len(validation_dataset.sequences),
    'test_samples': len(test_dataset.sequences),
    'best_val_f1': best_val_f1,
    'optimal_threshold': float(best_threshold),  # nominal, not tuned
    'test_accuracy': float(test_acc),
    'test_precision': float(test_prec),
    'test_recall': float(test_rec),
    'test_f1': float(test_f1),
    'test_loss': float(test_loss),
    'confusion_matrix': test_cm.tolist(),
    'test_class_distribution': dict(zip(
        [test_dataset.label_encoder.inverse_transform([c])[0] for c in unique_test],
        counts_test.tolist()
    )),
    'deployment_approach': 'argmax',
    'deployment_threshold': float(best_threshold)
})

# Per-class metrics for export
test_prec_per_class = prec_per_class
test_rec_per_class = rec_per_class

# SINGLE USER TEST RESULTS SUMMARY (unchanged structure)
print(f"\n{'='*80}")
print("SINGLE USER TEST RESULTS SUMMARY")
print(f"{'='*80}")

if len(louo_results) == 1:
    r = louo_results[0]
    print(f"\nTest User: {r['test_user']}")
    print(f"  Test Samples: {r['test_samples']}")
    print(f"  Training Samples: {r['train_samples']}")
    print(f"  Validation Samples: {r['val_samples']}")
    print(f"  Best Validation F1: {r['best_val_f1']:.4f}")
    print(f"  Nominal Threshold (unused for decisions): {r.get('optimal_threshold', 0.5):.2f}")
    print(f"\n  Test Performance (argmax):")
    print(f"    Accuracy: {r['test_accuracy']:.4f}")
    print(f"    F1 Score: {r['test_f1']:.4f}")
    print(f"    Precision: {r['test_precision']:.4f}")
    print(f"    Recall: {r['test_recall']:.4f}")
    print(f"    Confusion Matrix:")
    print(f"      {np.array(r['confusion_matrix'])}")
    
    avg_test_acc = r['test_accuracy']
    avg_test_f1 = r['test_f1']
    avg_test_prec = r['test_precision']
    avg_test_rec = r['test_recall']
    std_test_f1 = 0.0
else:
    avg_test_acc = np.mean([r['test_accuracy'] for r in louo_results])
    avg_test_f1 = np.mean([r['test_f1'] for r in louo_results])
    avg_test_prec = np.mean([r['test_precision'] for r in louo_results])
    avg_test_rec = np.mean([r['test_recall'] for r in louo_results])
    std_test_f1 = np.std([r['test_f1'] for r in louo_results]) if len(louo_results) > 1 else 0.0
test_probs = []
test_labels_list = []
test_logits_list = []
model.eval()
with torch.no_grad():
    for batch in test_loader:
        batch = {k: v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        labels = batch['label']
        logits = model(batch)
        probs = torch.softmax(logits, dim=1)
        test_probs.extend(probs.cpu().numpy())
        test_logits_list.extend(logits.cpu().numpy())
        test_labels_list.extend(labels.cpu().numpy())

test_probs = np.array(test_probs)
test_logits = np.array(test_logits_list)
test_labels_list = np.array(test_labels_list)

# Analyze test set confidence
test_max_probs = np.max(test_probs, axis=1)
test_confidence_mean = np.mean(test_max_probs)
test_confidence_median = np.median(test_max_probs)
test_low_confidence_pct = 100 * np.sum(test_max_probs < 0.60) / len(test_max_probs)

print(f"    Test set confidence statistics:")
print(f"      Mean confidence: {test_confidence_mean:.4f}")
print(f"      Median confidence: {test_confidence_median:.4f}")
print(f"      Low confidence (<0.60): {test_low_confidence_pct:.1f}%")

if test_confidence_mean < 0.65:
    print(f"\n    ⚠️  WARNING: Test set confidence is LOW (mean={test_confidence_mean:.4f} < 0.65)")
    print(f"       Model is uncertain on test data. This confirms underfitting concerns.")
    print(f"       Threshold adjustment ({best_threshold:.2f}) is masking this uncertainty.")
elif test_confidence_mean < 0.75:
    print(f"\n    ⚠️  CAUTION: Test set confidence is MODERATE (mean={test_confidence_mean:.4f} < 0.75)")
    print(f"       Model confidence could be improved, but threshold adjustment helps.")
else:
    print(f"\n    ✓ Test set confidence is GOOD (mean={test_confidence_mean:.4f} ≥ 0.75)")
    print(f"       Model is producing confident predictions on unseen data.")

# Apply optimal threshold (baseline)
test_preds_thresh = (test_probs[:, 1] > best_threshold).astype(int)

# =============================================================
# HYBRID SAFETY APPROACH: Confidence-based fallback + Gesture confirmation
# =============================================================
# For safety-critical wheelchair control:
# 1. Use threshold 0.70 for normal operation
# 2. If confidence is low (<0.60), default to REST (stop) - fail-safe
# 3. Require REST gesture for 2-3 consecutive windows before stopping
#    This prevents false stops from single misclassified windows
# =============================================================

def hybrid_safety_predictions(probs, threshold, confidence_threshold=0.60, rest_confirmation_windows=2):
    """
    Hybrid safety approach for wheelchair control:
    - Normal operation: Use threshold for predictions
    - Low confidence: Default to REST (stop) - fail-safe
    - REST confirmation: Require N consecutive REST predictions before confirming stop
    
    Args:
        probs: (N, 2) array of class probabilities
        threshold: Decision threshold for FIST (class 1)
        confidence_threshold: Minimum confidence to trust prediction (default: 0.60)
        rest_confirmation_windows: Number of consecutive REST predictions needed (default: 2)
    
    Returns:
        predictions: (N,) array of final predictions with safety mechanisms
        confidence_scores: (N,) array of confidence scores (max probability)
    """
    n_samples = len(probs)
    predictions = np.zeros(n_samples, dtype=int)
    confidence_scores = np.max(probs, axis=1)
    
    # Step 1: Initial predictions with confidence-based fallback
    initial_preds = np.zeros(n_samples, dtype=int)
    for i in range(n_samples):
        max_prob = confidence_scores[i]
        if max_prob < confidence_threshold:
            # Low confidence - default to REST (stop) for safety
            initial_preds[i] = 0  # REST
        else:
            # Normal threshold-based prediction
            initial_preds[i] = 1 if probs[i, 1] > threshold else 0
    
    # Step 2: Gesture confirmation for REST (stop) - require consecutive REST predictions
    # This prevents false stops from single misclassified windows
    rest_confirmation_count = 0
    for i in range(n_samples):
        if initial_preds[i] == 0:  # REST predicted
            rest_confirmation_count += 1
            if rest_confirmation_count >= rest_confirmation_windows:
                # Confirmed REST - safe to stop
                predictions[i] = 0  # REST
            else:
                # Not enough consecutive REST - keep previous state (FIST/continue)
                # For first window, default to FIST if not confirmed
                if i == 0:
                    predictions[i] = 1  # FIST (continue)
                else:
                    predictions[i] = predictions[i-1]  # Maintain previous state
        else:  # FIST predicted
            rest_confirmation_count = 0  # Reset counter
            predictions[i] = 1  # FIST (continue)
    
    return predictions, confidence_scores

# Apply hybrid safety approach
print(f"\n  Applying HYBRID SAFETY MECHANISM:")
print(f"    - Confidence threshold: 0.60 (low confidence → default to REST/stop)")
print(f"    - REST confirmation: 2 consecutive windows required")
test_preds_hybrid, test_confidence = hybrid_safety_predictions(
    test_probs, 
    threshold=best_threshold,
    confidence_threshold=0.60,
    rest_confirmation_windows=2
)

# Calculate metrics for all three approaches
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

# 1. Default threshold (argmax)
test_preds_default = np.argmax(test_probs, axis=1)
test_acc_default = accuracy_score(test_labels_list, test_preds_default)
test_prec_default, test_rec_default, test_f1_default, _ = precision_recall_fscore_support(
    test_labels_list, test_preds_default, average='weighted', zero_division=0
)
test_cm_default = confusion_matrix(test_labels_list, test_preds_default)

# 2. Optimal threshold (baseline)
test_acc_thresh = accuracy_score(test_labels_list, test_preds_thresh)
test_prec_thresh, test_rec_thresh, test_f1_thresh, _ = precision_recall_fscore_support(
    test_labels_list, test_preds_thresh, average='weighted', zero_division=0
)
test_cm_thresh = confusion_matrix(test_labels_list, test_preds_thresh)

# 3. Hybrid safety approach (with confidence fallback + gesture confirmation)
test_acc_hybrid = accuracy_score(test_labels_list, test_preds_hybrid)
test_prec_hybrid, test_rec_hybrid, test_f1_hybrid, _ = precision_recall_fscore_support(
    test_labels_list, test_preds_hybrid, average='weighted', zero_division=0
)
test_cm_hybrid = confusion_matrix(test_labels_list, test_preds_hybrid)

# Use OPTIMAL THRESHOLD metrics for final results (best performance)
# The hybrid approach was tested but optimal threshold (0.70) performs better
test_acc = test_acc_thresh
test_prec = test_prec_thresh
test_rec = test_rec_thresh
test_f1 = test_f1_thresh
test_cm = test_cm_thresh
test_loss = 0.0  # Loss not meaningful with threshold adjustment

# Calculate per-class metrics for all three approaches
from sklearn.metrics import precision_recall_fscore_support
prec_default, rec_default, f1_default, _ = precision_recall_fscore_support(
    test_labels_list, test_preds_default, average=None, zero_division=0
)
prec_thresh, rec_thresh, f1_thresh_class, _ = precision_recall_fscore_support(
    test_labels_list, test_preds_thresh, average=None, zero_division=0
)
prec_hybrid, rec_hybrid, f1_hybrid_class, _ = precision_recall_fscore_support(
    test_labels_list, test_preds_hybrid, average=None, zero_division=0
)

# Calculate confidence statistics
low_confidence_count = np.sum(test_confidence < 0.60)
low_confidence_pct = 100 * low_confidence_count / len(test_confidence)

print(f"    Comparison of all approaches:")
print(f"      Default (argmax, threshold=0.50):")
print(f"        Overall: Acc={test_acc_default:.4f}, F1={test_f1_default:.4f}")
print(f"        REST:    Prec={prec_default[0]:.4f}, Rec={rec_default[0]:.4f}, F1={f1_default[0]:.4f}")
print(f"        FIST:    Prec={prec_default[1]:.4f}, Rec={rec_default[1]:.4f}, F1={f1_default[1]:.4f}")
print(f"      Optimal threshold (threshold={best_threshold:.2f}):")
print(f"        Overall: Acc={test_acc_thresh:.4f}, F1={test_f1_thresh:.4f}")
print(f"        REST:    Prec={prec_thresh[0]:.4f}, Rec={rec_thresh[0]:.4f}, F1={f1_thresh_class[0]:.4f}")
print(f"        FIST:    Prec={prec_thresh[1]:.4f}, Rec={rec_thresh[1]:.4f}, F1={f1_thresh_class[1]:.4f}")
print(f"      HYBRID SAFETY (threshold={best_threshold:.2f} + confidence fallback + confirmation):")
print(f"        Overall: Acc={test_acc_hybrid:.4f}, F1={test_f1_hybrid:.4f}")
print(f"        REST:    Prec={prec_hybrid[0]:.4f}, Rec={rec_hybrid[0]:.4f}, F1={f1_hybrid_class[0]:.4f}")
print(f"        FIST:    Prec={prec_hybrid[1]:.4f}, Rec={rec_hybrid[1]:.4f}, F1={f1_hybrid_class[1]:.4f}")
print(f"        Low confidence samples: {low_confidence_count}/{len(test_confidence)} ({low_confidence_pct:.1f}%) → defaulted to REST")
print(f"\n    Comparison: Optimal threshold vs Default:")
print(f"      REST recall: {rec_default[0]:.4f} → {rec_thresh[0]:.4f} ({((rec_thresh[0]-rec_default[0])/rec_default[0]*100):+.1f}%)")
print(f"      Overall F1:   {test_f1_default:.4f} → {test_f1_thresh:.4f} ({((test_f1_thresh-test_f1_default)/test_f1_default*100):+.1f}%)")
print(f"\n    Note: Optimal threshold (0.70) selected for deployment - provides best balance")
print(f"          of REST recall (safety) and overall performance.")

# Store results
louo_results.append({
    'test_user': test_user,
    'train_samples': len(train_dataset.sequences),
    'val_samples': len(validation_dataset.sequences),
    'test_samples': len(test_dataset.sequences),
    'best_val_f1': best_val_f1,
    'optimal_threshold': float(best_threshold),
    'test_accuracy': test_acc,  # Using optimal threshold approach
    'test_precision': test_prec,
    'test_recall': test_rec,
    'test_f1': test_f1,
    'test_loss': test_loss,
    'confusion_matrix': test_cm.tolist(),  # Using optimal threshold approach
    'test_class_distribution': dict(zip(
        [test_dataset.label_encoder.inverse_transform([c])[0] for c in unique_test],
        counts_test.tolist()
    )),
    # Store all approach metrics for comparison
    'test_accuracy_default': float(test_acc_default),
    'test_f1_default': float(test_f1_default),
    'confusion_matrix_default': test_cm_default.tolist(),
    'test_accuracy_thresh': float(test_acc_thresh),
    'test_f1_thresh': float(test_f1_thresh),
    'confusion_matrix_thresh': test_cm_thresh.tolist(),
    'test_accuracy_hybrid': float(test_acc_hybrid),
    'test_f1_hybrid': float(test_f1_hybrid),
    'confusion_matrix_hybrid': test_cm_hybrid.tolist(),
    'hybrid_safety_config': {
        'confidence_threshold': 0.60,
        'rest_confirmation_windows': 2,
        'low_confidence_samples': int(low_confidence_count),
        'low_confidence_percentage': float(low_confidence_pct),
        'note': 'Hybrid approach tested but optimal threshold (0.70) performs better - using optimal threshold for deployment'
    },
    'deployment_approach': 'optimal_threshold',  # Using optimal threshold, not hybrid
    'deployment_threshold': float(best_threshold)
})

# Calculate per-class test metrics for detailed analysis (using threshold-adjusted predictions)
from sklearn.metrics import precision_recall_fscore_support

# Use threshold-adjusted predictions (already computed above)
test_prec_per_class, test_rec_per_class, test_f1_per_class, _ = precision_recall_fscore_support(
    test_labels_list, test_preds_thresh, average=None, zero_division=0
)
class_names = [test_dataset.label_encoder.inverse_transform([i])[0] for i in range(len(test_f1_per_class))]

# Calculate per-class metrics for optimal threshold approach (final results)
test_prec_per_class, test_rec_per_class, test_f1_per_class, _ = precision_recall_fscore_support(
    test_labels_list, test_preds_thresh, average=None, zero_division=0
)

print(f"\n  ✓ User '{test_user}' Results (OPTIMAL THRESHOLD APPROACH):")
print(f"    Configuration:")
print(f"      - Threshold: {best_threshold:.2f} (safety-first selection)")
print(f"      - Selection priority: REST recall > FIST recall > Overall F1")
print(f"      - Validation REST recall: {best_rest_recall:.4f}")
print(f"\n    Test Performance:")
print(f"      Test Accuracy: {test_acc:.4f}")
print(f"      Test F1 (weighted): {test_f1:.4f}")
print(f"      Test Precision (weighted): {test_prec:.4f}")
print(f"      Test Recall (weighted): {test_rec:.4f}")
print(f"\n    Per-Class Test Performance:")
for i, class_name in enumerate(class_names):
    print(f"      {class_name}:")
    print(f"        Precision: {test_prec_per_class[i]:.4f}")
    print(f"        Recall: {test_rec_per_class[i]:.4f}")
    print(f"        F1: {test_f1_per_class[i]:.4f}")
print(f"\n    Confusion Matrix (threshold={best_threshold:.2f}):")
print(f"      {test_cm}")
print(f"      (Rows = True labels, Columns = Predicted labels)")
print(f"\n    Safety assessment:")
print(f"      - REST recall: {test_rec_per_class[0]:.4f} ({test_rec_per_class[0]*100:.1f}% of REST gestures detected)")
print(f"      - REST precision: {test_prec_per_class[0]:.4f} ({test_prec_per_class[0]*100:.1f}% precision - very few false stops)")
print(f"      - FIST recall: {test_rec_per_class[1]:.4f} ({test_rec_per_class[1]*100:.1f}% of FIST gestures detected)")
print(f"      - Note: Threshold selected to maximize REST recall for safety-critical wheelchair control")

# =============================================================
# SINGLE USER TEST RESULTS SUMMARY
# =============================================================

print(f"\n{'='*80}")
print("SINGLE USER TEST RESULTS SUMMARY")
print(f"{'='*80}")

# Since we're testing on a single user, just show that user's results
if len(louo_results) == 1:
    r = louo_results[0]
    print(f"\nTest User: {r['test_user']}")
    print(f"  Test Samples: {r['test_samples']}")
    print(f"  Training Samples: {r['train_samples']}")
    print(f"  Validation Samples: {r['val_samples']}")
    print(f"  Best Validation F1: {r['best_val_f1']:.4f}")
    print(f"  Optimal Threshold: {r.get('optimal_threshold', 0.5):.2f}")
    print(f"\n  Test Performance (with threshold={r.get('optimal_threshold', 0.5):.2f}):")
    print(f"    Accuracy: {r['test_accuracy']:.4f}")
    print(f"    F1 Score: {r['test_f1']:.4f}")
    print(f"    Precision: {r['test_precision']:.4f}")
    print(f"    Recall: {r['test_recall']:.4f}")
    print(f"    Confusion Matrix:")
    print(f"      {np.array(r['confusion_matrix'])}")
    if 'test_f1_default' in r:
        print(f"\n  Comparison (default threshold=0.50):")
        print(f"    Accuracy: {r['test_accuracy_default']:.4f}")
        print(f"    F1 Score: {r['test_f1_default']:.4f}")
    
    # For compatibility with summary code, set these values
    avg_test_acc = r['test_accuracy']
    avg_test_f1 = r['test_f1']
    avg_test_prec = r['test_precision']
    avg_test_rec = r['test_recall']
    std_test_f1 = 0.0  # Single result, no std dev
else:
    # Fallback if somehow multiple results exist
    avg_test_acc = np.mean([r['test_accuracy'] for r in louo_results])
    avg_test_f1 = np.mean([r['test_f1'] for r in louo_results])
    avg_test_prec = np.mean([r['test_precision'] for r in louo_results])
    avg_test_rec = np.mean([r['test_recall'] for r in louo_results])
    std_test_f1 = np.std([r['test_f1'] for r in louo_results]) if len(louo_results) > 1 else 0.0



In [None]:
# =============================================================
# MODEL EXPORT FOR DEPLOYMENT
# =============================================================
# Export the trained model, threshold, and configuration for Raspberry Pi deployment
# =============================================================

import json
import pickle
from pathlib import Path

# Create deployment directory
deployment_dir = Path('deployment')
deployment_dir.mkdir(exist_ok=True)

print(f"\n{'='*80}")
print("EXPORTING MODEL FOR DEPLOYMENT")
print(f"{'='*80}")

# Verify model configuration matches training
print(f"\n  Verifying model configuration:")
print(f"    d_model: {D_MODEL}")
print(f"    hidden_size: {HIDDEN_SIZE}")
print(f"    num_layers: {NUM_LAYERS}")
print(f"    dropout: {DROPOUT}")
print(f"    sequence_length: {SEQUENCE_LENGTH}")
print(f"    bidirectional: {BIDIRECTIONAL}")
print(f"    Total parameters: {total_params:,}")

# Verify StandardScalers exist
required_scalers = ['raw', 'rms_lms', 'wiener_td', 'wiener_fft', 'imu', 'spectral_raw', 'spectral_wiener']
missing_scalers = [s for s in required_scalers if s not in train_dataset.scalers or train_dataset.scalers[s] is None]

if missing_scalers:
    raise ValueError(f"❌ Missing StandardScalers: {missing_scalers}. Cannot deploy without complete normalization.")
else:
    print(f"\n  ✓ All StandardScalers present: {required_scalers}")

# 1. Save model state dict
model_path = deployment_dir / 'emg_lstm_model.pt'
model_config = {
    'num_classes': train_dataset.num_classes,
    'd_model': D_MODEL,
    'hidden_size': HIDDEN_SIZE,
    'num_layers': NUM_LAYERS,
    'dropout': DROPOUT,
    'sequence_length': SEQUENCE_LENGTH,
    'bidirectional': BIDIRECTIONAL
}

torch.save({
    'model_state_dict': model.state_dict(),
    'model_config': model_config,
    'class_names': train_dataset.label_encoder.classes_.tolist(),
    'label_encoder': train_dataset.label_encoder
}, model_path)
print(f"\n✓ Model saved to: {model_path}")
print(f"  Model config: {model_config}")

# 2. Save deployment configuration
deployment_config = {
    'optimal_threshold': float(best_threshold),  # nominal 0.50, not used to gate decisions
    'window_size': WINDOW_SIZE,
    'stride': STRIDE,
    'sequence_length': SEQUENCE_LENGTH,
    'sampling_rate': 500,  # EMG sampling rate
    'imu_sampling_rate': 100,  # IMU sampling rate
    'normalization_method': 'two-step',  # Per-window z-score + global StandardScaler
    'scaler_file': 'standard_scalers.pkl',  # StandardScalers for deployment
    'model_parameters': int(total_params),
    'performance_metrics': {
        'test_accuracy': float(test_acc),
        'test_f1': float(test_f1),
        'test_precision': float(test_prec),
        'test_recall': float(test_rec),
        'rest_recall': float(test_rec_per_class[0]),
        'rest_precision': float(test_prec_per_class[0]),
        'fist_recall': float(test_rec_per_class[1]),
        'fist_precision': float(test_prec_per_class[1])
    },
    'confidence_metrics': {
        'validation_confidence_mean': float(val_confidence_mean),
        'validation_confidence_median': float(val_confidence_median),
        'validation_low_confidence_pct': float(val_low_confidence_pct),
        'test_confidence_mean': float(test_confidence_mean),
        'test_confidence_median': float(test_confidence_median),
        'test_low_confidence_pct': float(test_low_confidence_pct),
        'confidence_warning': 'LOW' if test_confidence_mean < 0.65 else ('MODERATE' if test_confidence_mean < 0.75 else 'GOOD')
    },
    'rest_baseline': {
        'validation_mean_confidence': float(val_rest_conf_mean),
        'validation_low_confidence_pct': float(val_rest_low_conf_pct),
        'test_mean_confidence': float(test_rest_conf_mean),
        'test_low_confidence_pct': float(test_rest_low_conf_pct),
        'note': 'For monitoring only; early-session REST can be compared to these values on-device.'
    },
    'deployment_date': str(pd.Timestamp.now()),
    'test_user': test_user,
    'notes': 'Model trained with per-window + global normalization; deployment uses argmax over softmax (no tuned threshold).',
    'threshold_warning': 'Controller must not gate decisions on probability thresholds; use argmax and optional temporal smoothing only.'
}

config_path = deployment_dir / 'deployment_config.json'
with open(config_path, 'w') as f:
    json.dump(deployment_config, f, indent=2)
print(f"✓ Configuration saved to: {config_path}")

# 2b. Save StandardScalers for deployment normalization (STEP 2 of two-step normalization)
scaler_path = deployment_dir / 'standard_scalers.pkl'
scaler_data = {
    'raw': train_dataset.scalers.get('raw'),
    'rms_lms': train_dataset.scalers.get('rms_lms'),
    'wiener_td': train_dataset.scalers.get('wiener_td'),
    'wiener_fft': train_dataset.scalers.get('wiener_fft'),
    'imu': train_dataset.scalers.get('imu'),
    'spectral_raw': train_dataset.scalers.get('spectral_raw'),
    'spectral_wiener': train_dataset.scalers.get('spectral_wiener')
}

# Verify all scalers are present and not None
for key, scaler in scaler_data.items():
    if scaler is None:
        raise ValueError(f"❌ StandardScaler '{key}' is None. Cannot deploy without complete normalization.")

with open(scaler_path, 'wb') as f:
    pickle.dump(scaler_data, f)
print(f"\n✓ StandardScalers saved to: {scaler_path}")
print(f"  Scaler keys: {list(scaler_data.keys())}")
print(f"  Note: These scalers are required for STEP 2 of two-step normalization in deployment")


# 3. Create model summary
summary = {
    'model_file': str(model_path),
    'config_file': str(config_path),
    'model_size_mb': total_params * 4 / (1024 * 1024),  # Assuming float32 (4 bytes)
    'threshold': float(best_threshold),
    'performance': deployment_config['performance_metrics']
}

summary_path = deployment_dir / 'model_summary.json'
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)
print(f"✓ Model summary saved to: {summary_path}")

print(f"\n{'='*80}")
print("DEPLOYMENT PACKAGE READY")
print(f"{'='*80}")
print(f"Deployment directory: {deployment_dir.absolute()}")
print(f"\nFiles created:")
print(f"  1. {model_path.name} - Trained model weights (d_model={D_MODEL}, hidden_size={HIDDEN_SIZE}, layers={NUM_LAYERS})")
print(f"  2. {config_path.name} - Deployment configuration (threshold={best_threshold:.2f})")
print(f"  3. {scaler_path.name} - StandardScalers for two-step normalization (STEP 2)")
print(f"  4. {summary_path.name} - Model summary")
print(f"\nDeployment package verification:")
print(f"  ✓ Model architecture: {model_config}")
print(f"  ✓ StandardScalers: {len(scaler_data)} scalers saved")
print(f"  ✓ Optimal threshold: {best_threshold:.2f}")
print(f"  ✓ Normalization: Two-step (per-window z-score + global StandardScaler)")
print(f"  ✓ Model parameters: {total_params:,} ({total_params * 4 / (1024 * 1024):.2f} MB)")
print(f"\n⚠️  IMPORTANT: Update inference.py to load and apply StandardScalers!")
print(f"   The deployment code must apply both normalization steps:")
print(f"   STEP 1: Per-window z-score normalization (already implemented)")
print(f"   STEP 2: Global StandardScaler normalization (MUST be added to inference.py)")
print(f"\nNext steps:")
print(f"  1. Update inference.py to load and apply standard_scalers.pkl")
print(f"  2. Copy 'deployment' folder to Raspberry Pi")
print(f"  3. On Raspberry Pi: pip install -r requirements.txt")
print(f"  4. Run inference: python inference.py --help")
print(f"  5. Model is ready for edge deployment! 🚀")
print(f"{'='*80}\n")
