In [None]:
# reorganize_and_explore.py
from pathlib import Path
import pandas as pd
import numpy as np
import keras
import matplotlib.pyplot as plt

# Use relative path from notebook location
# This works whether you're in notebooks/, scripts/, or project root
notebook_dir = Path.cwd()

# Find project root (where pyproject.toml is)
project_root = notebook_dir
while not (project_root / "pyproject.toml").exists():
    project_root = project_root.parent
    if project_root == project_root.parent:  # Reached filesystem root
        raise FileNotFoundError("Could not find project root (pyproject.toml not found)")

# Now use relative paths from project root
base_dir = project_root / "fall_detection_data"
processed_dir = base_dir / "processed"
models_dir = base_dir / "models"
models_dir.mkdir(exist_ok=True)
output_dir = models_dir

print(f"üìÇ Project root: {project_root}")
print(f"üìÇ Data directory: {base_dir}")
print(f"üìÇ Models directory: {models_dir}")
print()

print("=" * 80)
print("CURRENT DIRECTORY STRUCTURE")
print("=" * 80)

# Show current structure
for item in sorted(base_dir.iterdir()):
    if item.is_dir():
        print(f"\nüìÅ {item.name}/")
        # Show what's inside each directory
        sub_items = list(item.iterdir())[:5]
        for sub in sub_items:
            if sub.is_dir():
                file_count = len(list(sub.glob("*")))
                print(f"   üìÅ {sub.name}/ ({file_count} files)")
            else:
                print(f"   üìÑ {sub.name}")
        if len(list(item.iterdir())) > 5:
            print(f"   ... and {len(list(item.iterdir())) - 5} more")

print("\n" + "=" * 80)
print("PROPOSED REORGANIZATION")
print("=" * 80)

proposed_structure = """
fall_detection_data/
‚îú‚îÄ‚îÄ KFall/
‚îÇ   ‚îú‚îÄ‚îÄ sensor_data/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ SA06/
‚îÇ   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ S06T01R01.csv  (KFall format: S##T##R##.csv)
‚îÇ   ‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ S06T02R01.csv
‚îÇ   ‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ SA07/ ...
‚îÇ   ‚îî‚îÄ‚îÄ labels/
‚îÇ       ‚îú‚îÄ‚îÄ SA06_label.xlsx
‚îÇ       ‚îî‚îÄ‚îÄ SA07_label.xlsx ...
‚îÇ
‚îú‚îÄ‚îÄ SisFall/
‚îÇ   ‚îú‚îÄ‚îÄ SA01/
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ D01_SA01_R01.txt  (SisFall format: <CODE>_<SUBJECT>_<TRIAL>.txt)
‚îÇ   ‚îÇ   ‚îú‚îÄ‚îÄ F01_SA01_R01.txt
‚îÇ   ‚îÇ   ‚îî‚îÄ‚îÄ ...
‚îÇ   ‚îú‚îÄ‚îÄ SA02/ ...
‚îÇ   ‚îî‚îÄ‚îÄ SE01/ ... (elderly subjects)
‚îÇ
‚îî‚îÄ‚îÄ processed/
    ‚îú‚îÄ‚îÄ kfall_features.pkl
    ‚îú‚îÄ‚îÄ sisfall_features.pkl
    ‚îî‚îÄ‚îÄ fused_dataset.pkl
"""

print(proposed_structure)

print("\n" + "=" * 80)
print("DATASET COMPARISON")
print("=" * 80)

# KFall structure
kfall_sensor = base_dir / "KFall" / "sensor_data"
if kfall_sensor.exists():
    kfall_subjects = sorted([d.name for d in kfall_sensor.iterdir() if d.is_dir()])
    sample_kfall = kfall_sensor / kfall_subjects[0]
    sample_kfall_file = list(sample_kfall.glob("*.csv"))[0]
    
    df_kfall = pd.read_csv(sample_kfall_file)
    
    print("\nüìä KFALL DATASET:")
    print(f"   Subjects: {len(kfall_subjects)} (SA06-SA38)")
    print(f"   Sampling Rate: 100 Hz (needs upsampling to 200 Hz)")
    print(f"   File Format: S##T##R##.csv")
    print(f"   Columns: {df_kfall.columns.tolist()}")
    print(f"   Data Shape (sample): {df_kfall.shape}")
    print(f"   Has Labels: ‚úÖ Yes (temporal annotations in Excel files)")

# SisFall structure
sisfall_dir = base_dir / "SisFall"
if sisfall_dir.exists():
    sisfall_subjects = sorted([d.name for d in sisfall_dir.iterdir() if d.is_dir()])
    adults = [s for s in sisfall_subjects if s.startswith('SA')]
    elderly = [s for s in sisfall_subjects if s.startswith('SE')]
    
    sample_sisfall = sisfall_dir / adults[0]
    sample_sisfall_file = list(sample_sisfall.glob("*.txt"))[0]
    
    # Read SisFall file - more robust parsing
    try:
        # Method 1: Read line by line and parse manually
        with open(sample_sisfall_file, 'r') as f:
            lines = f.readlines()
        
        data = []
        for line in lines:
            # Remove semicolon and split by comma or whitespace
            line = line.strip().replace(';', '')
            values = line.replace(',', ' ').split()
            if len(values) == 9:  # Should have 9 columns
                data.append([float(v) for v in values])
        
        df_sisfall = pd.DataFrame(data)
        
        print("\nüìä SISFALL DATASET:")
        print(f"   Subjects: {len(sisfall_subjects)} total")
        print(f"     - Adults (SA): {len(adults)} (SA01-SA23)")
        print(f"     - Elderly (SE): {len(elderly)} (SE01-SE15)")
        print(f"   Sampling Rate: 200 Hz ‚úÖ")
        print(f"   File Format: <CODE>_<SUBJECT>_<TRIAL>.txt")
        print(f"   Columns: 9 (ADXL345: 0-2, ITG3200: 3-5, MMA8451Q: 6-8)")
        print(f"   Data Shape (sample): {df_sisfall.shape}")
        print(f"   Has Labels: ‚ùå No (must use Algorithm 1)")
        print(f"   Data Format: Raw bits (needs conversion to physical units)")
        
    except Exception as e:
        print(f"\n‚ùå Error reading SisFall file: {e}")
        print("   Will handle this in the preprocessing pipeline")

print("\n" + "=" * 80)
print("ACTIVITIES NEEDED FOR PAPER REPRODUCTION")
print("=" * 80)

print("\nüìã FROM KFALL (Table I):")
kfall_needed = {
    'T10': 'Stumble while walking',
    'T28': 'Vertical fall while walking (fainting)',
    'T30': 'Forward fall while walking (trip)',
    'T31': 'Forward fall while jogging (trip)',
    'T32': 'Forward fall while walking (slip)',
    'T33': 'Lateral fall while walking (slip)',
    'T34': 'Backward fall while walking (slip)'
}
for code, desc in kfall_needed.items():
    print(f"   {code}: {desc}")

print("\nüìã FROM SISFALL (Table I):")
print("\n   ADL Activities:")
sisfall_adl = {
    'D01': 'Walking slowly',
    'D02': 'Walking quickly',
    'D03': 'Jogging slowly',
    'D04': 'Jogging quickly',
    'D05': 'Walking upstairs/downstairs slowly',
    'D06': 'Walking upstairs/downstairs quickly',
    'D18': 'Stumble while walking'
}
for code, desc in sisfall_adl.items():
    print(f"   {code}: {desc}")

