# Sleep Stage Prediction - Ablation Study
## Comparing HR-only vs HR+Motion vs HR+Motion+Steps

This notebook allows you to toggle different data modalities to see their impact on prediction accuracy.

**Dataset**: [Motion and Heart Rate from Wrist-Worn Wearable](https://physionet.org/content/sleep-accel/1.0.0/)

## 1. Setup

In [None]:
!pip install -q optuna xgboost scikit-learn pandas numpy

In [None]:
import os
import glob
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Dict, List

import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner

import xgboost as xgb
from sklearn.model_selection import GroupKFold
from sklearn.metrics import f1_score, classification_report, confusion_matrix, accuracy_score
from sklearn.preprocessing import LabelEncoder

warnings.filterwarnings('ignore')

print(f"Optuna version: {optuna.__version__}")
print(f"XGBoost version: {xgb.__version__}")

## 2. Ablation Study Configuration

**Toggle these flags to run different experiments:**

In [None]:
# ================================================================
# ABLATION STUDY CONFIGURATION
# ================================================================

# Data modalities to include
INCLUDE_HR = True        # Heart rate features (baseline)
INCLUDE_MOTION = True    # Accelerometer features (recommended!)
INCLUDE_STEPS = False    # Step count features (marginal benefit)

# Classification settings
NUM_CLASSES = 5          # 5 = full (Wake/N1/N2/N3/REM), 3 = simplified (Wake/NREM/REM)
USE_CLASS_WEIGHTS = False

# Optuna settings
N_TRIALS = 100

# Print config
print("="*50)
print("ABLATION STUDY CONFIG")
print("="*50)
print(f"  Heart Rate:  {'ON' if INCLUDE_HR else 'OFF'}")
print(f"  Motion:      {'ON' if INCLUDE_MOTION else 'OFF'}")
print(f"  Steps:       {'ON' if INCLUDE_STEPS else 'OFF'}")
print(f"  Classes:     {NUM_CLASSES}")
print(f"  Trials:      {N_TRIALS}")
print("="*50)

## 3. Download Dataset

⚠️ **Warning**: Motion data is ~2GB. Only downloads what's needed based on config.

In [None]:
DATA_DIR = Path("./sleep_data")
DATA_DIR.mkdir(exist_ok=True)

DATASET_URL = "https://physionet.org/files/sleep-accel/1.0.0"

# Always need labels
print("Downloading labels...")
!wget -q -r -np -nH --cut-dirs=3 -P {DATA_DIR} {DATASET_URL}/labels/

# Heart rate (small, ~4MB)
if INCLUDE_HR:
    print("Downloading heart rate data (~4MB)...")
    !wget -q -r -np -nH --cut-dirs=3 -P {DATA_DIR} {DATASET_URL}/heart_rate/

# Motion (large, ~2GB)
if INCLUDE_MOTION:
    print("Downloading motion data (~2GB - this will take a while)...")
    !wget -q -r -np -nH --cut-dirs=3 -P {DATA_DIR} {DATASET_URL}/motion/

# Steps (small)
if INCLUDE_STEPS:
    print("Downloading steps data...")
    !wget -q -r -np -nH --cut-dirs=3 -P {DATA_DIR} {DATASET_URL}/steps/

print("\nDownload complete!")

## 4. Configuration

In [None]:
HEART_RATE_DIR = DATA_DIR / "heart_rate"
MOTION_DIR = DATA_DIR / "motion"
STEPS_DIR = DATA_DIR / "steps"
LABELS_DIR = DATA_DIR / "labels"

SLEEP_STAGE_NAMES = {
    0: 'Wake',
    1: 'N1',
    2: 'N2',
    3: 'N3',
    5: 'REM'
}

EPOCH_DURATION = 30  # seconds

## 5. Data Loading Functions

In [None]:
def get_subject_ids() -> List[str]:
    """Get list of all subject IDs from labels directory."""
    label_files = glob.glob(str(LABELS_DIR / "*_labeled_sleep.txt"))
    return sorted([Path(f).stem.replace('_labeled_sleep', '') for f in label_files])


def load_sleep_labels(subject_id: str) -> pd.DataFrame:
    """Load labeled sleep data for a subject."""
    label_file = LABELS_DIR / f"{subject_id}_labeled_sleep.txt"
    df = pd.read_csv(label_file, sep=' ', header=None, names=['time_offset', 'sleep_stage'])
    df['subject_id'] = subject_id
    return df


def load_heart_rate_data(subject_id: str) -> pd.DataFrame:
    """Load heart rate data for a subject."""
    hr_file = HEART_RATE_DIR / f"{subject_id}_heartrate.txt"
    if not hr_file.exists():
        return None
    df = pd.read_csv(hr_file, header=None, names=['timestamp', 'heart_rate'])
    return df


def load_motion_data(subject_id: str) -> pd.DataFrame:
    """Load accelerometer data for a subject."""
    motion_file = MOTION_DIR / f"{subject_id}_acceleration.txt"
    if not motion_file.exists():
        return None
    # Motion data is space-separated: timestamp, x, y, z
    df = pd.read_csv(motion_file, sep=' ', header=None, 
                     names=['timestamp', 'accel_x', 'accel_y', 'accel_z'])
    return df


def load_steps_data(subject_id: str) -> pd.DataFrame:
    """Load step count data for a subject."""
    steps_file = STEPS_DIR / f"{subject_id}_steps.txt"
    if not steps_file.exists():
        return None
    df = pd.read_csv(steps_file, header=None, names=['timestamp', 'steps'])
    return df


subject_ids = get_subject_ids()
print(f"Found {len(subject_ids)} subjects")

## 6. Feature Extraction Functions

In [None]:
def extract_hr_features(hr_values: np.ndarray, timestamps: np.ndarray) -> Dict[str, float]:
    """Extract HRV features from heart rate data within an epoch."""
    features = {}
    prefix = 'hr_'
    
    if len(hr_values) < 2:
        return {
            f'{prefix}mean': np.nan, f'{prefix}std': np.nan, f'{prefix}min': np.nan,
            f'{prefix}max': np.nan, f'{prefix}range': np.nan, f'{prefix}median': np.nan,
            f'{prefix}rmssd': np.nan, f'{prefix}pnn50': np.nan, f'{prefix}slope': np.nan,
            f'{prefix}count': len(hr_values), f'{prefix}cv': np.nan,
            f'{prefix}skew': np.nan, f'{prefix}iqr': np.nan
        }
    
    features[f'{prefix}mean'] = np.mean(hr_values)
    features[f'{prefix}std'] = np.std(hr_values)
    features[f'{prefix}min'] = np.min(hr_values)
    features[f'{prefix}max'] = np.max(hr_values)
    features[f'{prefix}range'] = features[f'{prefix}max'] - features[f'{prefix}min']
    features[f'{prefix}median'] = np.median(hr_values)
    features[f'{prefix}count'] = len(hr_values)
    
    features[f'{prefix}cv'] = features[f'{prefix}std'] / features[f'{prefix}mean'] if features[f'{prefix}mean'] > 0 else np.nan
    
    q75, q25 = np.percentile(hr_values, [75, 25])
    features[f'{prefix}iqr'] = q75 - q25
    
    if features[f'{prefix}std'] > 0:
        features[f'{prefix}skew'] = np.mean(((hr_values - features[f'{prefix}mean']) / features[f'{prefix}std']) ** 3)
    else:
        features[f'{prefix}skew'] = np.nan
    
    hr_diff = np.diff(hr_values)
    features[f'{prefix}rmssd'] = np.sqrt(np.mean(hr_diff ** 2)) if len(hr_diff) > 0 else np.nan
    features[f'{prefix}pnn50'] = np.sum(np.abs(hr_diff) > 5) / len(hr_diff) if len(hr_diff) > 0 else np.nan
    
    if len(timestamps) >= 2:
        try:
            slope, _ = np.polyfit(timestamps - timestamps[0], hr_values, 1)
            features[f'{prefix}slope'] = slope
        except:
            features[f'{prefix}slope'] = np.nan
    else:
        features[f'{prefix}slope'] = np.nan
    
    return features

In [None]:
def extract_motion_features(accel_x: np.ndarray, accel_y: np.ndarray, 
                            accel_z: np.ndarray) -> Dict[str, float]:
    """
    Extract motion features from accelerometer data within an epoch.
    
    These features are critical for distinguishing Wake from REM:
    - Wake: movement present
    - REM: muscle atonia (paralysis), no movement
    """
    features = {}
    prefix = 'motion_'
    
    if len(accel_x) < 2:
        return {
            f'{prefix}mag_mean': np.nan, f'{prefix}mag_std': np.nan,
            f'{prefix}mag_max': np.nan, f'{prefix}mag_min': np.nan,
            f'{prefix}activity_count': np.nan, f'{prefix}pct_still': np.nan,
            f'{prefix}zero_crossings': np.nan, f'{prefix}entropy': np.nan,
            f'{prefix}x_std': np.nan, f'{prefix}y_std': np.nan, f'{prefix}z_std': np.nan,
            f'{prefix}count': len(accel_x)
        }
    
    # Compute magnitude: sqrt(x² + y² + z²)
    magnitude = np.sqrt(accel_x**2 + accel_y**2 + accel_z**2)
    
    # Basic magnitude stats
    features[f'{prefix}mag_mean'] = np.mean(magnitude)
    features[f'{prefix}mag_std'] = np.std(magnitude)
    features[f'{prefix}mag_max'] = np.max(magnitude)
    features[f'{prefix}mag_min'] = np.min(magnitude)
    features[f'{prefix}count'] = len(magnitude)
    
    # Activity count (classic actigraphy metric)
    # Sum of absolute differences in magnitude
    mag_diff = np.abs(np.diff(magnitude))
    features[f'{prefix}activity_count'] = np.sum(mag_diff)
    
    # Percentage of time "still" (magnitude change < threshold)
    still_threshold = 0.01  # Adjust based on sensor units
    features[f'{prefix}pct_still'] = np.mean(mag_diff < still_threshold)
    
    # Zero crossings (movement frequency indicator)
    # Count sign changes in the de-meaned signal
    mag_centered = magnitude - np.mean(magnitude)
    zero_crossings = np.sum(np.abs(np.diff(np.sign(mag_centered))) > 0)
    features[f'{prefix}zero_crossings'] = zero_crossings / len(magnitude)
    
    # Entropy (randomness of movement)
    # Higher entropy = more random movement (awake)
    # Lower entropy = regular/no movement (sleep)
    hist, _ = np.histogram(magnitude, bins=10, density=True)
    hist = hist[hist > 0]  # Remove zeros for log
    features[f'{prefix}entropy'] = -np.sum(hist * np.log(hist + 1e-10))
    
    # Per-axis variability
    features[f'{prefix}x_std'] = np.std(accel_x)
    features[f'{prefix}y_std'] = np.std(accel_y)
    features[f'{prefix}z_std'] = np.std(accel_z)
    
    return features

In [None]:
def extract_steps_features(steps_df: pd.DataFrame, epoch_start: float, 
                           epoch_end: float) -> Dict[str, float]:
    """
    Extract step features for an epoch.
    Steps are in 10-minute windows, so we interpolate/estimate.
    """
    features = {}
    prefix = 'steps_'
    
    if steps_df is None or len(steps_df) == 0:
        return {f'{prefix}count': np.nan, f'{prefix}rate': np.nan}
    
    # Find steps records that overlap with this epoch
    # Steps are in 600-second (10 min) windows
    window_size = 600
    
    overlapping = steps_df[
        (steps_df['timestamp'] <= epoch_end) & 
        (steps_df['timestamp'] + window_size >= epoch_start)
    ]
    
    if len(overlapping) == 0:
        features[f'{prefix}count'] = 0
        features[f'{prefix}rate'] = 0
    else:
        # Estimate steps in this 30-second epoch
        # Proportional allocation from 10-minute window
        total_steps = overlapping['steps'].sum()
        # Scale from 10-min windows to 30-sec epoch
        features[f'{prefix}count'] = total_steps * (EPOCH_DURATION / window_size)
        features[f'{prefix}rate'] = features[f'{prefix}count'] / EPOCH_DURATION * 60  # steps/min
    
    return features

## 7. Epoch Alignment and Feature Extraction

In [None]:
def process_subject(subject_id: str) -> pd.DataFrame:
    """
    Process all data for a single subject, aligning to 30-second epochs.
    """
    # Load labels (required)
    labels_df = load_sleep_labels(subject_id)
    
    # Load optional data based on config
    hr_df = load_heart_rate_data(subject_id) if INCLUDE_HR else None
    motion_df = load_motion_data(subject_id) if INCLUDE_MOTION else None
    steps_df = load_steps_data(subject_id) if INCLUDE_STEPS else None
    
    features_list = []
    
    for _, row in labels_df.iterrows():
        epoch_start = row['time_offset']
        epoch_end = epoch_start + EPOCH_DURATION
        
        epoch_features = {
            'time_offset': epoch_start,
            'sleep_stage': row['sleep_stage'],
            'subject_id': subject_id
        }
        
        # Extract HR features
        if INCLUDE_HR and hr_df is not None:
            epoch_hr = hr_df[
                (hr_df['timestamp'] >= epoch_start) &
                (hr_df['timestamp'] < epoch_end)
            ]
            if len(epoch_hr) > 0:
                hr_features = extract_hr_features(
                    epoch_hr['heart_rate'].values,
                    epoch_hr['timestamp'].values
                )
            else:
                hr_features = extract_hr_features(np.array([]), np.array([]))
            epoch_features.update(hr_features)
        
        # Extract motion features
        if INCLUDE_MOTION and motion_df is not None:
            epoch_motion = motion_df[
                (motion_df['timestamp'] >= epoch_start) &
                (motion_df['timestamp'] < epoch_end)
            ]
            if len(epoch_motion) > 0:
                motion_features = extract_motion_features(
                    epoch_motion['accel_x'].values,
                    epoch_motion['accel_y'].values,
                    epoch_motion['accel_z'].values
                )
            else:
                motion_features = extract_motion_features(
                    np.array([]), np.array([]), np.array([])
                )
            epoch_features.update(motion_features)
        
        # Extract steps features
        if INCLUDE_STEPS:
            steps_features = extract_steps_features(steps_df, epoch_start, epoch_end)
            epoch_features.update(steps_features)
        
        features_list.append(epoch_features)
    
    return pd.DataFrame(features_list)

## 8. Prepare Complete Dataset

In [None]:
def prepare_dataset() -> pd.DataFrame:
    """Load and prepare dataset from all subjects."""
    subject_ids = get_subject_ids()
    print(f"Processing {len(subject_ids)} subjects...")
    
    all_features = []
    
    for i, subject_id in enumerate(subject_ids):
        print(f"  [{i+1}/{len(subject_ids)}] {subject_id}", end="")
        
        try:
            features_df = process_subject(subject_id)
            all_features.append(features_df)
            print(f" - {len(features_df)} epochs")
        except Exception as e:
            print(f" - ERROR: {e}")
            continue
    
    dataset = pd.concat(all_features, ignore_index=True)
    
    # Filter invalid sleep stages
    valid_stages = [0, 1, 2, 3, 5]
    dataset = dataset[dataset['sleep_stage'].isin(valid_stages)].copy()
    
    # Get feature columns based on what we included
    feature_prefixes = []
    if INCLUDE_HR:
        feature_prefixes.append('hr_')
    if INCLUDE_MOTION:
        feature_prefixes.append('motion_')
    if INCLUDE_STEPS:
        feature_prefixes.append('steps_')
    
    # Drop rows with too many missing values
    if INCLUDE_HR:
        dataset = dataset[dataset['hr_count'] >= 2].copy()
        dataset = dataset.dropna(subset=['hr_mean', 'hr_std'])
    if INCLUDE_MOTION:
        dataset = dataset[dataset['motion_count'] >= 10].copy()
        dataset = dataset.dropna(subset=['motion_mag_mean'])
    
    # Add temporal context features
    print("\nAdding temporal context features...")
    dataset = dataset.sort_values(['subject_id', 'time_offset']).reset_index(drop=True)
    
    # HR temporal features
    if INCLUDE_HR:
        dataset['hr_mean_roll_5'] = dataset.groupby('subject_id')['hr_mean'].transform(
            lambda x: x.rolling(window=5, center=True, min_periods=1).mean()
        )
        for lag in [1, 2, 4]:
            dataset[f'hr_mean_lag_{lag}'] = dataset.groupby('subject_id')['hr_mean'].shift(lag)
        for lead in [1, 2]:
            dataset[f'hr_mean_lead_{lead}'] = dataset.groupby('subject_id')['hr_mean'].shift(-lead)
        dataset['hr_diff_1'] = dataset['hr_mean'] - dataset.groupby('subject_id')['hr_mean'].shift(1)
    
    # Motion temporal features
    if INCLUDE_MOTION:
        dataset['motion_mag_roll_5'] = dataset.groupby('subject_id')['motion_mag_mean'].transform(
            lambda x: x.rolling(window=5, center=True, min_periods=1).mean()
        )
        dataset['motion_activity_roll_5'] = dataset.groupby('subject_id')['motion_activity_count'].transform(
            lambda x: x.rolling(window=5, center=True, min_periods=1).mean()
        )
        for lag in [1, 2]:
            dataset[f'motion_activity_lag_{lag}'] = dataset.groupby('subject_id')['motion_activity_count'].shift(lag)
        dataset['motion_activity_diff_1'] = dataset['motion_activity_count'] - dataset.groupby('subject_id')['motion_activity_count'].shift(1)
    
    # Drop NaNs from temporal features
    dataset = dataset.dropna()
    
    print(f"\nTotal epochs: {len(dataset)}")
    print(f"\nClass distribution:")
    for stage, name in SLEEP_STAGE_NAMES.items():
        count = (dataset['sleep_stage'] == stage).sum()
        pct = count / len(dataset) * 100
        print(f"  {name}: {count} ({pct:.1f}%)")
    
    return dataset


# Load dataset
dataset = prepare_dataset()

In [None]:
# Show features
feature_cols = [c for c in dataset.columns if any(c.startswith(p) for p in ['hr_', 'motion_', 'steps_'])]
print(f"Total features: {len(feature_cols)}")
print(f"\nFeature columns:")
for col in sorted(feature_cols):
    print(f"  - {col}")

## 9. Optuna Optimization

In [None]:
def create_objective(X, y, groups, n_folds=5, use_class_weights=False):
    """Create Optuna objective function."""
    
    sample_weights = None
    if use_class_weights:
        class_weights = {}
        for cls in np.unique(y):
            class_weights[cls] = len(y) / (len(np.unique(y)) * np.sum(y == cls))
        sample_weights = np.array([class_weights[yi] for yi in y])
    
    def objective(trial):
        params = {
            'objective': 'multi:softmax',
            'num_class': len(np.unique(y)),
            'eval_metric': 'mlogloss',
            'booster': 'gbtree',
            'device': 'cuda',
            'lambda': trial.suggest_float('lambda', 1e-8, 10.0, log=True),
            'alpha': trial.suggest_float('alpha', 1e-8, 10.0, log=True),
            'max_depth': trial.suggest_int('max_depth', 3, 10),
            'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.3, log=True),
            'n_estimators': trial.suggest_int('n_estimators', 50, 300),
            'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
            'subsample': trial.suggest_float('subsample', 0.5, 1.0),
            'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
            'gamma': trial.suggest_float('gamma', 1e-8, 1.0, log=True),
            'random_state': 42,
            'verbosity': 0
        }
        
        cv = GroupKFold(n_splits=n_folds)
        f1_scores = []
        
        for fold_idx, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            weights_train = sample_weights[train_idx] if sample_weights is not None else None
            
            model = xgb.XGBClassifier(**params)
            model.fit(X_train, y_train, sample_weight=weights_train,
                     eval_set=[(X_val, y_val)], verbose=False)
            
            y_pred = model.predict(X_val)
            f1 = f1_score(y_val, y_pred, average='macro')
            f1_scores.append(f1)
            
            trial.report(np.mean(f1_scores), fold_idx)
            if trial.should_prune():
                raise optuna.TrialPruned()
        
        return np.mean(f1_scores)
    
    return objective

In [None]:
# Prepare features and labels
X = dataset[feature_cols].values

label_encoder = LabelEncoder()

if NUM_CLASSES == 3:
    print("*** Using 3-class mode: Wake / NREM / REM ***")
    y_raw = dataset['sleep_stage'].replace({1: 2, 3: 2}).values
    SLEEP_STAGE_NAMES_USED = {0: 'Wake', 2: 'NREM', 5: 'REM'}
else:
    y_raw = dataset['sleep_stage'].values
    SLEEP_STAGE_NAMES_USED = SLEEP_STAGE_NAMES

y = label_encoder.fit_transform(y_raw)

subject_encoder = LabelEncoder()
groups = subject_encoder.fit_transform(dataset['subject_id'].values)

# Fill NaN
for i in range(X.shape[1]):
    col_median = np.nanmedian(X[:, i])
    X[np.isnan(X[:, i]), i] = col_median if not np.isnan(col_median) else 0

print(f"Features: {X.shape}")
print(f"Classes: {[SLEEP_STAGE_NAMES_USED.get(c, c) for c in label_encoder.classes_]}")

In [None]:
# Run optimization
study = optuna.create_study(
    direction='maximize',
    sampler=TPESampler(seed=42),
    pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=2)
)

