# Sleep Stage Prediction using Heart Rate Data
## Optuna + XGBoost Hyperparameter Optimization

This notebook predicts sleep stages (Wake, N1, N2, N3, REM) from heart rate data using:
- **XGBoost** for classification
- **Optuna** for hyperparameter tuning
- **HRV features** extracted from heart rate time series

**Dataset**: [Motion and Heart Rate from Wrist-Worn Wearable](https://physionet.org/content/sleep-accel/1.0.0/)

## 1. Setup and Installation

In [None]:
# Install required packages
!pip install -q optuna xgboost scikit-learn pandas numpy

In [None]:
import os
import glob
import warnings
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Tuple, Dict, List, Optional

import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner

import xgboost as xgb
from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.metrics import f1_score, classification_report, confusion_matrix
from sklearn.preprocessing import LabelEncoder

warnings.filterwarnings('ignore')

print(f"Optuna version: {optuna.__version__}")
print(f"XGBoost version: {xgb.__version__}")

## 2. Download Dataset from PhysioNet

The dataset is hosted on PhysioNet. We'll download only the heart rate and labels folders (not motion data which is 2GB+).

In [None]:
# Create data directory
DATA_DIR = Path("./sleep_data")
DATA_DIR.mkdir(exist_ok=True)

# Download dataset using wget
# Note: PhysioNet requires accepting data use agreement for some datasets
DATASET_URL = "https://physionet.org/files/sleep-accel/1.0.0"

# Download heart_rate folder
!wget -q -r -np -nH --cut-dirs=3 -P {DATA_DIR} {DATASET_URL}/heart_rate/

# Download labels folder  
!wget -q -r -np -nH --cut-dirs=3 -P {DATA_DIR} {DATASET_URL}/labels/

print("Download complete!")
print(f"\nHeart rate files: {len(list((DATA_DIR / 'heart_rate').glob('*.txt')))}")
print(f"Label files: {len(list((DATA_DIR / 'labels').glob('*.txt')))}")

### Alternative: Upload from Google Drive

If you've already downloaded the dataset, you can mount Google Drive and point to it:

In [None]:
# Uncomment if using Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# DATA_DIR = Path('/content/drive/MyDrive/path/to/your/data')

## 3. Configuration

In [None]:
# Paths
HEART_RATE_DIR = DATA_DIR / "heart_rate"
LABELS_DIR = DATA_DIR / "labels"

# Sleep stage mapping
SLEEP_STAGE_NAMES = {
    0: 'Wake',
    1: 'N1',
    2: 'N2',
    3: 'N3',
    5: 'REM'
}

# PSG epoch duration
EPOCH_DURATION = 30  # seconds

## 4. Data Loading Functions

In [None]:
def load_heart_rate_data(subject_id: str) -> pd.DataFrame:
    """Load heart rate data for a subject."""
    hr_file = HEART_RATE_DIR / f"{subject_id}_heartrate.txt"
    df = pd.read_csv(hr_file, header=None, names=['timestamp', 'heart_rate'])
    df['subject_id'] = subject_id
    return df


def load_sleep_labels(subject_id: str) -> pd.DataFrame:
    """Load labeled sleep data for a subject."""
    label_file = LABELS_DIR / f"{subject_id}_labeled_sleep.txt"
    df = pd.read_csv(label_file, sep=' ', header=None, names=['time_offset', 'sleep_stage'])
    df['subject_id'] = subject_id
    return df


def get_subject_ids() -> List[str]:
    """Get list of all subject IDs."""
    hr_files = glob.glob(str(HEART_RATE_DIR / "*_heartrate.txt"))
    return sorted([Path(f).stem.replace('_heartrate', '') for f in hr_files])


# Test loading
subject_ids = get_subject_ids()
print(f"Found {len(subject_ids)} subjects")
print(f"Subject IDs: {subject_ids[:5]}...")

## 5. HRV Feature Extraction

We extract heart rate variability (HRV) features from each 30-second epoch:

| Feature | Description |
|---------|-------------|
| hr_mean, hr_std, hr_median | Basic statistics |
| hr_min, hr_max, hr_range | Range metrics |
| hr_cv, hr_iqr, hr_skew | Distribution metrics |
| hr_rmssd | Root Mean Square of Successive Differences |
| hr_pnn50 | Proportion of successive differences > 5 BPM |
| hr_slope | Linear trend within epoch |

In [None]:
def extract_hrv_features(hr_values: np.ndarray, timestamps: np.ndarray) -> Dict[str, float]:
    """
    Extract HRV features from heart rate data within an epoch.
    """
    features = {}
    
    if len(hr_values) < 2:
        return {
            'hr_mean': np.nan, 'hr_std': np.nan, 'hr_min': np.nan,
            'hr_max': np.nan, 'hr_range': np.nan, 'hr_median': np.nan,
            'hr_rmssd': np.nan, 'hr_pnn50': np.nan, 'hr_slope': np.nan,
            'hr_count': len(hr_values), 'hr_cv': np.nan,
            'hr_skew': np.nan, 'hr_iqr': np.nan
        }
    
    # Basic statistics
    features['hr_mean'] = np.mean(hr_values)
    features['hr_std'] = np.std(hr_values)
    features['hr_min'] = np.min(hr_values)
    features['hr_max'] = np.max(hr_values)
    features['hr_range'] = features['hr_max'] - features['hr_min']
    features['hr_median'] = np.median(hr_values)
    features['hr_count'] = len(hr_values)
    
    # Coefficient of variation
    features['hr_cv'] = features['hr_std'] / features['hr_mean'] if features['hr_mean'] > 0 else np.nan
    
    # IQR
    q75, q25 = np.percentile(hr_values, [75, 25])
    features['hr_iqr'] = q75 - q25
    
    # Skewness
    if features['hr_std'] > 0:
        features['hr_skew'] = np.mean(((hr_values - features['hr_mean']) / features['hr_std']) ** 3)
    else:
        features['hr_skew'] = np.nan
    
    # HRV-like features
    hr_diff = np.diff(hr_values)
    
    # RMSSD
    features['hr_rmssd'] = np.sqrt(np.mean(hr_diff ** 2)) if len(hr_diff) > 0 else np.nan
    
    # pNN50 (adapted: % of successive differences > 5 BPM)
    features['hr_pnn50'] = np.sum(np.abs(hr_diff) > 5) / len(hr_diff) if len(hr_diff) > 0 else np.nan
    
    # Slope
    if len(timestamps) >= 2:
        try:
            slope, _ = np.polyfit(timestamps - timestamps[0], hr_values, 1)
            features['hr_slope'] = slope
        except:
            features['hr_slope'] = np.nan
    else:
        features['hr_slope'] = np.nan
    
    return features

## 6. Align HR Data to Sleep Epochs

In [None]:
def align_hr_to_epochs(hr_df: pd.DataFrame, labels_df: pd.DataFrame) -> pd.DataFrame:
    """
    Align heart rate data to 30-second sleep epochs.
    """
    label_start = labels_df['time_offset'].min()
    label_end = labels_df['time_offset'].max() + EPOCH_DURATION
    
    hr_df_filtered = hr_df[
        (hr_df['timestamp'] >= label_start - EPOCH_DURATION) &
        (hr_df['timestamp'] <= label_end + EPOCH_DURATION)
    ].copy()
    
    features_list = []
    
    for _, row in labels_df.iterrows():
        epoch_start = row['time_offset']
        epoch_end = epoch_start + EPOCH_DURATION
        
        epoch_hr = hr_df_filtered[
            (hr_df_filtered['timestamp'] >= epoch_start) &
            (hr_df_filtered['timestamp'] < epoch_end)
        ]
        
        if len(epoch_hr) > 0:
            features = extract_hrv_features(
                epoch_hr['heart_rate'].values,
                epoch_hr['timestamp'].values
            )
        else:
            features = extract_hrv_features(np.array([]), np.array([]))
        
        features['time_offset'] = epoch_start
        features['sleep_stage'] = row['sleep_stage']
        features['subject_id'] = row['subject_id']
        features_list.append(features)
    
    return pd.DataFrame(features_list)

## 7. Prepare Complete Dataset

In [None]:
def prepare_dataset() -> pd.DataFrame:
    """Load and prepare dataset from all subjects."""
    subject_ids = get_subject_ids()
    print(f"Processing {len(subject_ids)} subjects...")
    
    all_features = []
    
    for i, subject_id in enumerate(subject_ids):
        print(f"  [{i+1}/{len(subject_ids)}] {subject_id}", end="")
        
        try:
            hr_df = load_heart_rate_data(subject_id)
            labels_df = load_sleep_labels(subject_id)
            features_df = align_hr_to_epochs(hr_df, labels_df)
            all_features.append(features_df)
            print(f" - {len(features_df)} epochs")
        except Exception as e:
            print(f" - ERROR: {e}")
            continue
    
    dataset = pd.concat(all_features, ignore_index=True)
    
    # Filter invalid sleep stages
    valid_stages = [0, 1, 2, 3, 5]
    dataset = dataset[dataset['sleep_stage'].isin(valid_stages)].copy()
    
    # Drop rows with insufficient HR samples
    dataset = dataset[dataset['hr_count'] >= 2].copy()
    dataset = dataset.dropna(subset=['hr_mean', 'hr_std'])
    
    # =============================================================
    # ADD TEMPORAL CONTEXT FEATURES (Critical for sleep staging!)
    # =============================================================
    print("\nAdding temporal context features...")
    
    # Sort by subject and time (CRITICAL for lag/lead to work correctly)
    dataset = dataset.sort_values(['subject_id', 'time_offset']).reset_index(drop=True)
    
    # 1. Rolling averages (trend context) - 5 epochs = 2.5 minutes window
    dataset['hr_mean_roll_5'] = dataset.groupby('subject_id')['hr_mean'].transform(
        lambda x: x.rolling(window=5, center=True, min_periods=1).mean()
    )
    dataset['hr_std_roll_5'] = dataset.groupby('subject_id')['hr_std'].transform(
        lambda x: x.rolling(window=5, center=True, min_periods=1).mean()
    )
    
    # 2. Lag features (past context) - what happened before?
    for lag in [1, 2, 4]:  # 30s, 60s, 2min ago
        dataset[f'hr_mean_lag_{lag}'] = dataset.groupby('subject_id')['hr_mean'].shift(lag)
        dataset[f'hr_std_lag_{lag}'] = dataset.groupby('subject_id')['hr_std'].shift(lag)
    
    # 3. Lead features (future context) - what happens next? (valid for offline analysis)
    for lead in [1, 2]:  # 30s, 60s ahead
        dataset[f'hr_mean_lead_{lead}'] = dataset.groupby('subject_id')['hr_mean'].shift(-lead)
    
    # 4. Rate of change (is HR dropping or rising?)
    dataset['hr_diff_1'] = dataset['hr_mean'] - dataset.groupby('subject_id')['hr_mean'].shift(1)
    dataset['hr_diff_2'] = dataset['hr_mean'] - dataset.groupby('subject_id')['hr_mean'].shift(2)
    
    # 5. Variability change
    dataset['hr_std_diff_1'] = dataset['hr_std'] - dataset.groupby('subject_id')['hr_std'].shift(1)
    
    # Drop NaNs created by shifting (edges of each subject's recording)
    dataset = dataset.dropna()
    
    print(f"\nTotal epochs: {len(dataset)}")
    print(f"\nClass distribution:")
    for stage, name in SLEEP_STAGE_NAMES.items():
        count = (dataset['sleep_stage'] == stage).sum()
        pct = count / len(dataset) * 100
        print(f"  {name}: {count} ({pct:.1f}%)")
    
    return dataset


# Load dataset
dataset = prepare_dataset()

## 8. Optuna Objective Function

In [None]:
# ================================================================
# CONFIGURATION - Adjust these settings
# ================================================================
USE_CLASS_WEIGHTS = False  # Set True to balance classes (hurts accuracy without motion data)
NUM_CLASSES = 5            # Set to 3 for simplified Wake/NREM/REM classification

def create_objective(X, y, groups, n_folds=5, use_class_weights=False):
    """
    Create Optuna objective function for XGBoost tuning.
    Uses GroupKFold to prevent subject leakage.
    """
    
    # Optionally calculate class weights
    sample_weights = None
    if use_class_weights:
        class_weights = {}
        for cls in np.unique(y):
            class_weights[cls] = len(y) / (len(np.unique(y)) * np.sum(y == cls))
        sample_weights = np.array([class_weights[yi] for yi in y])
    
    def objective(trial):
        # Hyperparameter search space
        params = {
            'objective': 'multi:softmax',
            'num_class': len(np.unique(y)),
            'eval_metric': 'mlogloss',
            'booster': 'gbtree',  # Fixed to gbtree (dart is 10-50x slower)
            'device': 'cuda',     # Use GPU (Colab T4/A100)
            'lambda': trial.suggest_float('lambda', 1e-8, 10.0, log=True),
            'alpha': trial.suggest_float('alpha', 1e-8, 10.0, log=True),
            'max_depth': trial.suggest_int('max_depth', 3, 10),
            'learning_rate': trial.suggest_float('learning_rate', 1e-3, 0.3, log=True),
            'n_estimators': trial.suggest_int('n_estimators', 50, 300),
            'min_child_weight': trial.suggest_int('min_child_weight', 1, 10),
            'subsample': trial.suggest_float('subsample', 0.5, 1.0),
            'colsample_bytree': trial.suggest_float('colsample_bytree', 0.5, 1.0),
            'gamma': trial.suggest_float('gamma', 1e-8, 1.0, log=True),
            'random_state': 42,
            'verbosity': 0
        }
        
        # Cross-validation with GroupKFold
        cv = GroupKFold(n_splits=n_folds)
        f1_scores = []
        
        for fold_idx, (train_idx, val_idx) in enumerate(cv.split(X, y, groups)):
            X_train, X_val = X[train_idx], X[val_idx]
            y_train, y_val = y[train_idx], y[val_idx]
            weights_train = sample_weights[train_idx] if sample_weights is not None else None
            
            model = xgb.XGBClassifier(**params)
            model.fit(X_train, y_train, sample_weight=weights_train,
                     eval_set=[(X_val, y_val)],
                     verbose=False)
            
            y_pred = model.predict(X_val)
            f1 = f1_score(y_val, y_pred, average='macro')
            f1_scores.append(f1)
            
            # Pruning
            trial.report(np.mean(f1_scores), fold_idx)
            if trial.should_prune():
                raise optuna.TrialPruned()
        
        return np.mean(f1_scores)
    
    return objective

## 9. Run Optuna Optimization

In [None]:
# Prepare features and labels
feature_cols = [c for c in dataset.columns if c.startswith('hr_')]
X = dataset[feature_cols].values

print(f"Features ({len(feature_cols)} total): {feature_cols}")

# Encode labels
label_encoder = LabelEncoder()

# Optional: Simplify to 3 classes (Wake/NREM/REM) for testing
if NUM_CLASSES == 3:
    print("\n*** Using simplified 3-class mode: Wake / NREM / REM ***")
    # Map N1, N2, N3 -> NREM (use value 2)
    y_raw = dataset['sleep_stage'].replace({1: 2, 3: 2}).values
    SLEEP_STAGE_NAMES_USED = {0: 'Wake', 2: 'NREM', 5: 'REM'}
else:
    y_raw = dataset['sleep_stage'].values
    SLEEP_STAGE_NAMES_USED = SLEEP_STAGE_NAMES

y = label_encoder.fit_transform(y_raw)

# Subject groups
subject_encoder = LabelEncoder()
groups = subject_encoder.fit_transform(dataset['subject_id'].values)

# Fill NaN with column medians
for i in range(X.shape[1]):
    col_median = np.nanmedian(X[:, i])
    X[np.isnan(X[:, i]), i] = col_median

print(f"\nFeatures shape: {X.shape}")
print(f"Classes: {label_encoder.classes_} -> {[SLEEP_STAGE_NAMES_USED.get(c, c) for c in label_encoder.classes_]}")
print(f"Subjects: {len(np.unique(groups))}")
print(f"Class weights: {'ENABLED' if USE_CLASS_WEIGHTS else 'DISABLED'}")

In [None]:
# Create and run study
N_TRIALS = 100  # Adjust based on compute budget

study = optuna.create_study(
    direction='maximize',
    sampler=TPESampler(seed=42),
    pruner=MedianPruner(n_startup_trials=10, n_warmup_steps=2)
)

objective = create_objective(X, y, groups, n_folds=5, use_class_weights=USE_CLASS_WEIGHTS)

study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=True)