print("\n   Fall Activities:")
sisfall_falls = {
    'F01': 'Fall forward while walking (slip)',
    'F02': 'Fall backward while walking (slip)',
    'F03': 'Lateral fall while walking (slip)',
    'F04': 'Fall forward while walking (trip)',
    'F05': 'Fall forward while jogging (trip)',
    'F06': 'Vertical fall while walking (fainting)'
}
for code, desc in sisfall_falls.items():
    print(f"   {code}: {desc}")

print("\n" + "=" * 80)
print("NEXT STEPS")
print("=" * 80)
print("""
1. ‚úÖ Data is properly organized
2. ‚è≠Ô∏è  Implement preprocessing pipeline:
   - Load and convert SisFall raw bits to physical units
   - Upsample KFall from 100Hz to 200Hz
   - Apply Algorithm 1 for temporal segmentation
   - Extract features according to Table I
3. ‚è≠Ô∏è  Z-score normalization and dataset fusion
4. ‚è≠Ô∏è  Build and train FallNet""")

In [None]:
# %% [markdown]
# # Fall Detection Data Preprocessing Pipeline - CORRECTED VERSION
# 
# This notebook implements the preprocessing methodology from the paper:
# "A novel Feature extraction method for Pre-Impact Fall detection system using Deep learning and wearable sensors"
#
# Key fixes:
# - Removed Sp - 3 bug
# - Fixed transitional window logic (no duplicates)
# - Proper ADL extraction before falls
# - Correct stumble/recovery processing

# %% [markdown]
## 1. Setup and Imports

# %%
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.interpolate import CubicSpline
from sklearn.preprocessing import StandardScaler
import pickle
import json
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("‚úÖ Imports complete")

# %% [markdown]
## 2. Define Paths and Configuration

# %%
# Auto-detect project root (works from any directory in the project)
current_dir = Path.cwd()

# Find project root by looking for pyproject.toml
project_root = current_dir
while not (project_root / "pyproject.toml").exists():
    project_root = project_root.parent
    if project_root == project_root.parent:  # Reached filesystem root
        # Fallback: assume we're in notebooks/ directory
        project_root = current_dir.parent if current_dir.name == "notebooks" else current_dir
        break

# Define paths relative to project root
base_dir = project_root / "fall_detection_data"
kfall_sensor_dir = base_dir / "KFall" / "sensor_data"
kfall_labels_dir = base_dir / "KFall" / "label_data"
sisfall_dir = base_dir / "SisFall"
processed_dir = base_dir / "processed"

# Verify paths exist
if not base_dir.exists():
    raise FileNotFoundError(
        f"Data directory not found: {base_dir}\n"
        f"Expected structure: {project_root}/fall_detection_data/\n"
        f"Current directory: {current_dir}\n"
        f"Project root: {project_root}"
    )

# Clean up processed directory
print("üßπ Cleaning processed directory...")
if processed_dir.exists():
    for f in processed_dir.glob("*"):
        if f.is_file():
            f.unlink()
            print(f"  Deleted: {f.name}")
else:
    processed_dir.mkdir(exist_ok=True)

print("\n‚úÖ Directories configured:")
print(f"   Project root: {project_root}")
print(f"   Data directory: {base_dir}")
print(f"   KFall sensor data: {kfall_sensor_dir}")
print(f"   KFall labels: {kfall_labels_dir}")
print(f"   SisFall data: {sisfall_dir}")
print(f"   Output: {processed_dir}")

# %% [markdown]
## 3. Activity Mappings and Labels

# %%
# Activities from Table I in the paper
kfall_fall_activities = ['T28', 'T30', 'T31', 'T32', 'T33', 'T34']
kfall_stumble = ['T10']

sisfall_adl_map = {
    'D01': 'Walking', 'D02': 'Walking',
    'D03': 'Jogging', 'D04': 'Jogging',
    'D05': 'Walking_stairs_updown', 'D06': 'Walking_stairs_updown',
    'D18': 'Stumble_while_walking'
}

sisfall_falls = ['F01', 'F02', 'F03', 'F04', 'F05', 'F06']

# Label encoding (8-class classifier)
label_map = {
    'Walking': 0,
    'Jogging': 1,
    'Walking_stairs_updown': 2,
    'Stumble_while_walking': 3,
    'Fall_Recovery': 4,
    'Fall_Initiation': 5,
    'Impact': 6,
    'Aftermath': 7
}

reverse_label_map = {v: k for k, v in label_map.items()}

print("üìã Label Mapping:")
for label_name, label_id in label_map.items():
    print(f"   {label_id}: {label_name}")

# Save label map immediately
with open(processed_dir / "label_map.json", "w") as f:
    json.dump(label_map, f, indent=2)
print("\n‚úÖ Label map saved")

# %% [markdown]
## 4. Data Loading Functions

# %%
def load_sisfall_file(filepath):
    """Load and convert SisFall file from bits to physical units"""
    with open(filepath, 'r') as f:
        lines = f.readlines()
    
    data = []
    for line in lines:
        line = line.strip().replace(';', '').replace(',', ' ')
        values = line.split()
        if len(values) == 9:
            data.append([float(v) for v in values])
    
    if len(data) == 0:
        return None
    
    data = np.array(data)
    
    # Convert to physical units
    converted = np.zeros((data.shape[0], 6))
    
    # ADXL345 (columns 0-2): ¬±16g, 13-bit
    adxl_factor = (2 * 16) / (2**13)
    converted[:, 0:3] = data[:, 0:3] * adxl_factor
    
    # ITG3200 (columns 3-5): ¬±2000¬∞/s, 16-bit  
    itg_factor = (2 * 2000) / (2**16)
    converted[:, 3:6] = data[:, 3:6] * itg_factor
    
    return converted

def load_kfall_file(filepath):
    """Load KFall CSV file"""
    try:
        df = pd.read_csv(filepath)
        data = df[['AccX', 'AccY', 'AccZ', 'GyrX', 'GyrY', 'GyrZ']].values
        return data
    except:
        return None

print("‚úÖ Data loading functions defined")

# %% [markdown]
## 5. Upsampling Function (KFall 100Hz ‚Üí 200Hz)

# %%
def upsample_to_200hz(data, original_freq=100):
    """Upsample KFall data from 100Hz to 200Hz using cubic spline"""
    n_samples, n_features = data.shape
    original_time = np.arange(n_samples) / original_freq
    target_time = np.arange(0, n_samples / original_freq, 1 / 200)
    
    upsampled = np.zeros((len(target_time), n_features))
    for i in range(n_features):
        cs = CubicSpline(original_time, data[:, i])
        upsampled[:, i] = cs(target_time)
    
    return upsampled

print("‚úÖ Upsampling function defined")

# %% [markdown]
## 6. Algorithm 1: Temporal Feature Extraction (CORRECTED)