objective = create_objective(X, y, groups, n_folds=5, use_class_weights=USE_CLASS_WEIGHTS)

study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=True)

print(f"\n{'='*50}")
print(f"Best trial:")
print(f"  Macro F1: {study.best_trial.value:.4f}")
print(f"  Params: {study.best_trial.params}")

## 10. Final Evaluation

In [None]:
# Final evaluation
cv = GroupKFold(n_splits=5)
all_y_true = []
all_y_pred = []

for train_idx, val_idx in cv.split(X, y, groups):
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]
    
    params = study.best_params.copy()
    params['objective'] = 'multi:softmax'
    params['num_class'] = len(np.unique(y))
    params['random_state'] = 42
    params['verbosity'] = 0
    
    model = xgb.XGBClassifier(**params)
    model.fit(X_train, y_train)
    
    y_pred = model.predict(X_val)
    all_y_true.extend(y_val)
    all_y_pred.extend(y_pred)

stage_names = [SLEEP_STAGE_NAMES_USED.get(s, str(s)) for s in label_encoder.classes_]

print("="*50)
print("ABLATION STUDY RESULTS")
print("="*50)
print(f"Data: HR={'ON' if INCLUDE_HR else 'OFF'}, Motion={'ON' if INCLUDE_MOTION else 'OFF'}, Steps={'ON' if INCLUDE_STEPS else 'OFF'}")
print(f"Classes: {NUM_CLASSES}")
print("="*50)
print(f"\nOverall Accuracy: {accuracy_score(all_y_true, all_y_pred):.4f}")
print(f"Macro F1 Score:   {f1_score(all_y_true, all_y_pred, average='macro'):.4f}")
print("\nClassification Report:")
print(classification_report(all_y_true, all_y_pred, target_names=stage_names))

