In [None]:
import os
import numpy as np
import pandas as pd
from xml.etree import ElementTree as ET
from statistics import mean
from tqdm import tqdm

ROOT_DIR = "."
L_FOOT_ID, R_FOOT_ID = 21, 17

def extract_foot_contacts_simple(mvnx_path):
    try:
        tree = ET.parse(mvnx_path)
        root = tree.getroot()
        rows = []
        for frame in root.findall('.//{http://www.xsens.com/mvn/mvnx}frame'):
            fc = frame.find('.//{http://www.xsens.com/mvn/mvnx}footContacts')
            if fc is not None and fc.text:
                rows.append(list(map(int, fc.text.strip().split())))
        return np.array(rows) if rows else np.empty((0, 4), dtype=int)
    except Exception as e:
        print(f"⚠️  Contact‑parse error in {mvnx_path}: {e}")
        return np.empty((0, 4), dtype=int)

def extract_segment_positions_and_time(mvnx_path):
    try:
        tree = ET.parse(mvnx_path)
        root = tree.getroot()
        ns_uri = root.tag.split('}')[0].strip('{')
        ns = {'ns': ns_uri}
        pos_rows, t_rows = [], []
        for fr in root.findall('.//ns:frame', ns):
            pos_el = fr.find('ns:position', ns)
            t_ms = fr.get('time')
            if pos_el is not None and pos_el.text and t_ms:
                pos_rows.append(list(map(float, pos_el.text.split())))
                t_rows.append(float(t_ms) / 1000.0)
        return (
            np.array(pos_rows) if pos_rows else np.empty((0, 69)),
            np.array(t_rows) if t_rows else np.empty(0)
        )
    except Exception as e:
        print(f"⚠️  Position/time parse error in {mvnx_path}: {e}")
        return np.empty((0, 69)), np.empty(0)

def detect_heel_strikes(contacts):
    # Returns all heel strikes for both feet, as (index, foot) tuples
    events = []
    for foot_col, foot_name in [(contacts[:,0], "L"), (contacts[:,2], "R")]:
        if foot_col.size < 2:
            continue
        strikes = np.where((foot_col[1:] == 1) & (foot_col[:-1] == 0))[0] + 1
        events.extend([(idx, foot_name) for idx in strikes])
    return sorted(events, key=lambda x: x[0])


def detect_toe_offs(contacts):
    events = []
    for col_idx, foot in [(1, "L"), (3, "R")]:  # Toe off = 1 → 0
        foot_col = contacts[:, col_idx]
        if foot_col.size < 2:
            continue
        toe_offs = np.where((foot_col[1:] == 0) & (foot_col[:-1] == 1))[0] + 1
        events.extend([(idx, foot) for idx in toe_offs])
    return sorted(events, key=lambda x: x[0])



def robust_mean_width(widths, minval=0.03, maxval=0.3):
    filtered = [w for w in widths if minval < w < maxval]
    return np.mean(filtered) if len(filtered) >= 2 else None

def compute_support_phases_from_strides(contacts, heel_strikes):
    if len(heel_strikes) < 4:
        return None, None, None

    ds_vals, ssL_vals, ssR_vals = [], [], []

    for i in range(len(heel_strikes) - 2):
        idx1, foot1 = heel_strikes[i]
        idx2, foot2 = heel_strikes[i + 2]

        if foot1 != foot2:
            continue  # skip if not a full stride (same foot to same foot)

        window = contacts[idx1:idx2]
        if window.shape[0] < 2:
            continue

        L = window[:, 0]
        R = window[:, 2]
        total = len(window)

        ds = np.sum((L == 1) & (R == 1))
        ssL = np.sum((L == 1) & (R == 0))
        ssR = np.sum((L == 0) & (R == 1))

        ds_vals.append(100 * ds / total)
        ssL_vals.append(100 * ssL / total)
        ssR_vals.append(100 * ssR / total)

    return (
        round(np.mean(ds_vals), 2) if ds_vals else None,
        round(np.mean(ssL_vals), 2) if ssL_vals else None,
        round(np.mean(ssR_vals), 2) if ssR_vals else None
    )