# %%
def extract_temporal_features(data, sampling_freq=200):
    """
    Algorithm 1 from the paper: Automatic temporal feature extraction
    Uses Y-axis acceleration (gravity direction)
    
    FIXED: Removed Sp - 3 bug
    """
    acc_y = data[:, 1]  # Y-axis
    W_s = sampling_freq // 4  # 50 samples (0.25s)
    
    # Calculate std on non-overlapping windows
    std_devs = []
    window_positions = []
    for i in range(0, len(acc_y) - W_s, W_s):
        window = acc_y[i:i + W_s]
        std_devs.append(np.std(window))
        window_positions.append(i)
    
    if len(std_devs) == 0:
        return None
    
    # Find segmentation point
    # CRITICAL FIX: Don't subtract 3!
    max_std_idx = np.argmax(std_devs)
    Sp = max_std_idx  # Paper says: "The starting frame of Sw will become Sp"
    
    segments = {
        'std_devs': std_devs,
        'window_positions': window_positions,
        'Sp': Sp,
        'W_s': W_s,
        
        # Phase boundaries (in samples)
        'adl_end': Sp * W_s,
        'fall_init_start': Sp * W_s,
        'fall_init_end': min((Sp + 4) * W_s, len(data)),
        'transitional_end': min((Sp + 2) * W_s, len(data)),
        'impact_start': min((Sp + 4) * W_s, len(data)),
        'impact_end': min((Sp + 8) * W_s, len(data)),
        'aftermath_start': min((Sp + 8) * W_s, len(data)),
    }
    
    return segments

print("‚úÖ Algorithm 1 implemented (CORRECTED)")

# %% [markdown]
## 7. Feature Extraction Functions (CORRECTED)

# %%
def process_fall_activity(data):
    """
    Extract features from fall activity using Algorithm 1
    
    FIXED:
    - No duplicate Fall_Initiation samples
    - Properly extracts ADL before fall
    - Uses transitional window for 50% of samples (random selection)
    """
    segments = extract_temporal_features(data)
    if segments is None:
        return []
    
    results = []
    W_s = segments['W_s']
    
    # 1. ADL phase (before fall) - if available
    adl_start = max(0, segments['fall_init_start'] - 200)
    adl_end = segments['fall_init_start']
    
    if adl_end - adl_start >= 200 and adl_start >= 0:
        adl_segment = data[adl_start:adl_end]
        
        # Determine ADL type based on variance
        acc_std = np.std(adl_segment[:, 1])
        if acc_std > 0.5:
            adl_label = label_map['Jogging']
        else:
            adl_label = label_map['Walking']
        
        results.append((adl_segment[:200], adl_label))
    
    # 2. Fall Initiation - ONE sample per fall
    # Randomly choose between transitional window (0.5s) or full window (1s)
    fi_start = segments['fall_init_start']
    
    if np.random.random() < 0.5:
        # Use transitional window (0.5s) for early detection training
        tw_end = segments['transitional_end']
        if tw_end <= len(data) and (tw_end - fi_start) >= 100:
            tw_segment = data[fi_start:tw_end]
            
            # Interpolate to 200 samples
            if len(tw_segment) != 200:
                time_orig = np.linspace(0, 1, len(tw_segment))
                time_new = np.linspace(0, 1, 200)
                tw_interp = np.zeros((200, 6))
                for i in range(6):
                    tw_interp[:, i] = np.interp(time_new, time_orig, tw_segment[:, i])
                results.append((tw_interp, label_map['Fall_Initiation']))
            else:
                results.append((tw_segment, label_map['Fall_Initiation']))
    else:
        # Use full Fall Initiation window (1s)
        fi_end = segments['fall_init_end']
        if fi_end <= len(data) and (fi_end - fi_start) >= 200:
            fi_segment = data[fi_start:fi_end]
            results.append((fi_segment[:200], label_map['Fall_Initiation']))
    
    # 3. Impact
    impact_start = segments['impact_start']
    impact_end = segments['impact_end']
    if impact_end <= len(data) and (impact_end - impact_start) >= 200:
        impact_segment = data[impact_start:impact_end]
        results.append((impact_segment[:200], label_map['Impact']))
    
    # 4. Aftermath
    aftermath_start = segments['aftermath_start']
    if len(data) - aftermath_start >= 200:
        aftermath_segment = data[aftermath_start:aftermath_start + 200]
        results.append((aftermath_segment, label_map['Aftermath']))
    
    return results


def process_stumble_activity(data):
    """
    Process stumble/fall recovery
    
    Stumble = temporary loss of balance WITHOUT falling (recovers)
    """
    segments = extract_temporal_features(data)
    if segments is None:
        return []
    
    results = []
    
    # Extract the "stumble" moment (the imbalance event)
    stumble_start = segments['fall_init_start']
    stumble_end = segments['transitional_end']
    
    if stumble_end <= len(data) and (stumble_end - stumble_start) >= 100:
        stumble_segment = data[stumble_start:stumble_end]
        
        # Interpolate to 200 samples if needed
        if len(stumble_segment) < 200:
            time_orig = np.linspace(0, 1, len(stumble_segment))
            time_new = np.linspace(0, 1, 200)
            stumble_interp = np.zeros((200, 6))
            for i in range(6):
                stumble_interp[:, i] = np.interp(time_new, time_orig, stumble_segment[:, i])
            results.append((stumble_interp, label_map['Stumble_while_walking']))
        else:
            results.append((stumble_segment[:200], label_map['Stumble_while_walking']))
    
    # Fall recovery - the recovery period after stumble
    recovery_start = segments['transitional_end']
    recovery_end = segments['impact_end']
    
    if recovery_end <= len(data) and (recovery_end - recovery_start) >= 200:
        recovery_segment = data[recovery_start:recovery_end]
        results.append((recovery_segment[:200], label_map['Fall_Recovery']))
    
    return results


def process_adl_activity(data, label_name):
    """
    Extract 1-second non-overlapping windows from ADL activities
    """
    results = []
    label = label_map[label_name]
    
    # Extract up to 20 seconds (as per paper)
    max_samples = min(len(data), 4000)  # 20 seconds at 200Hz
    
    # Non-overlapping 1-second windows
    for i in range(0, max_samples - 200, 200):
        segment = data[i:i + 200]
        if len(segment) == 200:
            results.append((segment, label))
    
    return results

print("‚úÖ Feature extraction functions defined (CORRECTED)")

# %% [markdown]
## 8. Process KFall Dataset

# %%
def process_kfall_dataset():
    """Process all KFall data"""
    print("="*80)
    print("PROCESSING KFALL DATASET")
    print("="*80)
    
    X_data = []
    y_labels = []
    
    subjects = sorted([d for d in kfall_sensor_dir.iterdir() if d.is_dir()])
    print(f"Found {len(subjects)} subjects")
    
    for subject_dir in tqdm(subjects, desc="Processing KFall subjects"):
        files = list(subject_dir.glob("*.csv"))
        
        for file in files:
            # Extract activity code from filename: S06T10R01.csv -> T10
            filename = file.stem
            if len(filename) < 6:
                continue
            activity_code = filename[3:6]  # e.g., T10, T28
            
            # Load and upsample
            data = load_kfall_file(file)
            if data is None or len(data) < 100:
                continue
            
            try:
                data_upsampled = upsample_to_200hz(data)
                
                # Process based on activity type
                if activity_code in kfall_fall_activities:
                    features = process_fall_activity(data_upsampled)
                elif activity_code in kfall_stumble:
                    features = process_stumble_activity(data_upsampled)
                else:
                    continue
                
                for segment, label in features:
                    if segment.shape == (200, 6):
                        X_data.append(segment)
                        y_labels.append(label)
            except Exception as e:
                print(f"  Error processing {file.name}: {e}")
                continue
    
    return np.array(X_data), np.array(y_labels)

# Run KFall processing
print("\n" + "="*80)
X_kfall, y_kfall = process_kfall_dataset()
print(f"\n‚úÖ KFall processed:")
print(f"   X shape: {X_kfall.shape}")
print(f"   y shape: {y_kfall.shape}")

