# 09: Train/Test Split

Split model-ready data into training, validation, and test sets.

## Strategy

For IMU data with time series characteristics:
- **Stratified split**: Preserves class distribution across splits
- **Optional athlete-based split**: Ensures unseen athletes in test set
- **Optional time-based split**: Ensures temporal separation

Standard split: **60% train, 20% val, 20% test**


In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import sys
sys.path.append('../../src')

from sledhead_imu.prep.split_train_test import split_train_test

# Load model-ready data
data_dir = Path('../data')
model_ready_dir = data_dir / '08_model_ready_build' / 'model_ready_data'
splits_dir = data_dir / '09_splits'

# Find model-ready files
feature_files = list(model_ready_dir.glob('features_*.csv'))
print(f"Found {len(feature_files)} feature files")

if not feature_files:
    print("No feature files found. Run 08_model_ready_build.ipynb first.")
else:
    # Load features and labels
    feature_file = model_ready_dir / 'features_rf.csv'  # Latest RF features
    label_file = model_ready_dir / 'labels_rf.csv'
    
    if feature_file.exists() and label_file.exists():
        X = pd.read_csv(feature_file)
        y = pd.read_csv(label_file)
        
        print(f"\nLoaded data:")
        print(f"  Samples: {len(X)}")
        print(f"  Features: {X.shape[1]}")
        print(f"  Label distribution:\n{y['severity'].value_counts().sort_index()}")
        
        # Check if we have enough data for splits
        if len(X) < 3:
            print(f"\n⚠️  WARNING: Only {len(X)} samples available.")
            print("Not enough data for train/val/test split.")
            print("Need at least 3 samples (ideally 10+)")
        else:
            print(f"\n✓ Sufficient data for split")
    else:
        print(f"\nFiles not found:")
        print(f"  Features: {feature_file.exists()}")
        print(f"  Labels: {label_file.exists()}")


In [None]:
# Perform train/val/test split
# Only execute if we have sufficient data

if 'X' in locals() and 'y' in locals() and len(X) >= 3:
    # First split: train+val vs test (80/20)
    X_train_val, X_test, y_train_val, y_test = split_train_test(
        X, y, test_size=0.2, random_state=42
    )
    
    # Second split: train vs val (75/25 of remaining 80%)
    X_train, X_val, y_train, y_val = split_train_test(
        X_train_val, y_train_val, test_size=0.25, random_state=42
    )
    
    print(f"\n✓ Splits created:")
    print(f"  Train: {len(X_train)} samples ({len(X_train)/len(X)*100:.1f}%)")
    print(f"  Val: {len(X_val)} samples ({len(X_val)/len(X)*100:.1f}%)")
    print(f"  Test: {len(X_test)} samples ({len(X_test)/len(X)*100:.1f}%)")
    
    # Save splits
    splits_dir.mkdir(parents=True, exist_ok=True)
    train_dir = splits_dir / 'train'
    val_dir = splits_dir / 'val'
    test_dir = splits_dir / 'test'
    
    train_dir.mkdir(exist_ok=True)
    val_dir.mkdir(exist_ok=True)
    test_dir.mkdir(exist_ok=True)
    
    # Save train
    X_train.to_csv(train_dir / 'X_train.csv', index=False)
    y_train.to_csv(train_dir / 'y_train.csv', index=False)
    print(f"\n✓ Saved train split to {train_dir}")
    
    # Save val
    X_val.to_csv(val_dir / 'X_val.csv', index=False)
    y_val.to_csv(val_dir / 'y_val.csv', index=False)
    print(f"✓ Saved val split to {val_dir}")
    
    # Save test
    X_test.to_csv(test_dir / 'X_test.csv', index=False)
    y_test.to_csv(test_dir / 'y_test.csv', index=False)
    print(f"✓ Saved test split to {test_dir}")
    
    # Show label distribution per split
    print(f"\nLabel distribution by split:")
    print(f"\n  Train:\n{y_train['severity'].value_counts().sort_index()}")
    print(f"\n  Val:\n{y_val['severity'].value_counts().sort_index()}")
    print(f"\n  Test:\n{y_test['severity'].value_counts().sort_index()}")
    