print(f"\n{'='*50}")
print(f"Best trial:")
print(f"  Macro F1: {study.best_trial.value:.4f}")
print(f"  Params: {study.best_trial.params}")

## 10. Final Evaluation

In [None]:
# Final evaluation with GroupKFold
cv = GroupKFold(n_splits=5)
all_y_true = []
all_y_pred = []

for train_idx, val_idx in cv.split(X, y, groups):
    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]
    
    params = study.best_params.copy()
    params['objective'] = 'multi:softmax'
    params['num_class'] = len(np.unique(y))
    params['random_state'] = 42
    params['verbosity'] = 0
    
    model = xgb.XGBClassifier(**params)
    model.fit(X_train, y_train)  # No sample weights for final eval
    
    y_pred = model.predict(X_val)
    all_y_true.extend(y_val)
    all_y_pred.extend(y_pred)

# Classification report
stage_names = [SLEEP_STAGE_NAMES_USED.get(s, str(s)) for s in label_encoder.classes_]
print("Classification Report:")
print(classification_report(all_y_true, all_y_pred, target_names=stage_names))

# Also show accuracy
from sklearn.metrics import accuracy_score
print(f"\nOverall Accuracy: {accuracy_score(all_y_true, all_y_pred):.4f}")

In [None]:
# Confusion matrix visualization
import matplotlib.pyplot as plt
import seaborn as sns