# %% [markdown]
## 9. Process SisFall Dataset

# %%
def process_sisfall_dataset():
    """Process all SisFall data"""
    print("="*80)
    print("PROCESSING SISFALL DATASET")
    print("="*80)
    
    X_data = []
    y_labels = []
    
    subjects = sorted([d for d in sisfall_dir.iterdir() 
                      if d.is_dir() and (d.name.startswith('SA') or d.name.startswith('SE'))])
    print(f"Found {len(subjects)} subjects")
    
    for subject_dir in tqdm(subjects, desc="Processing SisFall subjects"):
        files = list(subject_dir.glob("*.txt"))
        
        for file in files:
            # Extract activity code: D01_SA01_R01.txt -> D01
            filename = file.stem
            parts = filename.split('_')
            if len(parts) < 2:
                continue
            activity_code = parts[0]
            
            # Load data
            data = load_sisfall_file(file)
            if data is None or len(data) < 200:
                continue
            
            try:
                # Process based on activity type
                if activity_code in sisfall_adl_map:
                    label_name = sisfall_adl_map[activity_code]
                    features = process_adl_activity(data, label_name)
                elif activity_code in sisfall_falls:
                    features = process_fall_activity(data)
                else:
                    continue
                
                for segment, label in features:
                    if segment.shape == (200, 6):
                        X_data.append(segment)
                        y_labels.append(label)
            except Exception as e:
                # Silently skip problematic files
                continue
    
    return np.array(X_data), np.array(y_labels)

# Run SisFall processing
print("\n" + "="*80)
X_sisfall, y_sisfall = process_sisfall_dataset()
print(f"\n‚úÖ SisFall processed:")
print(f"   X shape: {X_sisfall.shape}")
print(f"   y shape: {y_sisfall.shape}")

# %% [markdown]
## 10. Z-Score Normalization and Dataset Fusion

# %%
def normalize_and_fuse(X_kfall, y_kfall, X_sisfall, y_sisfall):
    """
    Z-score normalization and dataset fusion (paper methodology)
    
    Paper says: "Z-score standardization was again employed before the 
    final data was fed into the network to normalize the features extracted 
    from the two datasets"
    """
    print("="*80)
    print("NORMALIZING AND FUSING DATASETS")
    print("="*80)
    
    # Reshape for normalization
    n_kfall, ts, feat = X_kfall.shape
    X_kfall_flat = X_kfall.reshape(-1, feat)
    
    n_sisfall = X_sisfall.shape[0]
    X_sisfall_flat = X_sisfall.reshape(-1, feat)
    
    print(f"\nStep 1: Normalize each dataset separately")
    # Normalize KFall
    scaler_kfall = StandardScaler()
    X_kfall_norm = scaler_kfall.fit_transform(X_kfall_flat)
    X_kfall_norm = X_kfall_norm.reshape(n_kfall, ts, feat)
    print(f"  KFall normalized: {X_kfall_norm.shape}")
    
    # Normalize SisFall
    scaler_sisfall = StandardScaler()
    X_sisfall_norm = scaler_sisfall.fit_transform(X_sisfall_flat)
    X_sisfall_norm = X_sisfall_norm.reshape(n_sisfall, ts, feat)
    print(f"  SisFall normalized: {X_sisfall_norm.shape}")
    
    print(f"\nStep 2: Fuse datasets")
    # Fuse
    X_fused = np.concatenate([X_kfall_norm, X_sisfall_norm], axis=0)
    y_fused = np.concatenate([y_kfall, y_sisfall], axis=0)
    print(f"  Fused dataset: {X_fused.shape}")
    
    print(f"\nStep 3: Normalize fused dataset")
    # Final normalization
    X_fused_flat = X_fused.reshape(-1, feat)
    scaler_final = StandardScaler()
    X_fused_norm = scaler_final.fit_transform(X_fused_flat)
    X_fused_norm = X_fused_norm.reshape(-1, ts, feat)
    print(f"  Final normalized: {X_fused_norm.shape}")
    
    # Verify normalization
    print(f"\nVerification:")
    print(f"  Mean: {X_fused_norm.mean():.6f} (should be ~0)")
    print(f"  Std:  {X_fused_norm.std():.6f} (should be ~1)")
    
    return X_fused_norm, y_fused, scaler_final

# Normalize and fuse
X_final, y_final, scaler = normalize_and_fuse(X_kfall, y_kfall, X_sisfall, y_sisfall)

print(f"\n‚úÖ Final fused dataset:")
print(f"   X shape: {X_final.shape}")
print(f"   y shape: {y_final.shape}")

# %% [markdown]
## 11. Verify Data Quality

# %%
from collections import Counter

print("="*80)
print("DATA QUALITY VERIFICATION")
print("="*80)

# 1. Class distribution
counts = Counter(y_final)
print("\n1. Class Distribution:")
for cls_id in sorted(counts.keys()):
    count = counts[cls_id]
    pct = count / len(y_final) * 100
    print(f"   {cls_id}: {reverse_label_map[cls_id]:30s} - {count:5d} ({pct:5.2f}%)")

print(f"\n   Total samples: {len(y_final)}")

# 2. CRITICAL: Variance test
print("\n2. Variance Test (Acc-Y axis):")
print(f"   {'Class':<35s} {'Variance':<12s}")
print(f"   {'-'*50}")

variances = []
for cls_id in sorted(counts.keys()):
    class_samples = X_final[y_final == cls_id]
    var = class_samples[:, :, 1].var()
    variances.append((reverse_label_map[cls_id], var))
    print(f"   {reverse_label_map[cls_id]:<35s} {var:>10.4f}")

# Sort by variance
variances_sorted = sorted(variances, key=lambda x: x[1], reverse=True)
print(f"\n3. Variance Ranking:")
for i, (name, var) in enumerate(variances_sorted, 1):
    print(f"   {i}. {name:<35s}: {var:.4f}")

# Check if Fall_Initiation is in top 2
fall_init_rank = next(i for i, (name, _) in enumerate(variances_sorted, 1) if name == 'Fall_Initiation')

if fall_init_rank <= 2:
    print(f"\n‚úÖ PASS: Fall_Initiation ranked #{fall_init_rank} (should be #1 or #2)")
else:
    print(f"\n‚ùå FAIL: Fall_Initiation ranked #{fall_init_rank} (should be #1 or #2)")
    print("   Segmentation may be incorrect!")

# 4. Check for NaN/Inf
print(f"\n4. Data Integrity:")
print(f"   NaN values: {np.isnan(X_final).sum()}")
print(f"   Inf values: {np.isinf(X_final).sum()}")

# 5. Shape verification
print(f"\n5. Shape Verification:")
print(f"   Expected: (N, 200, 6)")
print(f"   Actual:   {X_final.shape}")
print(f"   ‚úÖ PASS" if X_final.shape[1:] == (200, 6) else "   ‚ùå FAIL")

# %% [markdown]
## 12. Visualize Samples

# %%
fig, axes = plt.subplots(4, 2, figsize=(16, 14))
axes = axes.flatten()