else:
    print("\n⚠️  Skipping split - insufficient data")


In [None]:
# Optional: Generate synthetic data for demonstration
# This creates a larger dataset to show the split functionality

if 'X' not in locals() or 'y' not in locals() or len(X) < 3:
    print("Generating synthetic data for demonstration...")
    
    np.random.seed(42)
    n_samples = 100
    
    # Create synthetic features
    X_synth = pd.DataFrame({
        'time_above_2.0g': np.random.exponential(10, n_samples),
        'time_above_3.0g': np.random.exponential(3, n_samples),
        'g_seconds_2.0g': np.random.exponential(15, n_samples),
        'g_seconds_3.0g': np.random.exponential(5, n_samples),
        'num_peaks_over_3g': np.random.poisson(5, n_samples),
        'num_peaks_over_4g': np.random.poisson(2, n_samples),
        'longest_2g_duration': np.random.exponential(0.1, n_samples),
        'run_duration': np.random.uniform(60, 180, n_samples),
        'accel_mean': np.random.uniform(1, 3, n_samples),
        'accel_std': np.random.uniform(0.5, 2, n_samples),
        'gyro_mean': np.random.uniform(0, 50, n_samples),
        'gyro_std': np.random.uniform(0, 20, n_samples),
        'jerk_mean': np.random.uniform(10, 100, n_samples),
        'dominant_freq': np.random.uniform(0.01, 0.5, n_samples),
        'accel_max': np.random.uniform(2, 10, n_samples),
        'accel_range': np.random.uniform(1, 8, n_samples),
        'gyro_max': np.random.uniform(20, 200, n_samples),
        'jerk_max': np.random.uniform(100, 2000, n_samples),
        'highest_peak_g': np.random.uniform(2, 9, n_samples),
        'num_symptoms': np.random.randint(0, 5, n_samples),
        'accel_min': np.random.uniform(0, 1, n_samples),
    })
    
    # Create synthetic labels (severity 0-5)
    y_synth = pd.DataFrame({
        'severity': np.random.choice([0, 1, 2, 3, 4, 5], n_samples, p=[0.3, 0.25, 0.2, 0.15, 0.05, 0.05])
    })
    
    print(f"✓ Created {n_samples} synthetic samples")
    
    # Perform split
    X_train_val, X_test, y_train_val, y_test = split_train_test(
        X_synth, y_synth, test_size=0.2, random_state=42
    )
    
    X_train, X_val, y_train, y_val = split_train_test(
        X_train_val, y_train_val, test_size=0.25, random_state=42
    )
    
    print(f"\n✓ Splits created from synthetic data:")
    print(f"  Train: {len(X_train)} ({len(X_train)/len(X_synth)*100:.1f}%)")
    print(f"  Val: {len(X_val)} ({len(X_val)/len(X_synth)*100:.1f}%)")
    print(f"  Test: {len(X_test)} ({len(X_test)/len(X_synth)*100:.1f}%)")
    
    # Save
    splits_dir.mkdir(parents=True, exist_ok=True)
    splits_dir.joinpath('train').mkdir(exist_ok=True)
    splits_dir.joinpath('val').mkdir(exist_ok=True)
    splits_dir.joinpath('test').mkdir(exist_ok=True)
    
    X_train.to_csv(splits_dir / 'train' / 'X_train.csv', index=False)
    y_train.to_csv(splits_dir / 'train' / 'y_train.csv', index=False)
    X_val.to_csv(splits_dir / 'val' / 'X_val.csv', index=False)
    y_val.to_csv(splits_dir / 'val' / 'y_val.csv', index=False)
    X_test.to_csv(splits_dir / 'test' / 'X_test.csv', index=False)
    y_test.to_csv(splits_dir / 'test' / 'y_test.csv', index=False)
    
    print(f"\n✓ Saved splits to data/09_splits/")
    
    print(f"\nLabel distribution:")
    print(f"\n  Train:\n{y_train['severity'].value_counts().sort_index()}")
    print(f"\n  Val:\n{y_val['severity'].value_counts().sort_index()}")
    print(f"\n  Test:\n{y_test['severity'].value_counts().sort_index()}")