def get_progression_unit_vector(pos, L_idx, R_idx):
    feet_traj_L = pos[:, L_idx:L_idx+2]
    feet_traj_R = pos[:, R_idx:R_idx+2]
    feet_traj = (feet_traj_L + feet_traj_R) / 2
    delta = feet_traj[-1] - feet_traj[0]
    norm = np.linalg.norm(delta)
    if norm == 0:
        return np.array([1, 0])
    return delta / norm

def project_point_onto_vector(p, origin, vec):
    vec = vec / np.linalg.norm(vec)
    rel = p[:2] - origin[:2]
    forward = np.dot(rel, vec)
    ortho_vec = np.array([-vec[1], vec[0]])
    lateral = np.dot(rel, ortho_vec)
    return forward, lateral

def compute_spatial_params_progression(events, pos):
    l_idx = L_FOOT_ID * 3
    r_idx = R_FOOT_ID * 3
    prog_vec = get_progression_unit_vector(pos, l_idx, r_idx)
    origin = (pos[0, l_idx:l_idx+2] + pos[0, r_idx:r_idx+2]) / 2
    fwd_lat_points = []
    for idx, foot in events:
        if foot == 'L':
            fwd_lat_points.append(project_point_onto_vector(pos[idx, l_idx:l_idx+2], origin, prog_vec))
        else:
            fwd_lat_points.append(project_point_onto_vector(pos[idx, r_idx:r_idx+2], origin, prog_vec))
    # Step length = distance between consecutive events
    step_lengths = [abs(fwd_lat_points[i+1][0] - fwd_lat_points[i][0]) for i in range(len(fwd_lat_points)-1)]
    # Stride length = every second event (one full gait cycle)
    stride_lengths = []
    for i in range(0, len(fwd_lat_points)-2, 2):
        if i+2 < len(fwd_lat_points):
            stride_lengths.append(abs(fwd_lat_points[i+2][0] - fwd_lat_points[i][0]))
    # Step width = orthogonal (lateral) distance between consecutive events
    step_widths = [abs(fwd_lat_points[i+1][1] - fwd_lat_points[i][1]) for i in range(len(fwd_lat_points)-1)]
    return step_lengths, stride_lengths, step_widths