for class_id in range(8):
    class_indices = np.where(y_final == class_id)[0]
    if len(class_indices) > 0:
        sample_idx = class_indices[0]
        sample_data = X_final[sample_idx]
        
        time = np.arange(200) / 200
        axes[class_id].plot(time, sample_data[:, 0], label='Acc-X', alpha=0.7, linewidth=1)
        axes[class_id].plot(time, sample_data[:, 1], label='Acc-Y', alpha=0.7, linewidth=1)
        axes[class_id].plot(time, sample_data[:, 2], label='Acc-Z', alpha=0.7, linewidth=1)
        
        axes[class_id].set_title(f'Class {class_id}: {reverse_label_map[class_id]}', 
                                fontsize=11, fontweight='bold')
        axes[class_id].set_xlabel('Time (s)')
        axes[class_id].set_ylabel('Normalized Acc')
        axes[class_id].legend(loc='upper right', fontsize=8)
        axes[class_id].grid(True, alpha=0.3)

plt.suptitle('Sample Segments from Each Activity Class', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(processed_dir / 'sample_visualization.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úÖ Visualization saved to {processed_dir / 'sample_visualization.png'}")

# %% [markdown]
## 13. Save Processed Data

# %%
print("="*80)
print("SAVING PROCESSED DATA")
print("="*80)

# Save arrays
np.save(processed_dir / "X_data.npy", X_final)
np.save(processed_dir / "y_labels.npy", y_final)

# Save scaler
with open(processed_dir / "scaler.pkl", 'wb') as f:
    pickle.dump(scaler, f)

# Save label map (both formats)
with open(processed_dir / "label_map.pkl", 'wb') as f:
    pickle.dump(label_map, f)

with open(processed_dir / "label_map.json", 'w') as f:
    json.dump(label_map, f, indent=2)

print(f"\n‚úÖ Saved to {processed_dir}/")
print(f"   üìÑ X_data.npy: {X_final.shape}")
print(f"   üìÑ y_labels.npy: {y_final.shape}")
print(f"   üìÑ scaler.pkl")
print(f"   üìÑ label_map.pkl")
print(f"   üìÑ label_map.json")

# %% [markdown]
## 14. Final Summary

# %%
print("\n" + "="*80)
print("PREPROCESSING PIPELINE COMPLETE")
print("="*80)

summary = f"""
‚úÖ Successfully processed {len(y_final):,} samples

Dataset Breakdown:
  - KFall samples: {len(y_kfall):,}
  - SisFall samples: {len(y_sisfall):,}
  
Data Shape:
  - Features: {X_final.shape}
  - Labels: {y_final.shape}
  
Normalization:
  - Mean: {X_final.mean():.6f}
  - Std: {X_final.std():.6f}
  
Quality Check:
  - Fall_Initiation rank: #{fall_init_rank} (should be ‚â§2)
  - Status: {'‚úÖ READY FOR TRAINING' if fall_init_rank <= 2 else '‚ùå NEEDS REVIEW'}

Next Steps:
  1. Load data with: X = np.load('{processed_dir}/X_data.npy')
  2. Train FallNet model
  3. Evaluate on stratified K-fold cross-validation
"""

print(summary)

# %%

In [None]:
# Load the newly processed data
X_data = np.load(processed_dir / "X_data.npy")
y_labels = np.load(processed_dir / "y_labels.npy")

# ============================================================================
# STEP 1: Merge Impact and Aftermath
# ============================================================================
print("Merging Impact and Aftermath classes...")
y_labels[y_labels == 7] = 6  # Change Aftermath (7) to Impact (6)

# ============================================================================
# STEP 2: Remove Fall_Recovery (NEW!)
# ============================================================================
print("\n" + "="*80)
print("REMOVING FALL_RECOVERY CLASS")
print("="*80)

from collections import Counter

# Show before
counts_before = Counter(y_labels)
print(f"\nBefore removal:")
print(f"  Total samples: {len(y_labels):,}")
print(f"  Fall_Recovery (class 4): {counts_before[4]} samples")

# Remove Fall_Recovery (class 4)
mask = y_labels != 4
X_data = X_data[mask]
y_labels_temp = y_labels[mask]

removed_count = (~mask).sum()
print(f"\n‚úÖ Removed {removed_count} Fall_Recovery samples")

# Shift labels down (5‚Üí4, 6‚Üí5)
y_labels = y_labels_temp.copy()
y_labels[y_labels_temp > 4] -= 1  # Classes 5,6 become 4,5

print(f"\nAfter removal:")
print(f"  Total samples: {len(y_labels):,}")
print(f"  Removed: {removed_count} samples ({removed_count/(len(y_labels)+removed_count)*100:.2f}%)")

# ============================================================================
# STEP 3: Update label map (NOW 6 CLASSES: 0-5)
# ============================================================================
label_map = {
    'Walking': 0,
    'Jogging': 1,
    'Walking_stairs_updown': 2,
    'Stumble_while_walking': 3,
    'Fall_Initiation': 4,      # Was 5, now 4 ‚Üê SHIFTED DOWN!
    'Impact_Aftermath': 5,     # Was 6, now 5 ‚Üê SHIFTED DOWN!
}
reverse_label_map = {v: k for k, v in label_map.items()}

print(f"\n‚úÖ Updated to 6 classes (0-5):")
for name, idx in sorted(label_map.items(), key=lambda x: x[1]):
    print(f"  Class {idx}: {name}")

y_categorical = keras.utils.to_categorical(y_labels, num_classes=6)  # ‚Üê HERE!
print(f"y_categorical shape: {y_categorical.shape}")

# ============================================================================
# DIAGNOSTICS
# ============================================================================
print("\n" + "="*80)
print("POST-REMOVAL DATA DIAGNOSTICS")
print("="*80)

# 1. Class distribution
class_counts = Counter(y_labels)
print("\n1. Class Distribution (6 classes):")
for cls_idx in sorted(class_counts.keys()):
    count = class_counts[cls_idx]
    pct = count / len(y_labels) * 100
    print(f"   Class {cls_idx} ({reverse_label_map[cls_idx]:30s}): {count:5d} ({pct:5.2f}%)")

# Calculate imbalance
max_count = max(class_counts.values())
min_count = min(class_counts.values())
print(f"\nImbalance ratio: {max_count/min_count:.2f}x (was 36.8x with Fall_Recovery)")

# 2. Per-class signal statistics
print("\n2. Per-Class Signal Statistics (Acc-Y axis):")
print(f"   {'Class':<35s} {'Mean':<10s} {'Std':<10s} {'Min':<10s} {'Max':<10s}")
print(f"   {'-'*75}")
for cls_idx in sorted(class_counts.keys()):
    class_samples = X_data[y_labels == cls_idx]
    acc_y = class_samples[:, :, 1]  # Y-axis acceleration
    
    mean_val = acc_y.mean()
    std_val = acc_y.std()
    min_val = acc_y.min()
    max_val = acc_y.max()
    
    print(f"   {reverse_label_map[cls_idx]:<35s} {mean_val:>8.4f}  {std_val:>8.4f}  {min_val:>8.2f}  {max_val:>8.2f}")

# 3. Variance ranking
print("\n3. Variance Ranking (Fall_Initiation should be #1):")
variances = []
for cls_idx in sorted(class_counts.keys()):
    class_samples = X_data[y_labels == cls_idx]
    acc_y_var = class_samples[:, :, 1].var()
    variances.append((reverse_label_map[cls_idx], acc_y_var, cls_idx))
variances.sort(key=lambda x: x[1], reverse=True)
for i, (name, var, idx) in enumerate(variances, 1):
    print(f"   {i}. {name:<35s}: {var:.4f}")

# 4. Visualize samples (update to 6 classes)
fig, axes = plt.subplots(3, 2, figsize=(15, 10))
axes = axes.flatten()
critical_classes = [
    label_map['Walking'],
    label_map['Fall_Initiation'],
    label_map['Impact_Aftermath'],
    label_map['Stumble_while_walking'],
    label_map['Jogging'],
    label_map['Walking_stairs_updown']
]
for i, cls_idx in enumerate(critical_classes):
    if cls_idx in class_counts:
        sample_idx = np.where(y_labels == cls_idx)[0][0]
        sample_data = X_data[sample_idx]
        
        time = np.arange(200) / 200
        axes[i].plot(time, sample_data[:, 0], label='Acc-X', alpha=0.7, linewidth=1)
        axes[i].plot(time, sample_data[:, 1], label='Acc-Y', alpha=0.7, linewidth=1)
        axes[i].plot(time, sample_data[:, 2], label='Acc-Z', alpha=0.7, linewidth=1)
        
        axes[i].set_title(f'{reverse_label_map[cls_idx]}', fontsize=11, fontweight='bold')
        axes[i].set_xlabel('Time (s)')
        axes[i].set_ylabel('Normalized Acc')
        axes[i].legend(fontsize=8)
        axes[i].grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\n" + "="*80)
print("‚úÖ DATA READY FOR TRAINING (6 CLASSES)")
print("="*80)

In [None]:
# %% [markdown]
# # Save Preprocessed 6-Class Data
# Save the cleaned dataset after merging Impact/Aftermath and removing Fall_Recovery

# %%
import numpy as np
from pathlib import Path

# Setup paths (same as before)
base_dir = Path("~/repos/summerschool2023/projects/fall-detection/fall_detection_data").expanduser()
processed_dir = base_dir / "processed"

print("="*80)
print("SAVING PREPROCESSED 6-CLASS DATA")
print("="*80)

# Save the processed data
save_path_X = processed_dir / "X_data_6class.npy"
save_path_y = processed_dir / "y_labels_6class.npy"
save_path_y_cat = processed_dir / "y_categorical_6class.npy"

np.save(save_path_X, X_data)
np.save(save_path_y, y_labels)
np.save(save_path_y_cat, y_categorical)

print(f"\n‚úÖ Saved preprocessed data:")
print(f"   X_data:        {save_path_X}")
print(f"   y_labels:      {save_path_y}")
print(f"   y_categorical: {save_path_y_cat}")

print(f"\nSaved shapes:")
print(f"   X_data:        {X_data.shape}")
print(f"   y_labels:      {y_labels.shape}")
print(f"   y_categorical: {y_categorical.shape}")

# Also save the label mapping for future reference
label_map_path = processed_dir / "label_map_6class.npy"
np.save(label_map_path, label_map)
     # ‚úÖ Labels (numbers 0-5)

# Save the LABEL MAPPING (dictionary)
import json
with open(processed_dir / "label_map_6class.json", 'w') as f:
    json.dump(label_map, f, indent=2)                          # ‚úÖ Class names ‚Üí numbers
print(f"\n   label_map:     {label_map_path}")

print("\n" + "="*80)
print("‚úÖ ALL DATA SAVED SUCCESSFULLY")
print("="*80)
print("\nTo load this data in future notebooks:")
print("```python")
print("X_data = np.load(processed_dir / 'X_data_6class.npy')")
print("y_labels = np.load(processed_dir / 'y_labels_6class.npy')")
print("y_categorical = np.load(processed_dir / 'y_categorical_6class.npy')")
print("label_map = np.load(processed_dir / 'label_map_6class.npy', allow_pickle=True).item()")
print("```")

In [None]:
# %% [markdown]
# # FallNet Training Pipeline
# CNN-LMU ensemble for fall detection with 6 classes

# %%
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score
import warnings
warnings.filterwarnings('ignore')
from keras_lmu import LMU
print(f"TensorFlow version: {tf.__version__}")
print(f"GPU available: {tf.config.list_physical_devices('GPU')}")

# %% [markdown]
## 1. FallNet Model Architecture

# %%
class FallNet:
    """
    FallNet: CNN-LmU Ensemble for Pre-Impact Fall Detection
    """
    
    def __init__(self, input_shape=(200, 6), n_classes=6):
        """
        Args:
            input_shape: (timesteps, features) = (200, 6)
            n_classes: Number of output classes (6)
        """
        self.input_shape = input_shape
        self.n_classes = n_classes
        self.model = None
    
    def build_lmu_branch(self, inputs):
    # Leaner, more focused LMU
        x = LMU(
            memory_d=2,       # Increased for better feature separation
            order=64,          # Reduced for smoother temporal curves
            theta=200.0,
            hidden_cell=layers.LSTMCell(128), # Adds nonlinear processing
            kernel_regularizer=keras.regularizers.l2(1e-6), # L2 only
            dropout=0.2,       # Reduced
            name='lmu'
        )(inputs)

        # Wider Dense layer to interpret the memory
        x = layers.Dense(512, activation='relu', name='lmu_dense1')(x)
        x = layers.BatchNormalization()(x) # Speeds up conversion
        x = layers.Dropout(0.3)(x)
    
        x = layers.Dense(128, activation='relu', name='lmu_dense2')(x)
    
        lmu_output = layers.Dense(self.n_classes, activation='softmax')(x)
        return lmu_output
        

    def build_lstm_branch(self, inputs):
        """LSTM Branch"""
        x = layers.LSTM(
            units=256,
            activation='tanh',
            return_sequences=False,
            name='lstm_layer'
        )(inputs)
        
        x = layers.Dense(128, activation='relu', name='lstm_dense1')(x)
        x = layers.Dropout(0.2, name='lstm_dropout1')(x)
        
        x = layers.Dense(64, activation='relu', name='lstm_dense2')(x)
        x = layers.Dropout(0.2, name='lstm_dropout2')(x)
        
        x = layers.Dense(32, activation='relu', name='lstm_dense3')(x)
        x = layers.Dropout(0.2, name='lstm_dropout3')(x)
        
        lstm_output = layers.Dense(
            self.n_classes, 
            activation='softmax',
            name='lstm_output'
        )(x)
        
        return lstm_output
    
    def build_cnn_branch(self, inputs):
        """CNN Branch"""
        x = layers.Conv1D(
            filters=128,
            kernel_size=3,
            activation='relu',
            padding='same',
            name='conv1d_layer'
        )(inputs)
        
        x = layers.MaxPooling1D(pool_size=2, name='maxpool_layer')(x)
        
        x = layers.GlobalAveragePooling1D()(x)
        x = layers.Dropout(0.2, name='cnn_dropout1')(x)
        
        x = layers.Dense(64, activation='relu', name='cnn_dense2')(x)
        x = layers.Dropout(0.2, name='cnn_dropout2')(x)
        
        cnn_output = layers.Dense(
            self.n_classes,
            activation='softmax',
            name='cnn_output'
        )(x)
        
        return cnn_output

    def build_cnn_only(self):
        """Build CNN-only model (no temporal component)"""
        inputs = layers.Input(shape=self.input_shape, name='input')
        cnn_output = self.build_cnn_branch(inputs)
        
        self.model = models.Model(
            inputs=inputs,
            outputs=cnn_output,
            name='FallNet_CNN_Only'
        )
        return self.model
    
    def build_lstm_only(self):
        """Build LSTM-only model (temporal encoding via gates)"""
        inputs = layers.Input(shape=self.input_shape, name='input')
        lstm_output = self.build_lstm_branch(inputs)
        
        self.model = models.Model(
            inputs=inputs,
            outputs=lstm_output,
            name='FallNet_LSTM_Only'
        )
        return self.model
    
    def build_lmu_only(self):
        """Build LMU-only model (temporal encoding via Legendre polynomials)"""
        inputs = layers.Input(shape=self.input_shape, name='input')
        lmu_output = self.build_lmu_branch(inputs)
        
        self.model = models.Model(
            inputs=inputs,
            outputs=lmu_output,
            name='FallNet_LMU_Only'
        )
        return self.model
    
    def build_ensemble(self):
        """Build the complete ensemble model"""
        inputs = layers.Input(shape=self.input_shape, name='input')
        
        lmu_output = self.build_lmu_branch(inputs)
        cnn_output = self.build_cnn_branch(inputs)
        
        ensemble_output = layers.Average(name='ensemble_average')([lmu_output, cnn_output])
        
        self.model = models.Model(
            inputs=inputs,
            outputs=ensemble_output,
            name='FallNet_CNN_LSTM'
        )
        
        return self.model
    
    def compile_model(self, learning_rate=None):
        """Compile model"""
        if self.model is None:
            raise ValueError("Model not built yet. Call build_ensemble() first.")
        
        optimizer = keras.optimizers.Adam(learning_rate=learning_rate) if learning_rate else keras.optimizers.Adam()
        
        self.model.compile(
            optimizer=keras.optimizers.Adam(learning_rate=5e-4),  # 100x smaller
            loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
            metrics=[
                'accuracy', 
                keras.metrics.Precision(name='precision'),
                keras.metrics.Recall(name='recall')
            ]
        )
        
        return self.model

print("‚úÖ FallNet class defined")

# %% [markdown]
## 2. Build and Display Model

# %%
print("\n" + "="*80)
print("BUILDING FALLNET MODEL")
print("="*80)

with tf.device('/CPU:0'):
    fallnet = FallNet(input_shape=(200, 6), n_classes=6) # Updated to 7 classes!
    model = fallnet.build_ensemble()


# Compile
model = fallnet.compile_model()

# Display architecture
print("\n")
model.summary()

# Count parameters
def count_parameters(model):
    trainable = np.sum([np.prod(v.shape) for v in model.trainable_weights])
    non_trainable = np.sum([np.prod(v.shape) for v in model.non_trainable_weights])
    return trainable, non_trainable

trainable, non_trainable = count_parameters(model)

print("\n" + "="*80)
print("MODEL PARAMETERS")
print("="*80)
print(f"Trainable:     {trainable:,}")
print(f"Non-trainable: {non_trainable:,}")
print(f"Total:         {trainable + non_trainable:,}")

# %% [markdown]
## 3. Training Configuration

# %%
BATCH_SIZE = 128
EPOCHS = 50
K_FOLDS = 5

print("\n" + "="*80)
print("TRAINING CONFIGURATION")
print("="*80)
print(f"Batch size: {BATCH_SIZE}")
print(f"Max epochs: {EPOCHS}")
print(f"K-Folds:    {K_FOLDS}")
print(f"Using data from previous cell (6 classes, {len(y_labels):,} samples)")

# %% [markdown]
## 4. Verify Data Before Training

# %%
print("\n" + "="*80)
print("PRE-TRAINING VERIFICATION")
print("="*80)

print(f"‚úÖ Data shapes:")
print(f"   X_data:        {X_data.shape}")
print(f"   y_labels:      {y_labels.shape}")
print(f"   y_categorical: {y_categorical.shape}")
print(f"\n‚úÖ Classes: {len(np.unique(y_labels))} (should be 6)")
print(f"‚úÖ Label range: {y_labels.min()}-{y_labels.max()} (should be 0-5)")
print(f"‚úÖ Model output: {model.output_shape[-1]} (should be 6)")

assert X_data.shape[0] == y_labels.shape[0] == y_categorical.shape[0], "Shape mismatch!"
assert len(np.unique(y_labels)) == 6, "Should have 6 classes!"
assert y_labels.max() == 5, "Max label should be 5!"
assert model.output_shape[-1] == 6, "Model should output 6 classes!"

print("\n‚úÖ All checks passed - ready to train!")

# %% [markdown]
## 5. K-Fold Cross-Validation Training

# %%
skf = StratifiedKFold(n_splits=K_FOLDS, shuffle=True, random_state=42)

fold_results = []
fold_histories = []

print("\n" + "="*80)
print("STARTING K-FOLD CROSS-VALIDATION")
print("="*80)

for fold, (train_idx, val_idx) in enumerate(skf.split(X_data, y_labels), 1):
    print(f"\n{'='*80}")
    print(f"FOLD {fold}/{K_FOLDS}")
    print(f"{'='*80}")
    
    # Split data
    X_train, X_val = X_data[train_idx], X_data[val_idx]
    y_train, y_val = y_categorical[train_idx], y_categorical[val_idx]
    
    print(f"Train: {X_train.shape[0]:,} samples | Val: {X_val.shape[0]:,} samples")
    
    # Build fresh model for this fold
    fallnet_fold = FallNet(input_shape=(200, 6), n_classes=6)
    model_fold = fallnet_fold.build_lmu_only()
    model_fold = fallnet_fold.compile_model()
    
    # Define callbacks for THIS fold
    fold_callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=20,
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=10,
            min_lr=1e-7,
            verbose=1
        ),
        ModelCheckpoint(
            filepath=str(output_dir / f'fallnet_lmu_boosted_fold_{fold}.keras'),
            monitor='val_accuracy',
            save_best_only=True,
            mode='max',
            verbose=1
        )
    ]
    # %% [markdown]