In [None]:
# Confusion matrix
import matplotlib.pyplot as plt
import seaborn as sns

cm = confusion_matrix(all_y_true, all_y_pred)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=stage_names, yticklabels=stage_names, ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')
axes[0].set_title('Confusion Matrix (Counts)')

sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=stage_names, yticklabels=stage_names, ax=axes[1])
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')
axes[1].set_title('Confusion Matrix (Normalized)')

plt.suptitle(f"HR={'ON' if INCLUDE_HR else 'OFF'}, Motion={'ON' if INCLUDE_MOTION else 'OFF'}, Steps={'ON' if INCLUDE_STEPS else 'OFF'}")
plt.tight_layout()
plt.show()

## 11. Feature Importance

In [None]:
# Train final model
final_params = study.best_params.copy()
final_params['objective'] = 'multi:softmax'
final_params['num_class'] = len(np.unique(y))
final_params['random_state'] = 42
final_params['verbosity'] = 0

final_model = xgb.XGBClassifier(**final_params)
final_model.fit(X, y)

# Plot importance
importance = final_model.feature_importances_
sorted_idx = np.argsort(importance)[::-1]

plt.figure(figsize=(14, 8))
top_n = min(25, len(importance))
plt.barh(range(top_n), importance[sorted_idx[:top_n]][::-1])
plt.yticks(range(top_n), [feature_cols[i] for i in sorted_idx[:top_n]][::-1])
plt.xlabel('Importance')
plt.title(f'Top {top_n} Feature Importance')
plt.tight_layout()
plt.show()