cm = confusion_matrix(all_y_true, all_y_pred)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Raw counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=stage_names, yticklabels=stage_names, ax=axes[0])
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')
axes[0].set_title('Confusion Matrix (Counts)')

# Normalized
sns.heatmap(cm_normalized, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=stage_names, yticklabels=stage_names, ax=axes[1])
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')
axes[1].set_title('Confusion Matrix (Normalized)')

plt.tight_layout()
plt.show()

## 11. Feature Importance

In [None]:
# Train final model on all data
final_params = study.best_params.copy()
final_params['objective'] = 'multi:softmax'
final_params['num_class'] = len(np.unique(y))
final_params['random_state'] = 42
final_params['verbosity'] = 0

final_model = xgb.XGBClassifier(**final_params)
final_model.fit(X, y, sample_weight=sample_weights)

# Plot feature importance
importance = final_model.feature_importances_
sorted_idx = np.argsort(importance)[::-1]

plt.figure(figsize=(10, 6))
plt.bar(range(len(importance)), importance[sorted_idx])
plt.xticks(range(len(importance)), [feature_cols[i] for i in sorted_idx], rotation=45, ha='right')
plt.xlabel('Feature')
plt.ylabel('Importance')
plt.title('XGBoost Feature Importance')
plt.tight_layout()
plt.show()

print("\nTop features:")
for i in sorted_idx[:5]:
    print(f"  {feature_cols[i]}: {importance[i]:.4f}")

## 12. Optuna Visualization

In [None]:
# Optimization history
from optuna.visualization import plot_optimization_history, plot_param_importances

fig = plot_optimization_history(study)
fig.show()

In [None]:
# Hyperparameter importance
fig = plot_param_importances(study)
fig.show()

## 13. Save Results

In [None]:
# Save best parameters
pd.DataFrame([study.best_params]).to_csv('best_params.csv', index=False)

# Save model
final_model.save_model('sleep_stage_model.json')

# Save optimization history
study.trials_dataframe().to_csv('optimization_history.csv', index=False)

print("Results saved!")
print("  - best_params.csv")
print("  - sleep_stage_model.json")
print("  - optimization_history.csv")

In [None]:
# Download files (Colab)
try:
    from google.colab import files
    files.download('best_params.csv')
    files.download('sleep_stage_model.json')
    files.download('optimization_history.csv')
except:
    print("Not running in Colab - files saved to current directory")