## 5.5 Calculate Class Weights

# %%
    from sklearn.utils.class_weight import compute_class_weight

    print("\n" + "="*80)
    print("CALCULATING CLASS WEIGHTS")
    print("="*80)

# Calculate balanced weights
    class_weights_array = compute_class_weight(
        class_weight='balanced',
        classes=np.unique(y_labels),
        y=y_labels
    )
# 2. Convert to dictionary
    class_weights = dict(enumerate(class_weights_array))

    # 3. APPLY SURGICAL BOOSTS
    # We are manually overriding specific classes to hit 90%
    class_weights[0] *= 1.5  # Boost Walking
    class_weights[3] *= 3.0  # Triple boost Stumbles (The "Problem Child")
    class_weights[4] *= 1.2  # Nudge Fall Initiation (Safety first)
    
    # 4. Cap weights to prevent instability
    MAX_WEIGHT = 5.0 
    for k in class_weights:
        class_weights[k] = min(class_weights[k], MAX_WEIGHT)

    print("\nTargeted Class Weights:")
    for cls_idx in range(6):
        print(f"  {reverse_label_map[cls_idx]:<30s}: {class_weights[cls_idx]:.2f}x")

    print(f"\nTraining fold {fold}...")
    history = model_fold.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        batch_size=BATCH_SIZE,
        epochs=EPOCHS,
        class_weight=class_weights,  # ‚Üê ADD THIS LINE!
        callbacks=fold_callbacks,
        verbose=1
    )
    
    # Evaluate
    val_loss, val_acc, val_precision, val_recall = model_fold.evaluate(X_val, y_val, batch_size=2, verbose=0)
    val_f1 = 2 * (val_precision * val_recall) / (val_precision + val_recall) if (val_precision + val_recall) > 0 else 0
    
    print(f"\n{'='*50}")
    print(f"Fold {fold} Results:")
    print(f"{'='*50}")
    print(f"Loss:      {val_loss:.4f}")
    print(f"Accuracy:  {val_acc:.4f}")
    print(f"Precision: {val_precision:.4f}")
    print(f"Recall:    {val_recall:.4f}")
    print(f"F1-Score:  {val_f1:.4f}")
    
    # Store results
    fold_results.append({
        'fold': fold,
        'val_loss': val_loss,
        'val_accuracy': val_acc,
        'val_precision': val_precision,
        'val_recall': val_recall,
        'val_f1': val_f1
    })
    
    fold_histories.append(history.history)
    
    print(f"‚úÖ Model saved: fallnet_boosted_lmu_fold_{fold}.keras")