print("\nTop 15 features:")
for i in sorted_idx[:15]:
    print(f"  {feature_cols[i]}: {importance[i]:.4f}")

## 12. Save Results

In [None]:
# Create results filename based on config
config_str = f"hr{int(INCLUDE_HR)}_motion{int(INCLUDE_MOTION)}_steps{int(INCLUDE_STEPS)}_{NUM_CLASSES}class"

# Save
pd.DataFrame([study.best_params]).to_csv(f'best_params_{config_str}.csv', index=False)
final_model.save_model(f'model_{config_str}.json')
study.trials_dataframe().to_csv(f'trials_{config_str}.csv', index=False)

# Save summary
summary = {
    'config': config_str,
    'include_hr': INCLUDE_HR,
    'include_motion': INCLUDE_MOTION,
    'include_steps': INCLUDE_STEPS,
    'num_classes': NUM_CLASSES,
    'accuracy': accuracy_score(all_y_true, all_y_pred),
    'macro_f1': f1_score(all_y_true, all_y_pred, average='macro'),
    'n_features': len(feature_cols),
    'n_samples': len(y)
}
pd.DataFrame([summary]).to_csv(f'summary_{config_str}.csv', index=False)

print(f"Results saved with prefix: {config_str}")

In [None]:
# Download (Colab)
try:
    from google.colab import files
    files.download(f'summary_{config_str}.csv')
    files.download(f'model_{config_str}.json')
except:
    print("Files saved to current directory")