def analyze_file(path):
    print(f"📂  Processing {path}")
    contacts = extract_foot_contacts_simple(path)
    pos, times = extract_segment_positions_and_time(path)
    if contacts.shape[0] < 2 or times.size < 2:
        print("   ⚠️  too few frames – skipped")
        return None
    if abs(contacts.shape[0] - times.size) > 10:
        m = min(contacts.shape[0], times.size)
        contacts, pos, times = contacts[:m], pos[:m], times[:m]

    # All heel strikes as (index, foot), sorted by index
    events = detect_heel_strikes(contacts)
    if len(events) < 2:
        print("   ⚠️  no heel strikes – skipped")
        return None

    # Store counts before trimming
    n_steps_total = len(events) - 1
    n_strides_total = max(0, (len(events) - 1) // 2)

    # Remove first and last 2 for stability
    trim = 2
    events_trim = events[trim:-trim] if len(events) > 2*trim else events
    n_steps_trim = len(events_trim) - 1
    n_strides_trim = max(0, (len(events_trim) - 1) // 2)

    # Extract indices and times for trimmed events
    trimmed_indices = [idx for idx, _ in events_trim]
    trimmed_times = [times[idx] for idx in trimmed_indices]

    # --- Temporal metrics (step/stride times) ---
    step_times = [trimmed_times[i+1] - trimmed_times[i] for i in range(len(trimmed_times)-1)]
    stride_times = []
    for i in range(0, len(trimmed_times)-2, 2):
        if i+2 < len(trimmed_times):
            stride_times.append(trimmed_times[i+2] - trimmed_times[i])

    mean_step_time = mean(step_times) if step_times else None
    mean_stride_time = mean(stride_times) if stride_times else None

    # --- Spatial metrics ---
    step_lengths, stride_lengths, step_widths = compute_spatial_params_progression(events_trim, pos)
    mean_step_length = mean(step_lengths) if step_lengths else None
    mean_stride_length = mean(stride_lengths) if stride_lengths else None
    mean_step_width = robust_mean_width(step_widths)

    gait_speed = (mean_stride_length / mean_stride_time) if (mean_stride_length and mean_stride_time) else None
    ds_pct, ssL_pct, ssR_pct = compute_support_phases_from_strides(contacts, events)
    trial_time = trimmed_times[-1] - trimmed_times[0] if len(trimmed_times) > 1 else None
    cadence = round(n_steps_trim / (trial_time / 60), 2) if trial_time and n_steps_trim > 0 else None

    results = {
        "n_steps_total": n_steps_total,
        "n_strides_total": n_strides_total,
        "n_steps_trimmed": n_steps_trim,
        "n_strides_trimmed": n_strides_trim,
        
        "gait_speed_mps": round(gait_speed, 3) if gait_speed else None,
        "cadence_spm": cadence,

        "step_time_mean_s": round(mean_step_time, 3) if mean_step_time else None,
        "step_time_sd_s": round(np.std(step_times), 3) if step_times else None,

        "stride_time_mean_s": round(mean_stride_time, 3) if mean_stride_time else None,
        "stride_time_sd_s": round(np.std(stride_times), 3) if stride_times else None,

        "step_length_mean_m": round(mean(step_lengths), 3) if step_lengths else None,
        "step_length_sd_m": round(np.std(step_lengths), 3) if step_lengths else None,

        "stride_length_mean_m": round(mean(stride_lengths), 3) if stride_lengths else None,
        "stride_length_sd_m": round(np.std(stride_lengths), 3) if stride_lengths else None,

        "step_width_mean_m_orth": round(mean_step_width, 3) if mean_step_width else None,
        "step_width_sd_m_orth": round(np.std(step_widths), 3) if step_widths else None,

        "double_support_pct": ds_pct,
        "single_support_L_pct": ssL_pct,
        "single_support_R_pct": ssR_pct,
    }

    return results

def main():
    all_rows = []
    for grp in tqdm(os.listdir(ROOT_DIR), desc="Groups"):
        p_grp = os.path.join(ROOT_DIR, grp)
        if not os.path.isdir(p_grp):
            continue
        for subj in tqdm(os.listdir(p_grp), desc=f"{grp} Participants", leave=False):
            p_subj = os.path.join(p_grp, subj)
            if not os.path.isdir(p_subj):
                continue
            gait_dir = next(
                (os.path.join(p_subj, d) for d in os.listdir(p_subj)
                 if d.lower() == "gait"),
                None
            )
            if not gait_dir:
                continue
            for pace in os.listdir(gait_dir):
                p_pace = os.path.join(gait_dir, pace)
                if not os.path.isdir(p_pace):
                    continue
                for fx in os.listdir(p_pace):
                    if not fx.lower().endswith(".mvnx"):
                        continue
                    full = os.path.join(p_pace, fx)
                    res = analyze_file(full)
                    if res is None:
                        continue
                    meta = {
                        "group": grp,
                        "participant": subj,
                        "pace_condition": pace.lower(),
                        "source_file": fx
                    }
                    all_rows.append({**meta, **res})
    if not all_rows:
        print("❌  No valid MVNX files processed.")
        return

    df = pd.DataFrame(all_rows)
    out_fn = "gait_analysis_global.csv"
    df.to_csv(out_fn, index=False, float_format="%.4f")
    print(f"\n✅  Exported {len(df)} rows → {out_fn}")

if __name__ == "__main__":
    main()