print("\n" + "="*80)
print("K-FOLD CROSS-VALIDATION COMPLETE")
print("="*80)

# %% [markdown]
## 6. Aggregate Results

# %%
results_df = pd.DataFrame(fold_results)

print("\n" + "="*80)
print("RESULTS ACROSS ALL FOLDS")
print("="*80)
print(results_df.to_string(index=False))

print("\n" + "="*80)
print("AVERAGE PERFORMANCE ¬± STD")
print("="*80)

mean_results = results_df.mean(numeric_only=True)
std_results = results_df.std(numeric_only=True)

metrics_table = []
for metric in ['val_loss', 'val_accuracy', 'val_precision', 'val_recall', 'val_f1']:
    metrics_table.append({
        'Metric': metric,
        'Mean': f"{mean_results[metric]:.4f}",
        'Std': f"¬±{std_results[metric]:.4f}"
    })

metrics_df = pd.DataFrame(metrics_table)
print(metrics_df.to_string(index=False))

# %% [markdown]
## 7. Visualize Training History

# %%
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

metrics = [
    ('loss', 'Loss'),
    ('accuracy', 'Accuracy'),
    ('precision', 'Precision'),
    ('recall', 'Recall')
]

for idx, (metric, title) in enumerate(metrics):
    ax = axes[idx // 2, idx % 2]
    
    for fold, history in enumerate(fold_histories, 1):
        epochs = range(1, len(history[metric]) + 1)
        ax.plot(epochs, history[metric], label=f'Fold {fold} Train', alpha=0.5, linewidth=1)
        ax.plot(epochs, history[f'val_{metric}'], label=f'Fold {fold} Val', 
                linestyle='--', alpha=0.7, linewidth=1.5)
    
    ax.set_title(f'{title} Across All Folds', fontsize=13, fontweight='bold')
    ax.set_xlabel('Epoch', fontsize=11)
    ax.set_ylabel(title, fontsize=11)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=7)
    ax.grid(True, alpha=0.3)

