# Gait Phase Feature Analysis

This notebook analyzes the EMG features distribution across the 3 gait phase classes:
- **Stance** (Class 0)
- **Swing** (Class 1)
- **None** (Class 2)

It also visualizes feature evolution over the continuous gait cycle.

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Add src to path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from lib.data_loader import Enabl3sDataLoader
from lib.preprocess import EMGPreprocessor

# Plot styling
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_context("notebook", font_scale=1.2)

In [None]:
PHASE_NAMES = {
    0: "Stance",
    1: "Swing", 
    2: "None"
}

DYNAMIC_MODES = [1, 2, 3]
STATIC_MODES = [0, 6]

EMG_CHANNELS = ['TA', 'MG', 'RF']
TARGET_FS = 250

In [None]:
def calculate_gait_phase_cycle(df):
    """Calculates a continuous gait phase (0.0 to 1.0)."""
    if 'Heel_Strike' not in df.columns or 'Toe_Off' not in df.columns:
        return np.zeros(len(df))
    
    phase_cycle = np.full(len(df), np.nan)
    valid_hs = np.where(df['Heel_Strike'].notna())[0]
    
    if len(valid_hs) < 2:
        return np.zeros(len(df))
    
    n_samples = len(df)
    sample_indices = np.arange(n_samples)
    next_hs_idx_positions = np.searchsorted(valid_hs, sample_indices, side='right')
    valid_mask = (next_hs_idx_positions > 0) & (next_hs_idx_positions < len(valid_hs))
    
    indices_to_calc = sample_indices[valid_mask]
    pos = next_hs_idx_positions[valid_mask]
    
    prev_hs = valid_hs[pos - 1]
    next_hs = valid_hs[pos]
    
    duration = next_hs - prev_hs
    duration[duration == 0] = 1 
    
    progress = (indices_to_calc - prev_hs) / duration
    phase_cycle[indices_to_calc] = progress
    
    return np.nan_to_num(phase_cycle)

def load_data(subject, data_root="../../data"):
    print(f"Loading {subject}...")
    loader = Enabl3sDataLoader(data_root, subject, target_fs=TARGET_FS)
    raw_dir = os.path.join(data_root, subject, 'Raw')
    
    circuit_files = [f for f in os.listdir(raw_dir) if f.endswith('_raw.csv')]
    circuits_to_load = range(1, len(circuit_files) + 1)
    
    load_cols = EMG_CHANNELS + ['Mode', 'Heel_Strike', 'Toe_Off']
    df = loader.load_dataset_batch(circuits_to_load, load_cols)
    
    if df.empty:
        return df
        
    # Filter modes
    df = df[df['Mode'].isin(DYNAMIC_MODES + STATIC_MODES)]
    
    # Preprocess
    preprocessor = EMGPreprocessor()
    df[EMG_CHANNELS] = preprocessor.apply_filter(df[EMG_CHANNELS].values)
    df[EMG_CHANNELS] = preprocessor.rectify(df[EMG_CHANNELS].values)
    
    # Add Gait Phase Cycle
    df['Gait_Phase_Cycle'] = df.groupby('Circuit_ID', group_keys=False).apply(
        lambda x: pd.Series(calculate_gait_phase_cycle(x), index=x.index)
    )
    
    # Add Class
    df['Phase_Class'] = 2
    is_dynamic = df['Mode'].isin(DYNAMIC_MODES)
    df.loc[is_dynamic & (df['Label_Phase'] == 1), 'Phase_Class'] = 0
    df.loc[is_dynamic & (df['Label_Phase'] == 0), 'Phase_Class'] = 1
    
    df['Phase_Name'] = df['Phase_Class'].map(PHASE_NAMES)
    
    return df

In [None]:
# Load Data
df = load_data("AB156")
print(f"Loaded {len(df)} samples")
df.head()

## 1. Class Distribution

In [None]:
plt.figure(figsize=(8, 4))
sns.countplot(data=df, x='Phase_Name', order=['Stance', 'Swing', 'None'])
plt.title("Sample Count by Phase")
plt.show()

## 2. Feature Distributions

In [None]:
feat_cols = EMG_CHANNELS # Using Rectified EMG directly as proxy for amplitude features
# Or we could extract windowed features, but plotting raw rectified is also useful to see amplitude envelopes.

plt.figure(figsize=(15, 5))
for i, ch in enumerate(EMG_CHANNELS):
    plt.subplot(1, 3, i+1)
    sns.boxplot(data=df, x='Phase_Name', y=ch, order=['Stance', 'Swing', 'None'])
    plt.title(f"{ch} Amplitude")
plt.tight_layout()
plt.show()

## 3. Gait Phase Cycle Analysis
Visualizing EMG activity over the gait cycle (0% -> 100%) for Dynamic Modes.

In [None]:
dynamic_df = df[df['Phase_Class'].isin([0, 1])].copy()
# Bin the cycle for smoother plots
dynamic_df['Cycle_Bin'] = pd.cut(dynamic_df['Gait_Phase_Cycle'], bins=50, labels=False)

plt.figure(figsize=(15, 6))
for i, ch in enumerate(EMG_CHANNELS):
    plt.subplot(1, 3, i+1)
    sns.lineplot(data=dynamic_df, x='Gait_Phase_Cycle', y=ch)
    plt.title(f"{ch} over Gait Cycle")
    plt.xlabel("Gait Cycle (0-1)")
plt.tight_layout()
plt.show()