plt.suptitle('FallNet Training History - 5-Fold Cross-Validation', 
             fontsize=15, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig(output_dir / 'training_history.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úÖ Training history saved to {output_dir / 'training_history.png'}")

# %% [markdown]
## 8. Detailed Evaluation on Best Fold

# %%
best_fold = int(results_df.loc[results_df['val_f1'].idxmax(), 'fold'])

print("\n" + "="*80)
print(f"DETAILED EVALUATION - BEST FOLD #{best_fold}")
print("="*80)
print(f"Best fold F1-Score: {results_df.loc[results_df['fold']==best_fold, 'val_f1'].values[0]:.4f}")

# Load best model
best_model = keras.models.load_model(output_dir / f'fallnet_boosted_lmu_fold_{fold}.keras')

# Get predictions on ALL data
y_pred_probs = best_model.predict(X_data, verbose=0)
y_pred = np.argmax(y_pred_probs, axis=1)

# Classification report
class_names = [reverse_label_map[i] for i in range(6)]

print("\n" + "="*80)
print("CLASSIFICATION REPORT (Best Fold on All Data)")
print("="*80)
print(classification_report(y_labels, y_pred, target_names=class_names, digits=4))

# %% [markdown]
## 9. Per-Class Detailed Metrics

# %%
print("\n" + "="*80)
print("PER-CLASS DETAILED METRICS")
print("="*80)

print(f"\n{'Class':<40s} {'Precision':<12s} {'Recall':<12s} {'F1-Score':<12s} {'Support'}")
print("-"*90)

for cls_idx in range(6):
    precision = precision_score(y_labels == cls_idx, y_pred == cls_idx, zero_division=0)
    recall = recall_score(y_labels == cls_idx, y_pred == cls_idx, zero_division=0)
    f1 = f1_score(y_labels == cls_idx, y_pred == cls_idx, zero_division=0)
    support = np.sum(y_labels == cls_idx)
    
    print(f"{reverse_label_map[cls_idx]:<40s} {precision:<12.4f} {recall:<12.4f} {f1:<12.4f} {support}")

# %% [markdown]
## 10. Confusion Matrix

# %%
cm = confusion_matrix(y_labels, y_pred)

plt.figure(figsize=(12, 10))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=class_names,
    yticklabels=class_names,
    cbar_kws={'label': 'Count'}
)
plt.title('Confusion Matrix - Best Fold (6 Classes)', fontsize=15, fontweight='bold', pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=45, ha='right', fontsize=10)
plt.yticks(rotation=0, fontsize=10)
plt.tight_layout()
plt.savefig(output_dir / 'confusion_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úÖ Confusion matrix saved to {output_dir / 'confusion_matrix.png'}")

# %% [markdown]
## 11. Final Summary

# %%
# Get Fall_Initiation metrics
fall_init_idx = label_map["Fall_Initiation"]
fall_init_precision = precision_score(y_labels == fall_init_idx, y_pred == fall_init_idx)
fall_init_recall = recall_score(y_labels == fall_init_idx, y_pred == fall_init_idx)
fall_init_f1 = f1_score(y_labels == fall_init_idx, y_pred == fall_init_idx)

print("\n" + "="*80)
print("TRAINING COMPLETE - FINAL SUMMARY")
print("="*80)

summary = f"""
‚úÖ Successfully trained FallNet with 5-fold cross-validation

Configuration:
  - Model: CNN-LSTM Ensemble (6 classes)
  - Total samples: {len(y_labels):,}
  - Training samples per fold: ~{len(y_labels)*0.8//K_FOLDS:,.0f}
  - Validation samples per fold: ~{len(y_labels)*0.2//K_FOLDS:,.0f}

Average Performance (5-fold CV):
  - Accuracy:  {mean_results['val_accuracy']:.4f} ¬± {std_results['val_accuracy']:.4f}
  - Precision: {mean_results['val_precision']:.4f} ¬± {std_results['val_precision']:.4f}
  - Recall:    {mean_results['val_recall']:.4f} ¬± {std_results['val_recall']:.4f}
  - F1-Score:  {mean_results['val_f1']:.4f} ¬± {std_results['val_f1']:.4f}

Fall_Initiation Performance (Critical Class):
  - Recall (Sensitivity): {fall_init_recall:.4f}
  - F1-Score:             {fall_init_f1:.4f}

Saved Files:
  - Training history:    {output_dir / 'training_history.png'}
  - Confusion matrix:    {output_dir / 'confusion_matrix.png'}
  - Best model:          {output_dir / f'fallnet_fold_{best_fold}.keras'}
  - All fold models:     {output_dir / 'fallnet_fold_*.keras'}
"""
uv
print(summary)

with open(output_dir / 'training_summary.txt', 'w') as f:
    f.write(summary)

print(f"‚úÖ Summary saved to {output_dir / 'training_summary.txt'}")