In [8]:
import cv2
from rtmlib.tools import RTMPose, YOLOX
import numpy as np
from rtmlib import draw_skeleton,draw_bbox
from scipy.signal import savgol_filter

In [None]:
device = 'cuda'
backend = 'onnxruntime'

#settings
yolo_size = (640,640)
rtm_size = (288,384)
yolo_path = "/mnt/samsung_sata/Body_Proximation/Docker_RtmPose/yolox_onnx/yolox_l_8xb8-300e_humanart-ce1d7a62/end2end.onnx"
rtm_path = "/mnt/samsung_sata/Body_Proximation/Docker_RtmPose/rtmpose_onnx/rtmpose-m_simcc-body7_pt-body7-halpe26_700e-384x288-89e6428b_20230605/end2end.onnx"

det_model=YOLOX(onnx_model=yolo_path,model_input_size = yolo_size ,backend= backend,device = device)
rtm_model=RTMPose(onnx_model=rtm_path,model_input_size = rtm_size ,backend= backend,device = device)

In [None]:
ANTHROPOMETRIC_PARAMS = {
    "Head": {"mass_perc": 0.081, "com_perc_prox": 1.0}, # CoM at 100% from neck joint (top of neck)
    "Trunk": {"mass_perc": 0.497, "com_perc_prox": 0.5}, # Midpoint of hip-shoulder line
    "Thigh": {"mass_perc": 0.100, "com_perc_prox": 0.433}, # From hip
    "Shank": {"mass_perc": 0.0465, "com_perc_prox": 0.433}, # From knee
    "Foot": {"mass_perc": 0.0145, "com_perc_prox": 0.500}, # From ankle
    "UpperArm": {"mass_perc": 0.028, "com_perc_prox": 0.436}, # From shoulder
    "Forearm": {"mass_perc": 0.016, "com_perc_prox": 0.430}, 
}
GRAVITY = 9.81 

In [5]:
# HALPE-26 / body7 mapping from your printout
Keypoint_dict = {"NOSE":0, "L_EYE":1, "R_EYE":2, "L_EAR":3, "R_EAR":4, "L_SHOULDER":5, 
                 "R_SHOULDER":6, "L_ELBOW":7, "R_ELBOW":8, "L_WRIST":9, "R_WRIST":10,
                 "L_HIP":11, "R_HIP":12, "L_KNEE":13, "R_KNEE":14, "L_ANKLE":15, "R_ANKLE":16,
                 "HEAD":17, "NECK":18, "HIP_CENTER":19, "L_BIG_TOE":20, "R_BIG_TOE":21, "L_SMALL_TOE":22,
                 "R_SMALL_TOE":23, "L_HEEL":24, "R_HEEL":25}

In [None]:
#there will be keypoint and score. Keypoint will be in [x,y] and Score is 0.0-1.0 to show confidence of the prediction
def _xyc(kp):
    """Return (x,y,conf) for a keypoint that maybe [x,y] or [x,y,conf]"""
    if kp is None: return None
    a = np.asarray(kp)
    if a.size >= 3:
        return float(a[0]), float(a[1]), float(a[2])
    return float(a[0]), float(a[1]), 1.0

def _valid(kp, min_conf):
    """True if kp exists and (no conf field) or conf >= min_conf."""
    v = _xyc(kp)
    return (v is not None) and (v[2] >= min_conf)

def get_midpoint(p1,p2, min_conf=0.0):
    """Get the midpoint of two mid point of two vectors.
       Confidence is between 0.0-1.0 pose estimator or detector 
       and represents how certain the model is that this keypoint is in the right place
    """
    a = _xyc(p1) ; b = _xyc(p2)
    if a is None or b is None: return None
    if a[2] < min_conf or b[2] < min_conf: return None
    x = 0.5*(a[0]+b[0]); y = 0.5*(a[1]+b[1]); c = min(a[2], b[2])
    return np.array([x, y, c], dtype=np.float32)

def calculate_vector(p_start, p_end):
    """Vector from start to end using first two coords"""
    if p_start is None or p_end is None: return None
    a = np.asarray(p_start)[:2]; b = np.asarray(p_start)[:2] #[:2] is to remove the confidence if the data is parse with the confidence
    if a.size <2 or b.size<2: return None
    return b-a

def calculate_angle_between_vectors(vec1, vec2, min_confidence = 0.5):
    """Calculates angle in degrees between two vectors."""
    if vec1 is None or vec2 is None:
        return None
    #Add 1e-6  to prevent division by zero
    unit_vec1 = vec1 / (np.linalg.norm(vec1) + 1e-6)
    unit_vec2 = vec2 / (np.linalg.norm(vec2) + 1e-6)
    #normalize the vectors to get magnitude of 1
    dot_product = np.dot(unit_vec1, unit_vec2) 
    #Get the cosine of the two angles wjich ranges between -1 to 1
    angle_rad = np.arccos(np.clip(dot_product, -1.0, 1.0))
    #Get the angle
    return np.degrees(angle_rad)

def calculate_segment_angle_horizontal(segment_vector):
    """Calculates angle of a segment vector with the horizontal axis in degrees"""
    if segment_vector is None:
        return None
    return np.degrees(np.arctan2(segment_vector[1],segment_vector[0]))

def smooth_time_series(data, window_length=5, polyorder=2):
    """Smooths 1D time series data using Savitzky-Golay filter."""
    if len(data) < window_length:
        return data # Not enough data to smooth
    return savgol_filter(data, window_length, polyorder)

def numerical_derivative(data, dt, smooth=True):
    """Calcualates numerical derivative of 1D data."""
    if len(data) < 2:
        return np.array(0.0 * len(data))
    if smooth:
        data =  smooth_time_series(data)
    
    derivative = np.gradient(data, dt)
    return derivative

def analyze_angular_velocity(keypoint_sequence, segments_indices, fps, min_confidence=0.3):
    """
    Calculates angular velocity for a given segment over a sequence of frame.
    segment_indices: tuple(proximal_joint_idx, distal_joint_idx)
    Return a list of angular velocities in degrees/sec
    """
    prox_idx, dist_idx = segments_indices
    segment_angles = []

    #edit the varibles above and below might be easier to separate the data earlier so no need to parse data
    for kps_frame in keypoint_sequence:
        kp_prox = kps_frame[prox_idx]
        kp_dist = kps_frame[dist_idx]
        if kp_prox[2] > min_confidence and kp_dist[2]> min_confidence:
            segment_vector = calculate_vector(kp_prox, kp_dist)
            if segment_vector is not None:
                angle = calculate_segment_angle_horizontal(segment_vector)
                segment_angles.append(angle)
            else:
                segment_angles.append(np.nan)
        else:
            segment_angles.append(np.nan)
    
    segment_angles = np.array(segment_angles)
    
    # Handle NaNs by interpolation or carrying forward if needed, or keep as NaN
    # For simplicity, derivative of NaN will be NaN

    if len(segment_angles) < 2:
        return [0.0] * len(segment_angles)
    
    # Unwrap angles to handle jumps from -180 to 180 
    unwrapped_angles = np.unwrap(np.deg2rad(segment_angles)) #unwrap works on radians
    unwrapped_angles_deg = np.rad2deg(unwrapped_angles)

    dt = 1.0/fps
    angular_velocities = numerical_derivative(unwrapped_angles_deg, dt, smooth= True )
    return  angular_velocities.tolist()

def analyze_leg_thigh_angle(keypoints_frame, leg_indices, min_confidence=0.3):
    """
    Calculates the angle between the leg(shank) and thigh.
    Leg_indices: tuple(hip_idx, knee_idx, ankle_idx)
    Returns angle in degrees
    """
    hip_idx, knee_idx, ankle_idx = leg_indices
    kp_hip = keypoints_frame[hip_idx]
    kp_knee = keypoints_frame[knee_idx]
    kp_ankle = keypoints_frame[ankle_idx]

    if kp_hip[2] > min_confidence and kp_knee[2] > min_confidence and kp_[2] > min_confidence:
        thigh_vector = calculate_vector(kp_hip, kp_knee)
        shank_vector = calculate_vector(kp_knee, kp_ankle)

        if thigh_vector is not None and shank_vector is not None:
            angle = calculate_angle_between_vectors(thigh_vector, -shank_vector)
            return angle
    return None

def analyze_hip_movement(kp_sequence,fps, min_confidence=0.3):
    """
    Analyzes hip movement (displacement, velocity of the hip centre)
    Returns lists of displacement (m/frame, assuming pixel meter scale is 1 for now) and velocities (m/s).
    """
    hip_centers_x, hip_centers_y = [] , []
    for kps_frame in kp_sequence:
        l_hip = kps_frame[L_HIP_IDX]
        r_hip = kps_frame[R_HIP_IDX]
        if _valid(l_hip, min_confidence) and _valid(r_hip, min_confidence):
            mid = get_midpoint(l_hip, r_hip,min_confidence)
            if mid is not None:
                hip_centers_x.append(mid[0]); hip_centers_y.append(mid[1])
                continue
        hip_centers_x.append(np.nan); hip_centers_y.append(np.nan)

    #there ought to be similar numbers of x and y
    hip_centers_x = np.array(hip_centers_x)
    hip_centers_y = np.array(hip_centers_y)

    displacements = [0.0]
    velocities_x = [0.0] * len(hip_centers_x)
    velocities_y = [0.0] * len(hip_centers_y)
    speeds = [0.0] * len(hip_centers_x)

    if len(hip_centers_x) > 1:
        for i in range(1, len(hip_centers_x)):
            if not np.isnan(hip_centers_x[i]) and not np.isnan(hip_centers_x[i-1]):
                dx = hip_centers_x[i] - hip_centers_x[i-1]
                dy = hip_centers_y[i] - hip_centers_y[i-1]
                displacements.append(np.sqrt(dx**2 + dy**2))
            else:
                displacements.append(np.nan)

        dt = 1.0 / fps
        velocities_x = numerical_derivative(hip_centers_x, dt, smooth= True)
        velocities_y = numerical_derivative(hip_centers_y, dt, smooth= True)

        for i in range(len(velocities_x)):
            if not np.isnan(velocities_x[i]) and not np.isnan(velocities_y[i]):
                 speeds[i] = np.sqrt(velocities_x[i]**2 + velocities_y[i]**2)
            else:
                speeds[i] = np.nan
    
    return displacements, speeds, list(zip(hip_centers_x, hip_centers_y))


#NEED TO CHOOSE 1 LEg to analyze foor contact maybe one foot/ two foot/ wheeling or not wheeling
#analyse how other leg generates momementum
#Maybe have a condition that varible that analyze wheeling condition
""" Thinking how to implement cause slopes or uneven ground as wheels not detected by the model"""
def analyze_foot_contact(keypoints_frame, foot_indices, y_ground_estimate=None, min_confidence=0.3, vertical_threshold=10, velocity_threshold=0.1):
    """
    Analyzes foot contact status.
    foot_indices: (ankle_idx, heel_idx, big_toe_idx, small_toe_idx)
    y_ground_estimate: y-coordinate of the ground (higher y is lower on screen). If None, uses min foot y.
    Returns a string: "Heel Strike", "Toe Off", "Flat Foot", "Tiptoeing", "Swing", "Unknown".
    This is a heuristic and needs tuning. Assumes foot keypoints are available.
    """
    ankle_idx, heel_idx, toe_idx, _ = foot_indices # Assuming big_toe_idx is primary toe
    
    kp_ankle = keypoints_frame[ankle_idx]
    kp_heel = keypoints_frame[heel_idx]
    kp_toe = keypoints_frame[toe_idx]

    if not (kp_ankle[2] > min_confidence and kp_heel[2] > min_confidence and kp_toe[2] > min_confidence):
        return "Unknown (Low Confidence)"

    # Simple ground estimation if not provided
    if y_ground_estimate is None:
        y_ground_estimate = max(kp_ankle[1], kp_heel[1], kp_toe[1]) # Simplistic, assumes lowest point is ground for this frame

    # Check vertical positions relative to ground (smaller distance means closer to ground)
    heel_on_ground = abs(kp_heel[1] - y_ground_estimate) < vertical_threshold
    toe_on_ground = abs(kp_toe[1] - y_ground_estimate) < vertical_threshold
    ankle_high = kp_ankle[1] < kp_heel[1] - vertical_threshold and kp_ankle[1] < kp_toe[1] - vertical_threshold # Ankle significantly above heel/toe

    # Needs previous frame for velocity - this function is frame-wise, so velocity needs to be passed or calculated outside
    # For simplicity, this example won't use velocity directly for contact states but it's important for real apps.

    if heel_on_ground and not toe_on_ground and kp_toe[1] < kp_heel[1]: # Toe is up
        return "Heel Strike"
    elif toe_on_ground and not heel_on_ground and kp_heel[1] < kp_toe[1]: # Heel is up
        # Could be Toe Off (if previously flat/heel) or Tiptoeing
        if ankle_high and (kp_heel[1] < kp_toe[1] - vertical_threshold/2): # Heel significantly above toe, ankle high
             return "Tiptoeing"
        return "Toe Dominant" # Or pre-swing / late tiptoe
    elif heel_on_ground and toe_on_ground:
        return "Flat Foot"
    elif not heel_on_ground and not toe_on_ground:
        # Check if foot is considerably off ground
        if min(kp_heel[1], kp_toe[1]) < y_ground_estimate - 2 * vertical_threshold:
            return "Swing"
        else: # Close to ground but not clearly touching
            return "Near Ground"
            
    return "Indeterminate"


def analyze_upper_body_posture(keypoints_frame, min_confidence=0.3):
    """
    Analyzes upper body straightness and head position
    Returns torso_angle_vertical, ear_shoulder_hip_angle_left, ear_shoulder_hip_angle_right, head_center
    """
    # Upper Body Straightness
    l_shoulder = keypoints_frame[LEFT_SHOULDER_IDX]
    r_shoulder = keypoints_frame[RIGHT_SHOULDER_IDX]
    l_hip = keypoints_frame[LEFT_HIP_IDX]
    r_hip = keypoints_frame[RIGHT_HIP_IDX]
    l_ear = keypoints_frame[LEFT_EAR_IDX]
    r_ear = keypoints_frame[RIGHT_EAR_IDX]
    nose = keypoints_frame[NOSE_IDX]

    torso_angle_vertical = None
    esh_angle_l = None
    esh_angle_r = None
    head_center = None

    mid_shoulder = get_midpoint(l_shoulder, r_shoulder) if (l_shoulder[2]> min_confidence and r_shoulder[2]>min_confidence) else None
    mid_hip = get_midpoint(l_hip, r_hip) if (l_hip[2]>min_confidence and r_hip[2]>min_confidence) else None

    if mid_shoulder is not None and mid_hip is not None:
        torso_vector = calculate_vector(mid_hip, mid_shoulder) #vector from hip to shoulder
        if torso_vector is not None:
            # Angle with vertical (0, -1) vector (pointing up)
            vertical_ref_vector = np.array([0, 1])
            torso_angle_vertical = calculate_angle_between_vectors(torso_vector, vertical_ref_vector)

    # Ear-Shoulder-Hip angle (a proxy for neck/torso alignment)
    if l_ear[2]>min_confidence and l_shoulder[2]>min_confidence and l_hip[2]>min_confidence:
        vec_shoulder_ear_l = calculate_vector(l_shoulder, l_ear)
        vec_shoulder_hip_l = calculate_vector(l_shoulder, l_hip)
        esp_angle_l = calculate_angle_between_vectors(vec_shoulder_ear_l, vec_shoulder_hip_l)
    
    if r_ear[2]>min_confidence and r_shoulder[2]>min_confidence and r_hip[2]>min_confidence:
        vec_shoulder_ear_r = calculate_vector(r_shoulder, r_ear)
        vec_shoulder_hip_r = calculate_vector(r_shoulder, r_hip)
        esp_angle_r = calculate_angle_between_vectors(vec_shoulder_ear_r, vec_shoulder_hip_r)
    
    # Head Position (centroid of nose, ears)
    head_kps = [nose, l_ear , r_ear]
    valid_head_kps = [kp[:2] for kp in head_kps if kp[2] > min_confidence]
    if len(valid_head_kps) > 0:
        head_center = np.mean(valid_head_kps, axis=0)
    
    return torso_angle_vertical, esh_angle_l, esh_angle_r,head_center

def calculate_segment_com(kp_prox, kp_dist, com_percentage_from_proximal):
    """Calculates the center of mass for a segment"""
    if kp_prox is None or kp_dist is None or len(kp_prox) < 2 or len(kp_dist) < 2:
        return None
    segment_vector = np.array(kp_dist[:2]) - np.array(kp_prox[:2])
    com_loc = np.array(kp_prox[:2]) + com_percentage_from_proximal * segment_vector
    return com_loc

def calculate_total_body_com(keypoints_frame, total_body_mass=None, min_confidence = 0.3):
    """
    Returns normalized total-body CoM (x,y) in image pixel, or None.
    keypoint_frame: (K,2) or (K,3) array for a single person.
    """
    k = np.asarray(keypoints_frame)

    #compound landmarks
    mid_hip = get_midpoint(k[L_HIP, k[R_HIP],min_confidence])
    mid_shoulder = get_midpoint(k[L_SHOULDER],k[R_SHOULDER], min_confidence)
    #check if the body part data is available and accurate if it does not exist check the next valid body part to be used to check
    neck = k[NECK] if _valid(k[NECK],min_confidence) else (
        mid_shoulder if mid_shoulder is not None else 
        get_midpoint(k[L_EAR], k[R_EAR], min_confidence)
    )
    head_top = k[HEAD] if _valid(k[HEAD],min_confidence) else k[NOSE]

    segments = {
        "Head":         (neck, head_top),
        "Trunk":        (mid_hip, mid_shoulder),
        "L_Thigh":      (k[L_HIP], k[L_KNEE]),
        "R_Thigh":      (k[R_HIP], k[R_KNEE]),
        "L_Shank":      (k[L_KNEE], k[L_ANKLE]),
        "R_Shank":      (k[R_KNEE], k[R_ANKLE]),
        "L_Foot":       (k[L_HEEL], k[L_BIG_TOE]),
        "R_Foot":       (k[R_HEEL], k[R_BIG_TOE]),
        "L_UpperArm":   (k[L_SHOULDER], k[L_ELBOW]),
        "R_UpperArm":   (k[R_SHOULDER], k[R_ELBOW]),
        "L_Forearm":    (k[L_ELBOW], k[L_WRIST]),
        "R_Forearm":    (k[R_ELBOW],k[R_WRIST]),
        "L_Hand":       (k[L_WRIST], k[L_WRIST]),
        "R_Hand":       (k[R_WRIST], k[R_WRIST])
    }

    anthro_map = {
        "Head":"Head", "Trunk":"Trunk",
        "L_Thigh":"Thigh", "R_Thigh":"Thigh",
        "L_Shank":"Shank", "R_Shank":"Shank",
        "L_Foot":"Foot",   "R_Foot":"Foot",
        "L_UpperArm":"UpperArm", "R_UpperArm":"UpperArm",
        "L_Forearm":"Forearm",   "R_Forearm":"Forearm",
        "L_Hand":"Hand",         "R_Hand":"Hand",
    }

    weighted = np.array([0.0, 0.0], dtype=np.float64)
    mass_sum = 0.0

    for seg_name, (prox, dist) in segments.items():
        if not (_valid(prox, min_confidence)) and _valid(dist, min_confidence):
            continue
        params = ANTHROPOMETRIC_PARAMS[anthro_map[seg_name]]
        com_seg = calculate_segment_com(prox, dist, params["com_perc_prox"])
        if com_seg is None:
            continue
        weighted += com_seg *params["mass_perc"]
        mass_sum += params["mass_perc"]
    
    if mass_sum > 1e-6:
        return (weighted / mass_sum).astype(np.float32)
    return None

def analyze_momentum_generation(com_sequence, total_body_mass, fps):
    com_x = np.array([c[0] if c is not None else np.nan for c in com_sequence])
    com_y = np.array([c[1] if c is not None else np.nan for c in com_sequence])
    if len(com_x) < 3:
        return [0.0]*len(com_x), [0.0]*len(com_x), [0.0]*len(com_x)
    
    dt = 1.0/fps
    com_vx = numerical_derivative(com_x,dt, smooth= True)
    com_vy = numerical_derivative(com_y,dt, smooth= True)
    com_ax = numerical_derivative(com_vx, dt, smooth= True)
    com_ay = numerical_derivative(com_vy, dt, smooth= True)

    vgrf_proxy  = total_body_mass * (GRAVITY - com_ay)
    return com_ax.tolist(), com_ay.tolist(), vgrf_proxy.tolist()

def _iou_xyxy(a, b):
    # a,b: [x1,y1,x2,y2]
    xa1, ya1, xa2, ya2 = a
    xb1, yb1, xb2, yb2 = b
    inter  = max(0, min(xa2, xb2) - max(xa1, xb1)) * max(0, min(ya2, yb2) - max(ya1, yb1))
    if inter == 0: return 0.0
    area_a = (xa2-xa1)*(ya2-ya2);   area_b = (xb2-xb1)*(yb2-yb2)
    return inter / (area_a + area_b - inter + 1e-6)




In [None]:
def _pelvis_local_frame(k):
    """Return pelvis (origin) and rotation that makes 'up' vertical."""
    mid_hip  = _mid(k[L_HIP], k[R_HIP], 0.0)
    mid_sho  = _mid(k[L_SHOULDER], k[R_SHOULDER], 0.0)
    if mid_hip is None or mid_sho is None: return None, None
    pelvis = mid_hip[:2]
    up = (mid_sho[:2] - pelvis)
    n = np.linalg.norm(up)
    if n < 1e-6: return pelvis, np.eye(2, dtype=np.float32)
    ang = np.arctan2(up[1], up[0])           # angle of up vs +x
    rot = np.array([[np.cos(np.pi/2 - ang), -np.sin(np.pi/2 - ang)],
                    [np.sin(np.pi/2 - ang),  np.cos(np.pi/2 - ang)]], dtype=np.float32)  # rotate so up->+y
    return pelvis, rot

def _side_sign(x, margin=4.0):
    # x < -m => Left; x > m => Right; else Uncertain
    if x < -margin: return -1
    if x >  margin: return  1
    return 0

def _fix_lr_per_frame(kpts, scores, side_margin_px=4.0, conf_thr=0.2):
    """Swap knees/ankles if they appear on the wrong side in pelvis-aligned coordinates."""
    K = kpts.copy()
    S = scores.copy()
    pelvis, rot = _pelvis_local_frame(K)
    if pelvis is None or rot is None:
        return K, S

    def to_local(idx):
        if S[idx] < conf_thr: return None
        p = K[idx]
        return (p - pelvis) @ rot.T

    # Knees
    lk = to_local(L_KNEE); rk = to_local(R_KNEE)
    if lk is not None and rk is not None:
        sl = _side_sign(lk[0], side_margin_px)
        sr = _side_sign(rk[0], side_margin_px)
        if sl == 1 and sr != 1:  # left knee on right side (and right knee not clearly on right)
            K[[L_KNEE, R_KNEE]] = K[[R_KNEE, L_KNEE]]
            S[[L_KNEE, R_KNEE]] = S[[R_KNEE, L_KNEE]]
        elif sr == -1 and sl != -1:  # right knee on left side
            K[[L_KNEE, R_KNEE]] = K[[R_KNEE, L_KNEE]]
            S[[L_KNEE, R_KNEE]] = S[[R_KNEE, L_KNEE]]

    # Ankles (similar)
    la = to_local(L_ANKLE); ra = to_local(R_ANKLE)
    if la is not None and ra is not None:
        sl = _side_sign(la[0], side_margin_px)
        sr = _side_sign(ra[0], side_margin_px)
        if sl == 1 and sr != 1:
            K[[L_ANKLE, R_ANKLE]] = K[[R_ANKLE, L_ANKLE]]
            S[[L_ANKLE, R_ANKLE]] = S[[R_ANKLE, L_ANKLE]]
        elif sr == -1 and sl != -1:
            K[[L_ANKLE, R_ANKLE]] = K[[R_ANKLE, L_ANKLE]]
            S[[L_ANKLE, R_ANKLE]] = S[[R_ANKLE, L_ANKLE]]

    return K, S


class PoseTemporalStabilizer:
    def __init__(self, min_conf=0.3, ema_alpha=0.5, side_hysteresis=3):
        self.min_conf = min_conf
        self.ema_alpha = ema_alpha
        self.side_hysteresis = side_hysteresis
        self.tracks = {}    # id -> dict(last_box, kpts(K,2), scores(K,), side_count{'knee':cnt,'ankle':cnt})
        self.next_id = 0

    def _match_tracks(self, boxes):
        """Greedy IoU matching: returns list of (track_id, box_idx) and new_ids for unmatched boxes."""
        # build IoU matrix
        track_ids = list(self.tracks.keys())
        pairs, used_tracks, used_boxes = [], set(), set()
        for bi, b in enumerate(boxes):
            best_iou, best_tid = 0.0, None
            for tid in track_ids:
                if tid in used_tracks: continue
                iou = _iou_xyxy(self.tracks[tid]['last_box'], b)
                if iou > best_iou:
                    best_iou, best_tid = iou, tid
            if best_tid is not None and best_iou >= 0.2:
                pairs.append((best_tid, bi))
                used_tracks.add(best_tid); used_boxes.add(bi)
        new_boxes = [i for i in range(len(boxes)) if i not in used_boxes]
        return pairs, new_boxes

    def _ema(self, prev, curr):
        return self.ema_alpha*curr + (1.0-self.ema_alpha)*prev

    def update(self, frame, boxes, keypoints, kp_scores):
        """
        boxes: (N,4) xyxy
        keypoints: (N,K,2) ; kp_scores: (N,K)
        Returns stabilized_keypoints, stabilized_scores
        """
        N, K = keypoints.shape[0], keypoints.shape[1]
        # initialize tracks if empty
        if not self.tracks:
            for i in range(N):
                self.tracks[self.next_id] = {
                    'last_box': boxes[i].astype(float),
                    'kpts': keypoints[i].astype(float),
                    'scores': kp_scores[i].astype(float),
                    'side_cnt_knee': 0,
                    'side_cnt_ankle': 0,
                }
                self.next_id += 1
            return keypoints, kp_scores

        pairs, new_boxes = self._match_tracks(boxes)

        # assign new ids
        for bi in new_boxes:
            self.tracks[self.next_id] = {
                'last_box': boxes[bi].astype(float),
                'kpts': keypoints[bi].astype(float),
                'scores': kp_scores[bi].astype(float),
                'side_cnt_knee': 0,
                'side_cnt_ankle': 0,
            }
            self.next_id += 1

        # output buffers
        out_kpts = np.zeros_like(keypoints, dtype=float)
        out_scrs = np.zeros_like(kp_scores, dtype=float)

        # update matched tracks
        for tid, bi in pairs:
            prev = self.tracks[tid]
            curr_k = keypoints[bi].astype(float)
            curr_s = kp_scores[bi].astype(float)

            # 1) Per-frame L/R fix
            curr_k, curr_s = _fix_lr_per_frame(curr_k, curr_s, side_margin_px=4.0, conf_thr=self.min_conf)

            # 2) Low-confidence fallback + EMA smoothing
            k_stab = prev['kpts'].copy()
            s_stab = prev['scores'].copy()
            for j in range(K):
                if curr_s[j] >= self.min_conf:
                    # EMA smooth towards current
                    k_stab[j] = self._ema(prev['kpts'][j], curr_k[j])
                    s_stab[j] = curr_s[j]
                else:
                    # keep last good (could also extrapolate with velocity)
                    k_stab[j] = prev['kpts'][j]
                    s_stab[j] = prev['scores'][j] * 0.95  # decay a bit

            out_kpts[bi] = k_stab
            out_scrs[bi] = s_stab

            # 3) Update track state
            self.tracks[tid]['last_box'] = boxes[bi].astype(float)
            self.tracks[tid]['kpts'] = k_stab
            self.tracks[tid]['scores'] = s_stab

        # Any unmatched tracks are kept as-is (optional: age them out)
        # For unmatched current boxes (already handled as new ids) we returned their raw values above.

        # For boxes that matched, outputs already set; for newly added we should output current raw
        for bi in new_boxes:
            out_kpts[bi] = keypoints[bi]
            out_scrs[bi] = kp_scores[bi]

        return out_kpts, out_scrs

In [None]:
class BiomechanicalAnalyzer:
    def __init__(self, keypoints_map):
        self.kpt_map = keypoints_map
        # You can also store anthropometric params here
        self.ANTHROPOMETRIC_PARAMS = ANTHROPOMETRIC_PARAMS = {
    "Head": {"mass_perc": 0.081, "com_perc_prox": 1.0}, # CoM at 100% from neck joint (top of neck)
    "Trunk": {"mass_perc": 0.497, "com_perc_prox": 0.5}, # Midpoint of hip-shoulder line
    "Thigh": {"mass_perc": 0.100, "com_perc_prox": 0.433}, # From hip
    "Shank": {"mass_perc": 0.0465, "com_perc_prox": 0.433}, # From knee
    "Foot": {"mass_perc": 0.0145, "com_perc_prox": 0.500}, # From ankle
    "UpperArm": {"mass_perc": 0.028, "com_perc_prox": 0.436}, # From shoulder
    "Forearm": {"mass_perc": 0.016, "com_perc_prox": 0.430}, # From elbow
    "Hand": {"mass_perc": 0.006, "com_perc_prox": 0.506} # From wrist
}
    def analyze_frame(self, keypoints_frame, min_confidence=0.3):
        """Analyzes a single frame of keypoints."""
        
        # Call specific analysis methods
        knee_angle = self._calculate_knee_angle(keypoints_frame, "L_HIP", "L_KNEE", "L_ANKLE", min_confidence)
        # ... other analyses
        
        return {"left_knee_angle": knee_angle}

    def _calculate_knee_angle(self, kps, hip_key, knee_key, ankle_key, min_conf):
        hip = kps[self.kpt_map[hip_key]]
        knee = kps[self.kpt_map[knee_key]]
        ankle = kps[self.kpt_map[ankle_key]]

        thigh_vector = calculate_vector(hip, knee, min_conf)
        shank_vector = calculate_vector(knee, ankle, min_conf)
        
        if thigh_vector is not None and shank_vector is not None:
            # Angle between thigh and the inverse of the shank for the interior joint angle
            return calculate_angle_between_vectors(thigh_vector, -shank_vector)
        return None

In [3]:
from collections import deque
import math
import numpy as np
from dataclasses import dataclass

@dataclass
class FootIndices:
    ankle: int
    heel: int
    big_toe: int
    small_toe: int | None = None  # optional

@dataclass
class FootHistory:
    # store (t, x, y) for ankle/heel/toe
    ankle: deque
    heel: deque
    toe: deque
    contact_flags: deque  # bool: any contact this frame
    state_labels: deque   # str labels
    def __init__(self, maxlen=90):
        self.ankle = deque(maxlen=maxlen)
        self.heel = deque(maxlen=maxlen)
        self.toe  = deque(maxlen=maxlen)
        self.contact_flags = deque(maxlen=maxlen)
        self.state_labels  = deque(maxlen=maxlen)

class SkateContactAnalyzer:
    """
    Inline/quad skating detector with:
      - robust wheeling (toe/heel) detection (direction-agnostic),
      - momentum kick (swing bursts),
      - ankle/foot twist + twist_push (torsion to translation),
      - session-level modes (no "Two-Foot Glide" label):
          * single_wheel_skate: exactly one leg wheeling
          * single_leg_skate: exactly one leg grounded & rolling
          * two_foot_skate: both legs grounded & rolling in similar direction
          * two_wheels_skate: both legs wheeling

    Coordinates: y increases downward (image coords).
    """

    def __init__(
        self,
        left_idx: FootIndices,
        right_idx: FootIndices,
        history_len=120,
        min_conf=0.3,
        vertical_thresh=10.0,       # px: "on ground" tolerance versus ground line
        near_thresh_mul=2.0,        # Near-ground = 2x vertical_thresh
        min_glide_speed=40.0,       # px/s: ground-tangent ankle speed for rolling/wheeling
        min_push_speed=80.0,        # px/s: ankle horizontal burst for a kick
        stable_vert_std=4.0,        # px: vertical std to call something "stable"
        wheel_frames=5,             # frames window to confirm wheeling
        auto_select_stance=True,
        # ---- Twist thresholds ----
        twist_window_s=0.25,          # s of history to analyze twist
        min_twist_angle=0.35,         # rad total change (~20°)
        min_twist_omega=3.0,          # rad/s peak angular velocity
        max_trans_speed_during_twist=35.0,  # px/s: keep twist "mostly in place"
        min_post_twist_speed=70.0      # px/s: burst after twist = push
    ):
        self.left_idx  = left_idx
        self.right_idx = right_idx
        self.min_conf = min_conf
        self.vertical_thresh = vertical_thresh
        self.near_thresh = near_thresh_mul * vertical_thresh
        self.min_glide_speed = min_glide_speed
        self.min_push_speed  = min_push_speed
        self.stable_vert_std = stable_vert_std
        self.wheel_frames = wheel_frames
        self.auto_select_stance = auto_select_stance

        self.twist_window_s = float(twist_window_s)
        self.min_twist_angle = float(min_twist_angle)
        self.min_twist_omega = float(min_twist_omega)
        self.max_trans_speed_during_twist = float(max_trans_speed_during_twist)
        self.min_post_twist_speed = float(min_post_twist_speed)

        self.left_hist  = FootHistory(maxlen=history_len)
        self.right_hist = FootHistory(maxlen=history_len)

        # rolling pool of (x, y) points believed to be on the ground for ground-line fit
        self.ground_pts = deque(maxlen=400)
        # cached ground line y = a*x + b (None until enough points)
        self.ground_line = None  # (a, b)

    # ---------- utilities ----------
    @staticmethod
    def _vel_from_hist(hist_deque):
        if len(hist_deque) < 2: return 0.0, 0.0
        (t0, x0, y0), (t1, x1, y1) = hist_deque[-2], hist_deque[-1]
        dt = max(1e-6, t1 - t0)
        return (x1 - x0) / dt, (y1 - y0) / dt

    @staticmethod
    def _horiz_speed(hist_deque):
        vx, vy = SkateContactAnalyzer._vel_from_hist(hist_deque)
        return abs(vx)

    @staticmethod
    def _vert_std(hist_deque, window=15):
        if len(hist_deque) < 3: return float('inf')
        ys = [p[2] for p in list(hist_deque)[-window:]]
        return float(np.std(ys)) if len(ys) >= 3 else float('inf')

    @staticmethod
    def _foot_pitch(heel_xy, toe_xy):
        # angle of foot segment relative to horizontal, radians
        dx = toe_xy[0] - heel_xy[0]
        dy = toe_xy[1] - heel_xy[1]
        return math.atan2(dy, dx)

    def _dist_to_ground(self, x, y):
        # vertical distance to ground line at x
        if self.ground_line is None:
            return 0.0
        a, b = self.ground_line
        y_line = a * x + b
        return y - y_line  # positive if below line (towards screen bottom)

    def _fit_ground_line(self):
        if len(self.ground_pts) < 40:
            return
        pts = np.array(self.ground_pts)
        xs, ys = pts[:,0], pts[:,1]
        cutoff = np.percentile(ys, 35)  # keep lowest 35% y (closest to ground)
        xs_f = xs[ys <= cutoff]
        ys_f = ys[ys <= cutoff]
        if len(xs_f) >= 10:
            a, b = np.polyfit(xs_f, ys_f, 1)
            self.ground_line = (float(a), float(b))

    def _append_hist(self, hist: FootHistory, t, ankle_xy, heel_xy, toe_xy, contact_flag, state_label):
        hist.ankle.append((t, *ankle_xy))
        hist.heel.append((t, *heel_xy))
        hist.toe.append((t, *toe_xy))
        hist.contact_flags.append(bool(contact_flag))
        hist.state_labels.append(state_label)

    # ---------- ground/tangent helpers (direction-agnostic rolling) ----------
    def _ground_axes(self):
        """Unit tangent/normal to the fitted ground line."""
        if self.ground_line is not None:
            a, _ = self.ground_line
            t = np.array([1.0, a], dtype=float)
            t /= (np.linalg.norm(t) + 1e-12)
            n = np.array([-a, 1.0], dtype=float)
            n /= (np.linalg.norm(n) + 1e-12)
        else:
            t = np.array([1.0, 0.0], dtype=float)
            n = np.array([0.0, 1.0], dtype=float)
        return t, n

    def _last_v2d(self, hist_deque):
        if len(hist_deque) < 2:
            return np.array([0.0, 0.0], dtype=float)
        (t0, x0, y0), (t1, x1, y1) = hist_deque[-2], hist_deque[-1]
        dt = max(1e-6, t1 - t0)
        return np.array([(x1 - x0)/dt, (y1 - y0)/dt], dtype=float)

    def _proj_speeds(self, ankle_hist):
        """(|v_t|, |v_n|) = speeds along ground tangent and normal (direction-agnostic)."""
        v = self._last_v2d(ankle_hist)
        t_hat, n_hat = self._ground_axes()
        vt = abs(float(v @ t_hat))
        vn = abs(float(v @ n_hat))
        return vt, vn

    def _contact_ratios_from_states(self, foot_hist: FootHistory, frames: int):
        """
        Over last `frames` labels, compute fractions for toe-only, heel-only, flat, near, swing.
        Windowing smooths out flicker when reversing or carving.
        """
        labels = list(foot_hist.state_labels)
        if not labels:
            return 0.0, 0.0, 0.0, 0.0, 1.0
        labels = labels[-frames:]

        toe_only = sum(s in ("Toe Dominant", "Tiptoeing", "Wheel-Toe") for s in labels)
        heel_only = sum(s in ("Heel Strike", "Wheel-Heel") for s in labels)
        flat = sum(s in ("Flat Foot",) for s in labels)
        near = sum(s in ("Near Ground",) for s in labels)
        swing = sum(s in ("Swing",) for s in labels)

        denom = max(1, len(labels))
        return toe_only/denom, heel_only/denom, flat/denom, near/denom, swing/denom

    @staticmethod
    def _is_grounded_state(s: str) -> bool:
        if not s: return False
        grounded = {
            "Flat Foot", "Toe Dominant", "Tiptoeing", "Heel Strike",
            "Near Ground", "Wheel-Toe", "Wheel-Heel"
        }
        return s in grounded

    # ---------- core per-foot contact classification ----------
    def _classify_contact(self, ankle_xy, heel_xy, toe_xy):
        """
        Returns: state_label, heel_on, toe_on, near_ground
        """
        dy_heel = self._dist_to_ground(*heel_xy)
        dy_toe  = self._dist_to_ground(*toe_xy)

        heel_on = abs(dy_heel) < self.vertical_thresh
        toe_on  = abs(dy_toe)  < self.vertical_thresh
        near_g  = (not heel_on and not toe_on) and (abs(dy_heel) < self.near_thresh or abs(dy_toe) < self.near_thresh)

        # Relative verticals (y grows downward)
        ankle_high_vs_both = (ankle_xy[1] + self.vertical_thresh) < min(heel_xy[1], toe_xy[1])

        if heel_on and not toe_on and toe_xy[1] < heel_xy[1]:
            return "Heel Strike", heel_on, toe_on, near_g
        if toe_on and not heel_on and heel_xy[1] < toe_xy[1]:
            if ankle_high_vs_both and (heel_xy[1] + self.vertical_thresh/2) < toe_xy[1]:
                return "Tiptoeing", heel_on, toe_on, near_g
            return "Toe Dominant", heel_on, toe_on, near_g
        if heel_on and toe_on:
            return "Flat Foot", heel_on, toe_on, near_g
        if not heel_on and not toe_on:
            # far off?
            if min(heel_xy[1], toe_xy[1]) + 2*self.vertical_thresh < (
                self.ground_line[0]*ankle_xy[0] + self.ground_line[1] if self.ground_line
                else ankle_xy[1] + 2*self.vertical_thresh
            ):
                return "Swing", heel_on, toe_on, near_g
            return "Near Ground", heel_on, toe_on, near_g

        return "Indeterminate", heel_on, toe_on, near_g

    # ---------- stance selection ----------
    def _pick_stance_leg(self):
        # Stance = more on-ground recently and lower horizontal speed
        def score(hist: FootHistory):
            on_frac = (sum(hist.contact_flags) / max(1, len(hist.contact_flags)))
            speed   = self._horiz_speed(hist.ankle)
            return on_frac - 0.002*speed
        left_score  = score(self.left_hist)
        right_score = score(self.right_hist)
        return 'left' if left_score >= right_score else 'right'

    # ---------- robust wheeling detection (direction-agnostic) ----------
    """
    def _wheeling_state(self, foot_hist: FootHistory):
        if len(foot_hist.ankle) < 3:
            return None

        window = max(6, self.wheel_frames)
        toe_r, heel_r, flat_r, near_r, swing_r = self._contact_ratios_from_states(foot_hist, frames=window)

        vt, vn = self._proj_speeds(foot_hist.ankle)
        vert_stability = self._vert_std(foot_hist.ankle)

        # Must be rolling and stable vertically
        if vt < self.min_glide_speed or vert_stability > self.stable_vert_std:
            return None

        # Dominant toe/heel contact despite brief flickers
        if toe_r >= 0.55 and heel_r <= 0.30:
            return "Wheel-Toe"
        if heel_r >= 0.55 and toe_r <= 0.30:
            return "Wheel-Heel"

        return None
    """
    def _wheeling(self, hist: FootHistory):
        # Returns dict with angle even if not strictly "wheeling"
        if len(hist.ankle) < 3:
            return {'label': None, 'angle': None}

        # --- Compute extended angle knee–ankle–toe/heel ---
        try:
            heel_xy = np.array([hist.heel[-1][1], hist.heel[-1][2]])
            toe_xy  = np.array([hist.toe[-1][1],  hist.toe[-1][2]])
            ankle_xy = np.array([hist.ankle[-1][1], hist.ankle[-1][2]])
            # Approximate knee position relative to ankle if not passed explicitly, 
            # but relying on vector logic defined in previous code:
            knee_xy = ankle_xy + np.array([0, -40]) 
        except Exception:
            return {'label': None, 'angle': None}

        foot_vec = toe_xy - heel_xy
        leg_vec  = ankle_xy - knee_xy
        
        # This is the angle between the shin and the foot
        ang = abs(angle_between(leg_vec, foot_vec))  # degrees

        # --- Logic for labeling (optional, but we keep it for context) ---
        label = None
        
        # Only apply label if speed and stability requirements are met
        window = max(6, self.wheel_frames)
        toe_r, heel_r, flat_r, near_r, swing_r = self._contact_ratios(hist, frames=window)
        vt, vn = self._proj_speeds(hist.ankle)
        vert_stab = self._vert_std(hist.ankle)

        if vt >= self.min_glide_speed and vert_stab <= self.stable_vert_std:
            if ang >= 95:
                label = "Wheel-Toe"
            elif ang <= 80:
                label = "Wheel-Heel"
            elif toe_r >= 0.55 and heel_r <= 0.30:
                label = "Wheel-Toe"
            elif heel_r >= 0.55 and toe_r <= 0.30:
                label = "Wheel-Heel"

        return {"label": label, "angle": float(round(ang, 1))}

    # ---------- momentum (kicking) detection for swing leg ----------
    def _momentum_kick(self, swing_hist: FootHistory, stance_hist: FootHistory):
        # A kick = short burst of swing-foot horizontal speed while briefly ground-brushing,
        # with stance foot stably on ground.
        if len(swing_hist.ankle) < 3 or len(stance_hist.ankle) < 3:
            return False
        v_sw = self._horiz_speed(swing_hist.ankle)
        stance_stable = self._vert_std(stance_hist.ankle) < self.stable_vert_std and any(stance_hist.contact_flags)
        recent = swing_hist.state_labels[-6:]
        toe_like = any(s in ("Toe Dominant", "Tiptoeing", "Heel Strike") for s in recent)
        return (v_sw >= self.min_push_speed) and stance_stable and toe_like

    # ---------- session-level modes ----------
    def _grounded_ratio(self, foot_hist: FootHistory, frames=8):
        labels = list(foot_hist.state_labels)[-frames:]
        if not labels: return 0.0
        return sum(self._is_grounded_state(s) for s in labels) / max(1, len(labels))

    def _session_modes(self):
        """
        Decide global modes from left/right behavior (direction-agnostic):
          - single_wheel_skate: exactly one leg wheeling
          - single_leg_skate: exactly one leg grounded & rolling, other mostly not
          - two_foot_skate: both legs grounded & rolling in roughly same direction
          - two_wheels_skate: both legs wheeling
        Returns dict with booleans and details.
        """
        # Per-leg rolling speed along ground tangent
        vLt, _ = self._proj_speeds(self.left_hist.ankle)  if len(self.left_hist.ankle)  >= 2 else (0.0, 0.0)
        vRt, _ = self._proj_speeds(self.right_hist.ankle) if len(self.right_hist.ankle) >= 2 else (0.0, 0.0)
        L_roll = vLt >= self.min_glide_speed
        R_roll = vRt >= self.min_glide_speed

        # Grounded ratios (robust to flicker)
        L_gr = self._grounded_ratio(self.left_hist, frames=8)
        R_gr = self._grounded_ratio(self.right_hist, frames=8)
        L_grounded = L_gr >= 0.5
        R_grounded = R_gr >= 0.5

        # Wheeling states
        L_wh = self._wheeling_state(self.left_hist)
        R_wh = self._wheeling_state(self.right_hist)
        L_is_wheel = L_wh is not None
        R_is_wheel = R_wh is not None

        # Direction coherence (signed along ground tangent)
        t_hat, _ = self._ground_axes()
        vL_t_signed = float(self._last_v2d(self.left_hist.ankle)  @ t_hat)
        vR_t_signed = float(self._last_v2d(self.right_hist.ankle) @ t_hat)
        cos_sim = (vL_t_signed * vR_t_signed) / (abs(vL_t_signed) * abs(vR_t_signed) + 1e-6)

        two_wheels_skate = L_is_wheel and R_is_wheel
        single_wheel_skate = (L_is_wheel ^ R_is_wheel)
        two_foot_skate = (L_grounded and R_grounded and L_roll and R_roll and cos_sim >= 0.4)
        single_leg_skate = (
            (L_grounded and L_roll and not (R_grounded and R_roll)) ^
            (R_grounded and R_roll and not (L_grounded and L_roll))
        )

        return {
            'single_wheel_skate': bool(single_wheel_skate),
            'single_leg_skate':   bool(single_leg_skate),
            'two_foot_skate':     bool(two_foot_skate),
            'two_wheels_skate':   bool(two_wheels_skate),
            'details': {
                'left':  {'wheeling': L_wh, 'grounded_ratio': L_gr, 'vt': vLt},
                'right': {'wheeling': R_wh, 'grounded_ratio': R_gr, 'vt': vRt},
                'cos_sim': float(cos_sim)
            }
        }

    # ---------- Twist helpers ----------
    @staticmethod
    def _foot_yaw(heel_xy, toe_xy):
        """Angle (rad) of heel->toe vector; image coords (y down)."""
        dx = toe_xy[0] - heel_xy[0]
        dy = toe_xy[1] - heel_xy[1]
        return math.atan2(dy, dx)

    @staticmethod
    def _window_from_hist(hist_deque, t_now, window_s):
        """Subsequence of (t,x,y) within [t_now-window_s, t_now]."""
        if not hist_deque: return []
        lo = t_now - window_s
        out = []
        for (t,x,y) in reversed(hist_deque):
            if t < lo: break
            out.append((t,x,y))
        return list(reversed(out))

    @staticmethod
    def _unwrap_angles(angles):
        if len(angles) <= 1: return angles[:]
        return list(np.unwrap(np.asarray(angles, dtype=float)))

    def _ankle_horiz_speeds_window(self, ankle_hist, t_now, window_s):
        seg = self._window_from_hist(ankle_hist, t_now, window_s)
        if len(seg) < 2: return []
        speeds = []
        for (t0,x0,y0),(t1,x1,y1) in zip(seg[:-1], seg[1:]):
            dt = max(1e-6, t1 - t0)
            speeds.append(abs((x1 - x0)/dt))
        return speeds

    def _twist_metrics(self, foot_hist: FootHistory, t_now):
        """
        Compute twist metrics over a short window using heel->toe yaw.
        Returns dict or None: {delta_theta, omega_peak, omega_mean, direction,
                               grounded_ratio, last_state, trans_speed_mean, post_speed}
        """
        heel_seg = self._window_from_hist(foot_hist.heel, t_now, self.twist_window_s)
        toe_seg  = self._window_from_hist(foot_hist.toe,  t_now, self.twist_window_s)
        if len(heel_seg) < 2 or len(toe_seg) < 2:
            return None

        n = min(len(heel_seg), len(toe_seg))
        heel_seg, toe_seg = heel_seg[-n:], toe_seg[-n:]
        ts   = [p[0] for p in heel_seg]
        yaws = [self._foot_yaw((h[1],h[2]), (t_[1],t_[2])) for h,t_ in zip(heel_seg, toe_seg)]
        yaws_u = self._unwrap_angles(yaws)
        if len(yaws_u) < 2: return None

        # Angular velocity samples
        oma = []
        for (th0, t0), (th1, t1) in zip(zip(yaws_u[:-1], ts[:-1]), zip(yaws_u[1:], ts[1:])):
            dt = max(1e-6, t1 - t0)
            oma.append((th1 - th0)/dt)

        # Translation gating during the twist window
        trans = self._ankle_horiz_speeds_window(foot_hist.ankle, t_now, self.twist_window_s)
        trans_speed_mean = float(np.mean(trans)) if trans else 0.0

        # Grounded ratio via state labels
        recent_states = list(foot_hist.state_labels)[-len(ts):] if foot_hist.state_labels else []
        grounded_ratio = sum(self._is_grounded_state(s) for s in recent_states) / max(1, len(recent_states))
        last_state = recent_states[-1] if recent_states else None

        delta_theta = yaws_u[-1] - yaws_u[0]
        omega_peak  = float(max((abs(w) for w in oma), default=0.0))
        omega_mean  = float(np.mean(np.abs(oma))) if oma else 0.0
        direction   = 'ccw' if delta_theta > 0 else ('cw' if delta_theta < 0 else 'flat')

        # Immediate post-twist ankle speed (x-based; simple proxy)
        post_speed = self._horiz_speed(foot_hist.ankle) if len(foot_hist.ankle) >= 2 else 0.0

        return {
            'delta_theta': float(delta_theta),
            'omega_peak': omega_peak,
            'omega_mean': omega_mean,
            'direction': direction,
            'grounded_ratio': float(grounded_ratio),
            'last_state': last_state,
            'trans_speed_mean': trans_speed_mean,
            'post_speed': float(post_speed)
        }

    def _twist_detect(self, metrics):
        """
        Return (twist_flag, twist_push_flag, pivot, score).
        Fires when grounded (single-foot or wheeling), angle change & angular speed are high,
        and translation during the twist itself is limited.
        """
        if not metrics:
            return False, False, None, 0.0

        # grounded requirement (lenient so near-ground passes)
        grounded_ok = metrics['grounded_ratio'] >= 0.4

        # pivot inference from last_state
        ls = metrics.get('last_state')
        if ls in ('Tiptoeing', 'Toe Dominant', 'Wheel-Toe'):
            pivot = 'toe'
        elif ls in ('Heel Strike', 'Wheel-Heel'):
            pivot = 'heel'
        elif ls in ('Flat Foot', 'Near Ground'):
            pivot = 'flat'
        else:
            pivot = 'mixed'

        # allow slightly more translation if 'flat', stricter for wheel pivots
        trans_gate = metrics['trans_speed_mean'] <= (
            self.max_trans_speed_during_twist * (1.35 if pivot == 'flat' else 1.0)
        )

        twist = (
            grounded_ok and
            metrics['omega_peak']    >= self.min_twist_omega and
            abs(metrics['delta_theta']) >= self.min_twist_angle and
            trans_gate
        )

        twist_push = twist and (metrics['post_speed'] >= self.min_post_twist_speed)

        score = (
            (1.0 if grounded_ok else 0.5) *
            abs(metrics['delta_theta']) *
            (metrics['omega_peak'] / max(1e-6, self.min_twist_omega)) *
            (1.0 + (1.0 - min(1.0, metrics['trans_speed_mean'] / max(1e-6, self.max_trans_speed_during_twist))))
        )

        return twist, twist_push, pivot, float(score)

    # ---------- public update ----------
    def update(self, keypoints_frame, t, leg_choice='auto'):
        """
        keypoints_frame: list/array of (x,y,conf) for this frame
        t: timestamp in seconds
        leg_choice: 'auto' | 'left' | 'right'
        """
        def ok(i):
            k = keypoints_frame[i]
            return (k[2] if len(k) > 2 else 0.0) > self.min_conf

        def get_xy(i):
            k = keypoints_frame[i]; return (k[0], k[1])

        low_conf_return = {
            'stance_leg': None,
            'left':  {'state': 'Unknown (Low Confidence)'},
            'right': {'state': 'Unknown (Low Confidence)'},
            'wheeling': {'left': None, 'right': None},
            'momentum_kick': {'left': False, 'right': False},
            'ankle_twist': {'left': None, 'right': None},
            'twist_push': {'left': False, 'right': False},
            'modes': {
                'single_wheel_skate': False,
                'single_leg_skate': False,
                'two_foot_skate': False,
                'two_wheels_skate': False,
                'details': {}
            },
            'ground_line': self.ground_line
        }

        # LEFT classify
        if not (ok(self.left_idx.ankle) and ok(self.left_idx.heel) and ok(self.left_idx.big_toe)):
            left_state = 'Unknown (Low Confidence)'; L_flags=(False, False, False)
        else:
            L_ank = get_xy(self.left_idx.ankle)
            L_heel= get_xy(self.left_idx.heel)
            L_toe = get_xy(self.left_idx.big_toe)
            left_state, L_heel_on, L_toe_on, L_near = self._classify_contact(L_ank, L_heel, L_toe)
            L_flags = (L_heel_on, L_toe_on, L_near)

        # RIGHT classify
        if not (ok(self.right_idx.ankle) and ok(self.right_idx.heel) and ok(self.right_idx.big_toe)):
            right_state = 'Unknown (Low Confidence)'; R_flags=(False, False, False)
        else:
            R_ank = get_xy(self.right_idx.ankle)
            R_heel= get_xy(self.right_idx.heel)
            R_toe = get_xy(self.right_idx.big_toe)
            right_state, R_heel_on, R_toe_on, R_near = self._classify_contact(R_ank, R_heel, R_toe)
            R_flags = (R_heel_on, R_toe_on, R_near)

        # If both are unknown due to low conf, bail early
        if ('Unknown' in locals().get('left_state','')) and ('Unknown' in locals().get('right_state','')):
            return low_conf_return

        # Append histories & grow ground-point pool from confident contacts
        if 'Unknown' not in left_state:
            self._append_hist(self.left_hist, t, L_ank, L_heel, L_toe, any(L_flags[:2]), left_state)
            if L_flags[0]: self.ground_pts.append((L_heel[0], L_heel[1]))
            if L_flags[1]: self.ground_pts.append((L_toe[0],  L_toe[1]))
        if 'Unknown' not in right_state:
            self._append_hist(self.right_hist, t, R_ank, R_heel, R_toe, any(R_flags[:2]), right_state)
            if R_flags[0]: self.ground_pts.append((R_heel[0], R_heel[1]))
            if R_flags[1]: self.ground_pts.append((R_toe[0],  R_toe[1]))

        # Continuously refit ground line (handles slopes/uneven ground)
        self._fit_ground_line()

        # Stance choice
        stance_leg = (self._pick_stance_leg() if (leg_choice=='auto' and self.auto_select_stance)
                      else leg_choice if leg_choice in ('left','right') else None)

        # Robust wheeling per leg (direction-agnostic)
        left_wheel  = self._wheeling_state(self.left_hist)  if len(self.left_hist.ankle)>2 else None
        right_wheel = self._wheeling_state(self.right_hist) if len(self.right_hist.ankle)>2 else None

        # Momentum kick (by the swing leg)
        if stance_leg=='left':
            mk_left  = False
            mk_right = self._momentum_kick(self.right_hist, self.left_hist)
        elif stance_leg=='right':
            mk_left  = self._momentum_kick(self.left_hist, self.right_hist)
            mk_right = False
        else:
            mk_left  = self._momentum_kick(self.left_hist, self.right_hist)
            mk_right = self._momentum_kick(self.right_hist, self.left_hist)

        # --- Ankle twist: per leg (works for single-foot, wheeling or flat) ---
        t_now = t
        L_metrics = self._twist_metrics(self.left_hist,  t_now) if len(self.left_hist.ankle)  >= 2 else None
        R_metrics = self._twist_metrics(self.right_hist, t_now) if len(self.right_hist.ankle) >= 2 else None

        L_twist, L_twist_push, L_pivot, L_score = self._twist_detect(L_metrics) if L_metrics else (False, False, None, 0.0)
        R_twist, R_twist_push, R_pivot, R_score = self._twist_detect(R_metrics) if R_metrics else (False, False, None, 0.0)

        # Global modes (no "Two-Foot Glide" label; just mode summary)
        modes = self._session_modes()

        return {
            'stance_leg': stance_leg,
            'left':  {'state': left_state},
            'right': {'state': right_state},
            'wheeling': {
                'left': left_wheel,
                'right': right_wheel
            },
            'momentum_kick': {
                'left': mk_left,
                'right': mk_right
            },
            'ankle_twist': {
                'left':  {**(L_metrics or {}), 'detected': L_twist, 'pivot': L_pivot, 'score': L_score} if L_metrics else None,
                'right': {**(R_metrics or {}), 'detected': R_twist, 'pivot': R_pivot, 'score': R_score} if R_metrics else None
            },
            'twist_push': {
                'left':  L_twist_push,
                'right': R_twist_push
            },
            'modes': modes,
            'ground_line': self.ground_line
        }



In [4]:
#Biggest box or most active box is the main focus to do the calculation
# need to use opencv to find out the fps
# need to have index properly done

# 


In [1]:
"""
Skating Video Analyzer
----------------------
- Detects people with YOLOX, estimates keypoints with RTMPose
- Robust per-foot contact classification w/ wheeling detection
- Swing/kick (momentum generation) detection
- Twist + twist-push detection (torsion -> translation)
- Body stability (torso uprightness, head bob), hand stability (wrist speed)
- Knee angles (per leg) over time
- Ground-line fitting (handles slopes/uneven ground)
- Multi-person: pick main subject (centered + large box), track with simple IoU
- Frame skipping (process every N frames) to reduce compute
- On-screen overlays: skeleton, bbox, status panel inside the box, momentum arrow

⚠️ Notes
- This file is designed to plug into your existing pipeline using `rtmlib`.
- The wrapper signatures of `YOLOX` and `RTMPose` may vary; adjust the two
  methods `_yolo_detect()` and `_rtm_infer()` if needed.
- Thresholds are conservative defaults; tune per video.
- Coordinates assume image y increases downward (OpenCV default).
"""
from __future__ import annotations
import cv2
import math
import time
import numpy as np
from collections import deque
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from rtmlib.tools import RTMPose, YOLOX
from rtmlib import draw_skeleton, draw_bbox

# --- External models (adjust imports to your environment) ---
try:
    from rtmlib.tools import RTMPose, YOLOX
    from rtmlib import draw_skeleton, draw_bbox
except Exception:
    RTMPose = YOLOX = None
    def draw_skeleton(img, kpts, pairs=None, conf=None, conf_thr=0.2, **kw):
        # Minimal fallback: draw small circles for keypoints
        for i, p in enumerate(kpts):
            if len(p) >= 3 and p[2] < conf_thr: 
                continue
            cv2.circle(img, (int(p[0]), int(p[1])), 2, (255,255,255), -1)
        return img
    def draw_bbox(img, box, **kw):
        x1,y1,x2,y2 = map(int, box)
        return cv2.rectangle(img,(x1,y1),(x2,y2),(255,255,255),1)

# -----------------------------
# Keypoint indexing (HALPE-26 / body7)
# -----------------------------
KEYPOINT_DICT = {"NOSE":0, "L_EYE":1, "R_EYE":2, "L_EAR":3, "R_EAR":4, "L_SHOULDER":5,
                 "R_SHOULDER":6, "L_ELBOW":7, "R_ELBOW":8, "L_WRIST":9, "R_WRIST":10,
                 "L_HIP":11, "R_HIP":12, "L_KNEE":13, "R_KNEE":14, "L_ANKLE":15, "R_ANKLE":16,
                 "HEAD":17, "NECK":18, "HIP_CENTER":19, "L_BIG_TOE":20, "R_BIG_TOE":21, "L_SMALL_TOE":22,
                 "R_SMALL_TOE":23, "L_HEEL":24, "R_HEEL":25}

# Convenience constants
NOSE = KEYPOINT_DICT["NOSE"]
L_EAR, R_EAR = KEYPOINT_DICT["L_EAR"], KEYPOINT_DICT["R_EAR"]
L_SHOULDER, R_SHOULDER = KEYPOINT_DICT["L_SHOULDER"], KEYPOINT_DICT["R_SHOULDER"]
L_ELBOW, R_ELBOW = KEYPOINT_DICT["L_ELBOW"], KEYPOINT_DICT["R_ELBOW"]
L_WRIST, R_WRIST = KEYPOINT_DICT["L_WRIST"], KEYPOINT_DICT["R_WRIST"]
L_HIP, R_HIP = KEYPOINT_DICT["L_HIP"], KEYPOINT_DICT["R_HIP"]
L_KNEE, R_KNEE = KEYPOINT_DICT["L_KNEE"], KEYPOINT_DICT["R_KNEE"]
L_ANKLE, R_ANKLE = KEYPOINT_DICT["L_ANKLE"], KEYPOINT_DICT["R_ANKLE"]
L_BIG_TOE, R_BIG_TOE = KEYPOINT_DICT["L_BIG_TOE"], KEYPOINT_DICT["R_BIG_TOE"]
L_HEEL, R_HEEL = KEYPOINT_DICT["L_HEEL"], KEYPOINT_DICT["R_HEEL"]

# -----------------------------
# Anthro + physics params (approx, Dempster-based)
# -----------------------------
GRAVITY = 9.81
ANTHRO = {
    "Head":      {"mass_perc": 0.081,  "com_perc_prox": 0.50},
    "Trunk":     {"mass_perc": 0.497,  "com_perc_prox": 0.50},
    "Thigh":     {"mass_perc": 0.100,  "com_perc_prox": 0.433},
    "Shank":     {"mass_perc": 0.0465, "com_perc_prox": 0.433},
    "Foot":      {"mass_perc": 0.0145, "com_perc_prox": 0.50},
    "UpperArm":  {"mass_perc": 0.027,  "com_perc_prox": 0.436},
    "Forearm":   {"mass_perc": 0.016,  "com_perc_prox": 0.430},
    "Hand":      {"mass_perc": 0.006,  "com_perc_prox": 0.50},
}

# -----------------------------
# Small math/pose helpers
# -----------------------------

def xyc(kp):
    if kp is None:
        return None
    a = np.asarray(kp)
    if a.size >= 3:
        return float(a[0]), float(a[1]), float(a[2])
    return float(a[0]), float(a[1]), 1.0

def valid(kp, thr=0.3):
    v = xyc(kp)
    return (v is not None) and (v[2] >= thr)

def midpoint(p1, p2, min_conf=0.0):
    a, b = xyc(p1), xyc(p2)
    if a is None or b is None:
        return None
    if a[2] < min_conf or b[2] < min_conf:
        return None
    x = 0.5 * (a[0] + b[0])
    y = 0.5 * (a[1] + b[1])
    c = min(a[2], b[2])
    return np.array([x, y, c], dtype=np.float32)

def vec(a, b):
    """Returns vector from point a to point b (b - a)."""
    if a is None or b is None:
        return None
    aa, bb = np.asarray(a)[:2], np.asarray(b)[:2]
    if aa.size < 2 or bb.size < 2:
        return None
    return bb - aa

def angle_between(v1, v2):
    """
    Calculates the interior angle (0-180 degrees) between two vectors.
    """
    if v1 is None or v2 is None:
        return None
    v1 = np.asarray(v1, dtype=float)
    v2 = np.asarray(v2, dtype=float)
    
    # Magnitudes
    n1 = np.linalg.norm(v1)
    n2 = np.linalg.norm(v2)
    
    # Prevent Division by Zero
    if n1 < 1e-6 or n2 < 1e-6:
        return 0.0
    
    # Dot product formula: a . b = |a||b|cos(theta)
    dot_prod = np.dot(v1, v2)
    c = dot_prod / (n1 * n2)
    
    # Clipping is crucial because float precision can result in 1.000000002
    c = max(-1.0, min(1.0, c))
    
    return math.degrees(math.acos(c))

def angle_to_horizontal(vseg):
    if vseg is None:
        return None
    return math.degrees(math.atan2(float(vseg[1]), float(vseg[0])))

def sgolay(y, window=5, poly=2):
    # Minimal S-G filter (optional): uses numpy polyfit over sliding windows
    if len(y) < window:
        return np.asarray(y)
    out = np.zeros_like(y, dtype=float)
    half = window // 2
    xs = np.arange(len(y))
    for i in range(len(y)):
        lo = max(0, i - half)
        hi = min(len(y), i + half + 1)
        xi = xs[lo:hi]
        yi = y[lo:hi]
        if len(xi) < 2:
            out[i] = y[i]
            continue
        coeff = np.polyfit(xi - xi.mean(), yi, poly)
        out[i] = np.polyval(coeff, 0.0)  # center value
    return out

def deriv(y, dt, smooth=True):
    y = np.asarray(y, dtype=float)
    if y.size < 2:
        return np.zeros_like(y)
    if smooth:
        y = sgolay(y)
    return np.gradient(y, dt)

# -----------------------------
# IoU + simple geometry
# -----------------------------

def iou_xyxy(a, b):
    xa1, ya1, xa2, ya2 = a
    xb1, yb1, xb2, yb2 = b
    inter_w = max(0.0, min(xa2, xb2) - max(xa1, xb1))
    inter_h = max(0.0, min(ya2, yb2) - max(ya1, yb1))
    inter = inter_w * inter_h
    if inter <= 0:
        return 0.0
    area_a = max(0.0, xa2 - xa1) * max(0.0, ya2 - ya1)
    area_b = max(0.0, xb2 - xb1) * max(0.0, yb2 - yb1)
    return float(inter / (area_a + area_b - inter + 1e-6))

# -----------------------------
# L/R sanity fix (pelvis-aligned)
# -----------------------------

def pelvis_local_frame(kpts, scores, conf_thr=0.2):
    K = np.asarray(kpts)
    S = np.asarray(scores)
    mh = midpoint(K[L_HIP], K[R_HIP], 0.0)
    ms = midpoint(K[L_SHOULDER], K[R_SHOULDER], 0.0)
    if mh is None or ms is None:
        return None, None
    pelvis = mh[:2]
    up = ms[:2] - pelvis
    n = float(np.linalg.norm(up))
    if n < 1e-6:
        return pelvis, np.eye(2, dtype=np.float32)
    ang = math.atan2(up[1], up[0])
    # rotate so that up -> +y
    rot = np.array([[math.cos(math.pi/2 - ang), -math.sin(math.pi/2 - ang)],
                    [math.sin(math.pi/2 - ang),  math.cos(math.pi/2 - ang)]], dtype=np.float32)
    return pelvis, rot

def _side_sign(x, margin=4.0):
    if x < -margin: return -1
    if x >  margin: return  1
    return 0

def fix_lr_per_frame(kpts, scores, side_margin_px=4.0, conf_thr=0.2):
    K = np.asarray(kpts).copy()
    S = np.asarray(scores).copy()
    pelvis, rot = pelvis_local_frame(K, S, conf_thr=conf_thr)
    if pelvis is None or rot is None:
        return K, S
    def to_local(idx):
        if S[idx] < conf_thr:
            return None
        p = K[idx]
        return (p[:2] - pelvis) @ rot.T
    # Knees
    lk, rk = to_local(L_KNEE), to_local(R_KNEE)
    if lk is not None and rk is not None:
        sl, sr = _side_sign(lk[0], side_margin_px), _side_sign(rk[0], side_margin_px)
        if sl == 1 and sr != 1:
            K[[L_KNEE, R_KNEE]] = K[[R_KNEE, L_KNEE]]
            S[[L_KNEE, R_KNEE]] = S[[R_KNEE, L_KNEE]]
        elif sr == -1 and sl != -1:
            K[[L_KNEE, R_KNEE]] = K[[R_KNEE, L_KNEE]]
            S[[L_KNEE, R_KNEE]] = S[[R_KNEE, L_KNEE]]
    # Ankles
    la, ra = to_local(L_ANKLE), to_local(R_ANKLE)
    if la is not None and ra is not None:
        sl, sr = _side_sign(la[0], side_margin_px), _side_sign(ra[0], side_margin_px)
        if sl == 1 and sr != 1:
            K[[L_ANKLE, R_ANKLE]] = K[[R_ANKLE, L_ANKLE]]
            S[[L_ANKLE, R_ANKLE]] = S[[R_ANKLE, L_ANKLE]]
        elif sr == -1 and sl != -1:
            K[[L_ANKLE, R_ANKLE]] = K[[R_ANKLE, L_ANKLE]]
            S[[L_ANKLE, R_ANKLE]] = S[[R_ANKLE, L_ANKLE]]
    return K, S

# -----------------------------
# Foot contact/wheeling + twist logic
# -----------------------------
@dataclass
class FootIndices:
    ankle: int
    heel: int
    big_toe: int
    small_toe: Optional[int] = None

@dataclass
class FootHistory:
    ankle: deque
    heel: deque
    toe: deque
    contact_flags: deque
    state_labels: deque
    def __init__(self, maxlen=90):
        self.ankle = deque(maxlen=maxlen)
        self.heel = deque(maxlen=maxlen)
        self.toe  = deque(maxlen=maxlen)
        self.contact_flags = deque(maxlen=maxlen)
        self.state_labels  = deque(maxlen=maxlen)

class SkateContactAnalyzer:
    """Direction-agnostic wheeling & momentum detection with ground-line fitting."""
    def __init__(self,
                 left_idx: FootIndices,
                 right_idx: FootIndices,
                 history_len=120,
                 min_conf=0.3,
                 vertical_thresh=10.0,
                 near_thresh_mul=2.0,
                 min_glide_speed=40.0,
                 min_push_speed=80.0,
                 stable_vert_std=4.0,
                 wheel_frames=5,
                 auto_select_stance=True,
                 twist_window_s=0.25,
                 min_twist_angle=0.35,
                 min_twist_omega=3.0,
                 max_trans_speed_during_twist=35.0,
                 min_post_twist_speed=70.0):
        self.left_idx  = left_idx
        self.right_idx = right_idx
        self.min_conf = min_conf
        self.vertical_thresh = float(vertical_thresh)
        self.near_thresh = float(near_thresh_mul) * float(vertical_thresh)
        self.min_glide_speed = float(min_glide_speed)
        self.min_push_speed  = float(min_push_speed)
        self.stable_vert_std = float(stable_vert_std)
        self.wheel_frames = int(wheel_frames)
        self.auto_select_stance = bool(auto_select_stance)
        self.twist_window_s = float(twist_window_s)
        self.min_twist_angle = float(min_twist_angle)
        self.min_twist_omega = float(min_twist_omega)
        self.max_trans_speed_during_twist = float(max_trans_speed_during_twist)
        self.min_post_twist_speed = float(min_post_twist_speed)
        self.left_hist  = FootHistory(maxlen=history_len)
        self.right_hist = FootHistory(maxlen=history_len)
        self.ground_pts = deque(maxlen=400)
        self.ground_line: Optional[Tuple[float,float]] = None  # y = a*x + b

    # --- small helpers ---
    @staticmethod
    def _vel(hist: deque):
        if len(hist) < 2:
            return 0.0, 0.0
        (t0,x0,y0), (t1,x1,y1) = hist[-2], hist[-1]
        dt = max(1e-6, t1 - t0)
        return (x1-x0)/dt, (y1-y0)/dt

    @staticmethod
    def _horiz_speed(hist: deque):
        vx, _ = SkateContactAnalyzer._vel(hist)
        return abs(vx)

    @staticmethod
    def _vert_std(hist: deque, window=15):
        if len(hist) < 3:
            return float('inf')
        ys = [p[2] for p in list(hist)[-window:]]
        return float(np.std(ys)) if len(ys) >= 3 else float('inf')

    @staticmethod
    def _foot_pitch(heel_xy, toe_xy):
        dx = toe_xy[0]-heel_xy[0]
        dy = toe_xy[1]-heel_xy[1]
        return math.atan2(dy, dx)

    def _dist_to_ground(self, x, y):
        if self.ground_line is None:
            return 0.0
        a, b = self.ground_line
        y_line = a * x + b
        return y - y_line

    def _fit_ground_line(self):
        if len(self.ground_pts) < 40:
            return
        pts = np.array(self.ground_pts, dtype=float)
        xs, ys = pts[:,0], pts[:,1]
        cutoff = np.percentile(ys, 35)
        xs_f, ys_f = xs[ys <= cutoff], ys[ys <= cutoff]
        if len(xs_f) >= 10:
            a, b = np.polyfit(xs_f, ys_f, 1)
            self.ground_line = (float(a), float(b))

    def _append_hist(self, hist: FootHistory, t, ankle_xy, heel_xy, toe_xy, contact_flag, label):
        hist.ankle.append((t, *ankle_xy))
        hist.heel.append((t, *heel_xy))
        hist.toe.append((t, *toe_xy))
        hist.contact_flags.append(bool(contact_flag))
        hist.state_labels.append(label)

    def _ground_axes(self):
        if self.ground_line is not None:
            a, _ = self.ground_line
            t = np.array([1.0, a], dtype=float)
            t /= (np.linalg.norm(t) + 1e-12)
            n = np.array([-a, 1.0], dtype=float)
            n /= (np.linalg.norm(n) + 1e-12)
        else:
            t = np.array([1.0, 0.0], dtype=float)
            n = np.array([0.0, 1.0], dtype=float)
        return t, n

    def _last_v2d(self, hist: deque):
        if len(hist) < 2:
            return np.array([0.0, 0.0], dtype=float)
        (t0,x0,y0), (t1,x1,y1) = hist[-2], hist[-1]
        dt = max(1e-6, t1 - t0)
        return np.array([(x1-x0)/dt, (y1-y0)/dt], dtype=float)

    def _proj_speeds(self, ankle_hist: deque):
        v = self._last_v2d(ankle_hist)
        t_hat, n_hat = self._ground_axes()
        vt = abs(float(v @ t_hat))
        vn = abs(float(v @ n_hat))
        return vt, vn

    def _contact_ratios(self, hist: FootHistory, frames: int):
        labels = list(hist.state_labels)[-frames:]
        if not labels:
            return 0,0,0,0,1
        toe_only = sum(s in ("Toe Dominant", "Tiptoeing", "Wheel-Toe") for s in labels)
        heel_only= sum(s in ("Heel Strike", "Wheel-Heel") for s in labels)
        flat    = sum(s in ("Flat Foot",) for s in labels)
        near    = sum(s in ("Near Ground",) for s in labels)
        swing   = sum(s in ("Swing",) for s in labels)
        d = max(1, len(labels))
        return toe_only/d, heel_only/d, flat/d, near/d, swing/d

    @staticmethod
    def _grounded_state(s: str) -> bool:
        if not s: return False
        return s in {"Flat Foot","Toe Dominant","Tiptoeing","Heel Strike","Near Ground","Wheel-Toe","Wheel-Heel"}

    def _classify_contact(self, ankle_xy, heel_xy, toe_xy):
        dy_heel = self._dist_to_ground(*heel_xy)
        dy_toe  = self._dist_to_ground(*toe_xy)
        heel_on = abs(dy_heel) < self.vertical_thresh
        toe_on  = abs(dy_toe)  < self.vertical_thresh
        near_g  = (not heel_on and not toe_on) and (abs(dy_heel) < self.near_thresh or abs(dy_toe) < self.near_thresh)
        ankle_high = (ankle_xy[1] + self.vertical_thresh) < min(heel_xy[1], toe_xy[1])
        if heel_on and not toe_on and toe_xy[1] < heel_xy[1]:
            return "Heel Strike", heel_on, toe_on, near_g
        if toe_on and not heel_on and heel_xy[1] < toe_xy[1]:
            if ankle_high and (heel_xy[1] + self.vertical_thresh/2) < toe_xy[1]:
                return "Tiptoeing", heel_on, toe_on, near_g
            return "Toe Dominant", heel_on, toe_on, near_g
        if heel_on and toe_on:
            return "Flat Foot", heel_on, toe_on, near_g
        if not heel_on and not toe_on:
            return ("Swing" if min(heel_xy[1],toe_xy[1]) + 2*self.vertical_thresh < (ankle_xy[1] + 2*self.vertical_thresh) else "Near Ground"), heel_on, toe_on, near_g
        return "Indeterminate", heel_on, toe_on, near_g

    def _pick_stance(self):
        def score(hist: FootHistory):
            on_frac = (sum(hist.contact_flags) / max(1, len(hist.contact_flags)))
            speed   = self._horiz_speed(hist.ankle)
            return on_frac - 0.002*speed
        Ls = score(self.left_hist)
        Rs = score(self.right_hist)
        return 'left' if Ls >= Rs else 'right'

    def _wheeling(self, hist: FootHistory, kps: np.ndarray, scores: np.ndarray, side: str):
        if len(hist.ankle) < 3:
            return {'label': None, 'angle': None}

        # --- 1. Select Indices based on Side ---
        if side == 'left':
            KNEE_IDX, ANKLE_IDX = KEYPOINT_DICT["L_KNEE"], KEYPOINT_DICT["L_ANKLE"]
            BIG_TOE_IDX, SMALL_TOE_IDX = KEYPOINT_DICT["L_BIG_TOE"], KEYPOINT_DICT["L_SMALL_TOE"]
        else:
            KNEE_IDX, ANKLE_IDX = KEYPOINT_DICT["R_KNEE"], KEYPOINT_DICT["R_ANKLE"]
            BIG_TOE_IDX, SMALL_TOE_IDX = KEYPOINT_DICT["R_BIG_TOE"], KEYPOINT_DICT["R_SMALL_TOE"]

        # --- 2. Validate Scores & Get Coordinates ---
        if not (scores[KNEE_IDX] >= self.min_conf and scores[ANKLE_IDX] >= self.min_conf):
            return {'label': None, 'angle': None}
            
        knee_xy = kps[KNEE_IDX, :2]
        ankle_xy = kps[ANKLE_IDX, :2]

        # Handle Toe: Try midpoint of big/small, fallback to single
        s_big, s_small = scores[BIG_TOE_IDX], scores[SMALL_TOE_IDX]
        if s_big >= self.min_conf and s_small >= self.min_conf:
            toe_xy = (kps[BIG_TOE_IDX, :2] + kps[SMALL_TOE_IDX, :2]) / 2.0
        elif s_big >= self.min_conf:
            toe_xy = kps[BIG_TOE_IDX, :2]
        elif s_small >= self.min_conf:
            toe_xy = kps[SMALL_TOE_IDX, :2]
        else:
            # Cannot calculate foot angle without a toe
            return {'label': None, 'angle': None}

        # --- 3. Calculate Vectors & Angle ---
        # Vector 1: Ankle -> Knee (Shin direction)
        v_shin = knee_xy - ankle_xy
        # Vector 2: Ankle -> Toe (Foot direction)
        v_foot = toe_xy - ankle_xy
        
        # Calculate Angle
        ang = abs(angle_between(v_shin, v_foot))

        # --- 4. Labeling (Dorsiflexion vs Extension) ---
        # Note: In 2D side view, <90 is dorsiflexion (knee over toe).
        # In Front view, this calculates skate lean (inversion/eversion).
        label = None
        
        # Retrieve history stats for context
        vt, vn = self._proj_speeds(hist.ankle)
        vert_stab = self._vert_std(hist.ankle)
        
        # Only label if moving and relatively stable
        if vt >= self.min_glide_speed and vert_stab <= self.stable_vert_std:
            if ang < 80:
                label = "Deep Bend" # High Dorsiflexion
            elif ang > 110:
                label = "Extended"  # Plantarflexion / Push extension
            elif 80 <= ang <= 110:
                label = "Neutral"

        return {"label": label, "angle": float(round(ang, 1))}


    def _momentum_kick(self, swing_hist: FootHistory, stance_hist: FootHistory):
        if len(swing_hist.ankle) < 3 or len(stance_hist.ankle) < 3:
            return False
        v_sw = self._horiz_speed(swing_hist.ankle)
        stance_stable = self._vert_std(stance_hist.ankle) < self.stable_vert_std and any(stance_hist.contact_flags)
        recent = list(swing_hist.state_labels)[-6:]
        toe_like = any(s in ("Toe Dominant","Tiptoeing","Heel Strike") for s in recent)
        return (v_sw >= self.min_push_speed) and stance_stable and toe_like

    def _grounded_ratio(self, hist: FootHistory, frames=8):
        labels = list(hist.state_labels)[-frames:]
        if not labels:
            return 0.0
        return sum(self._grounded_state(s) for s in labels) / max(1, len(labels))

    def _modes(self):
        vLt, _ = self._proj_speeds(self.left_hist.ankle)  if len(self.left_hist.ankle)  >= 2 else (0.0, 0.0)
        vRt, _ = self._proj_speeds(self.right_hist.ankle) if len(self.right_hist.ankle) >= 2 else (0.0, 0.0)
        L_roll = vLt >= self.min_glide_speed
        R_roll = vRt >= self.min_glide_speed
        L_gr = self._grounded_ratio(self.left_hist, 8)
        R_gr = self._grounded_ratio(self.right_hist, 8)
        L_grounded, R_grounded = L_gr >= 0.5, R_gr >= 0.5
        L_wh, R_wh = self._wheeling(self.left_hist), self._wheeling(self.right_hist)
        t_hat, _ = self._ground_axes()
        vL_t_s = float(self._last_v2d(self.left_hist.ankle)  @ t_hat)
        vR_t_s = float(self._last_v2d(self.right_hist.ankle) @ t_hat)
        cos_sim = (vL_t_s * vR_t_s) / (abs(vL_t_s) * abs(vR_t_s) + 1e-6)
        two_wheels = (L_wh is not None) and (R_wh is not None)
        single_wheel = (L_wh is not None) ^ (R_wh is not None)
        two_foot = (L_grounded and R_grounded and L_roll and R_roll and cos_sim >= 0.4)
        single_leg = ((L_grounded and L_roll and not (R_grounded and R_roll)) ^
                      (R_grounded and R_roll and not (L_grounded and L_roll)))
        return {
            'single_wheel_skate': bool(single_wheel),
            'single_leg_skate':   bool(single_leg),
            'two_foot_skate':     bool(two_foot),
            'two_wheels_skate':   bool(two_wheels),
            'details': {
                'left':  {'wheeling': L_wh, 'grounded_ratio': L_gr, 'vt': vLt},
                'right': {'wheeling': R_wh, 'grounded_ratio': R_gr, 'vt': vRt},
                'cos_sim': float(cos_sim)
            }
        }

    @staticmethod
    def _unwrap(angles):
        if len(angles) <= 1:
            return list(angles)
        return list(np.unwrap(np.asarray(angles, dtype=float)))

    def _window(self, hist: deque, t_now: float, window_s: float):
        if not hist:
            return []
        lo = t_now - window_s
        out = []
        for (t,x,y) in reversed(hist):
            if t < lo:
                break
            out.append((t,x,y))
        return list(reversed(out))

    def _ankle_horiz_speeds(self, ankle_hist: deque, t_now: float, window_s: float):
        seg = self._window(ankle_hist, t_now, window_s)
        if len(seg) < 2:
            return []
        speeds = []
        for (t0,x0,y0), (t1,x1,y1) in zip(seg[:-1], seg[1:]):
            dt = max(1e-6, t1 - t0)
            speeds.append(abs((x1-x0)/dt))
        return speeds

    @staticmethod
    def _foot_yaw(heel_xy, toe_xy):
        dx = toe_xy[0]-heel_xy[0]
        dy = toe_xy[1]-heel_xy[1]
        return math.atan2(dy, dx)

    def _twist_metrics(self, hist: FootHistory, t_now: float):
        heel_seg = self._window(hist.heel, t_now, self.twist_window_s)
        toe_seg  = self._window(hist.toe,  t_now, self.twist_window_s)
        if len(heel_seg) < 2 or len(toe_seg) < 2:
            return None
        n = min(len(heel_seg), len(toe_seg))
        heel_seg, toe_seg = heel_seg[-n:], toe_seg[-n:]
        ts = [p[0] for p in heel_seg]
        yaws = [self._foot_yaw((h[1],h[2]), (t_[1],t_[2])) for h,t_ in zip(heel_seg, toe_seg)]
        yaws_u = self._unwrap(yaws)
        if len(yaws_u) < 2:
            return None
        oma = []
        for (th0,t0),(th1,t1) in zip(zip(yaws_u[:-1], ts[:-1]), zip(yaws_u[1:], ts[1:])):
            dt = max(1e-6, t1 - t0)
            oma.append((th1 - th0)/dt)
        trans = self._ankle_horiz_speeds(hist.ankle, t_now, self.twist_window_s)
        trans_mean = float(np.mean(trans)) if trans else 0.0
        recent_states = list(hist.state_labels)[-len(ts):] if hist.state_labels else []
        grounded_ratio = sum(self._grounded_state(s) for s in recent_states) / max(1, len(recent_states))
        last_state = recent_states[-1] if recent_states else None
        delta_theta = yaws_u[-1] - yaws_u[0]
        omega_peak  = float(max((abs(w) for w in oma), default=0.0))
        omega_mean  = float(np.mean(np.abs(oma))) if oma else 0.0
        direction   = 'ccw' if delta_theta > 0 else ('cw' if delta_theta < 0 else 'flat')
        post_speed  = self._horiz_speed(hist.ankle) if len(hist.ankle) >= 2 else 0.0
        return {
            'delta_theta': float(delta_theta),
            'omega_peak': omega_peak,
            'omega_mean': omega_mean,
            'direction': direction,
            'grounded_ratio': float(grounded_ratio),
            'last_state': last_state,
            'trans_speed_mean': trans_mean,
            'post_speed': float(post_speed)
        }

    def _twist_detect(self, m):
        if not m:
            return False, False, None, 0.0
        grounded_ok = m['grounded_ratio'] >= 0.4
        ls = m.get('last_state')
        if ls in ('Tiptoeing','Toe Dominant','Wheel-Toe'): pivot = 'toe'
        elif ls in ('Heel Strike','Wheel-Heel'): pivot = 'heel'
        elif ls in ('Flat Foot','Near Ground'): pivot = 'flat'
        else: pivot = 'mixed'
        trans_gate = m['trans_speed_mean'] <= (self.max_trans_speed_during_twist * (1.35 if pivot=='flat' else 1.0))
        twist = grounded_ok and (m['omega_peak'] >= self.min_twist_omega) and (abs(m['delta_theta']) >= self.min_twist_angle) and trans_gate
        twist_push = twist and (m['post_speed'] >= self.min_post_twist_speed)
        score = (
            (1.0 if grounded_ok else 0.5) *
            abs(m['delta_theta']) *
            (m['omega_peak'] / max(1e-6, self.min_twist_omega)) *
            (1.0 + (1.0 - min(1.0, m['trans_speed_mean'] / max(1e-6, self.max_trans_speed_during_twist))))
        )
        return twist, twist_push, pivot, float(score)

    def update(self, kps: np.ndarray, scores: np.ndarray, t: float, leg_choice: str = 'auto'):
        def ok(i):
            return scores[i] >= self.min_conf
        def get_xy(i):
            return (float(kps[i,0]), float(kps[i,1]))

        # --- Basic Classification (Left/Right) ---
        # (This logic remains the same as original)
        low = {
            'left':  {'state':'Unknown (Low Confidence)'},
            'right': {'state':'Unknown (Low Confidence)'},
            'wheeling': {'left':{'angle':None, 'label':None},'right':{'angle':None, 'label':None}},
            'modes': {'details':{}},
            'ground_line': self.ground_line
        }
        
        # Left
        if not (ok(L_ANKLE) and ok(L_HEEL) and ok(L_BIG_TOE)):
            left_state, L_flags = 'Unknown (Low Confidence)', (False,False,False)
        else:
            L_ank, L_heel, L_toe = get_xy(L_ANKLE), get_xy(L_HEEL), get_xy(L_BIG_TOE)
            left_state, L_heel_on, L_toe_on, L_near = self._classify_contact(L_ank, L_heel, L_toe)
            L_flags = (L_heel_on, L_toe_on, L_near)

        # Right
        if not (ok(R_ANKLE) and ok(R_HEEL) and ok(R_BIG_TOE)):
            right_state, R_flags = 'Unknown (Low Confidence)', (False,False,False)
        else:
            R_ank, R_heel, R_toe = get_xy(R_ANKLE), get_xy(R_HEEL), get_xy(R_BIG_TOE)
            right_state, R_heel_on, R_toe_on, R_near = self._classify_contact(R_ank, R_heel, R_toe)
            R_flags = (R_heel_on, R_toe_on, R_near)

        if 'Unknown' in left_state and 'Unknown' in right_state:
            return low

        # Update Histories
        if 'Unknown' not in left_state:
            self._append_hist(self.left_hist, t, L_ank, L_heel, L_toe, any(L_flags[:2]), left_state)
            if L_flags[0]: self.ground_pts.append((L_heel[0], L_heel[1]))
            if L_flags[1]: self.ground_pts.append((L_toe[0],  L_toe[1]))
        if 'Unknown' not in right_state:
            self._append_hist(self.right_hist, t, R_ank, R_heel, R_toe, any(R_flags[:2]), right_state)
            if R_flags[0]: self.ground_pts.append((R_heel[0], R_heel[1]))
            if R_flags[1]: self.ground_pts.append((R_toe[0],  R_toe[1]))

        # Fit Ground & Analyze Wheeling
        self._fit_ground_line()
        
        L_wheel = self._wheeling(self.left_hist, kps, scores, 'left')
        R_wheel = self._wheeling(self.right_hist, kps, scores, 'right')

        # REMOVED: Stance, Momentum Kick, Twist metrics

        return {
            'left': {'state': left_state},
            'right':{'state': right_state},
            'wheeling': {'left': L_wheel, 'right': R_wheel},
            'ground_line': self.ground_line
        }

# -----------------------------
# Posture & knee angles / body stability
# -----------------------------

def posture_metrics(kps: np.ndarray, scores: np.ndarray, conf_thr=0.3):
    l_sh, r_sh = kps[L_SHOULDER], kps[R_SHOULDER]
    l_hp, r_hp = kps[L_HIP], kps[R_HIP]
    l_ear, r_ear = kps[L_EAR], kps[R_EAR]
    nose = kps[NOSE]
    torso_angle = None
    esh_l = None
    esh_r = None
    head_center = None
    ms = midpoint(l_sh, r_sh, conf_thr) if (scores[L_SHOULDER]>=conf_thr and scores[R_SHOULDER]>=conf_thr) else None
    mh = midpoint(l_hp, r_hp, conf_thr) if (scores[L_HIP]>=conf_thr and scores[R_HIP]>=conf_thr) else None
    if ms is not None and mh is not None:
        torso_vec = vec(mh, ms)
        if torso_vec is not None:
            torso_angle = angle_between(torso_vec, np.array([0.0, -1.0]))
    if scores[L_EAR]>=conf_thr and scores[L_SHOULDER]>=conf_thr and scores[L_HIP]>=conf_thr:
        esh_l = angle_between(vec(l_sh, l_ear), vec(l_sh, l_hp))
    if scores[R_EAR]>=conf_thr and scores[R_SHOULDER]>=conf_thr and scores[R_HIP]>=conf_thr:
        esh_r = angle_between(vec(r_sh, r_ear), vec(r_sh, r_hp))
    head_kps = []
    for idx in (NOSE, L_EAR, R_EAR):
        if scores[idx] >= conf_thr:
            head_kps.append(kps[idx,:2])
    if head_kps:
        head_center = np.mean(np.asarray(head_kps), axis=0)
    return torso_angle, esh_l, esh_r, head_center


def knee_angle(kps: np.ndarray, scores: np.ndarray, side: str, conf_thr=0.3):
    if side == 'left':
        hip, knee, ankle = L_HIP, L_KNEE, L_ANKLE
    else:
        hip, knee, ankle = R_HIP, R_KNEE, R_ANKLE
    if min(scores[hip], scores[knee], scores[ankle]) < conf_thr:
        return None
    thigh = vec(kps[hip], kps[knee])
    shank = vec(kps[knee], kps[ankle])
    if thigh is None or shank is None:
        return None
    return 180.0 - angle_between(thigh, shank)

# -----------------------------
# Simple multi-person tracker (greedy IoU) + main-subject chooser
# -----------------------------
class BoxTracker:
    def __init__(self, iou_thr=0.2):
        self.iou_thr = iou_thr
        self.tracks = {}  # id -> {'box': np.array([x1,y1,x2,y2]), 'age': int}
        self.next_id = 0
        self.active_id = None

    @staticmethod
    def _center(b):
        x1,y1,x2,y2 = b
        return ((x1+x2)/2.0, (y1+y2)/2.0)

    def _score_main(self, b, W, H):
        cx, cy = self._center(b)
        dx = (cx - W/2.0) / max(1.0, W)
        dy = (cy - H/2.0) / max(1.0, H)
        area = max(1.0, (b[2]-b[0])*(b[3]-b[1]))
        # larger and more centered -> higher score
        return (1.0 - (abs(dx)+abs(dy))) + 0.00001 * area

    def update(self, boxes: List[np.ndarray], frame_shape):
        H, W = frame_shape[:2]
        # match
        assigned = {}
        used = set()
        for tid, t in list(self.tracks.items()):
            best_i, best_iou = -1, 0.0
            for i, b in enumerate(boxes):
                if i in used:
                    continue
                iou = iou_xyxy(t['box'], b)
                if iou > best_iou:
                    best_iou, best_i = iou, i
            if best_i >= 0 and best_iou >= self.iou_thr:
                assigned[tid] = best_i
                used.add(best_i)
                self.tracks[tid]['box'] = boxes[best_i]
                self.tracks[tid]['age'] = 0
            else:
                self.tracks[tid]['age'] += 1
        # add new tracks
        for i, b in enumerate(boxes):
            if i in used:
                continue
            self.tracks[self.next_id] = {'box': b, 'age': 0}
            self.next_id += 1
        # remove stale
        for tid in list(self.tracks.keys()):
            if self.tracks[tid]['age'] > 30:
                del self.tracks[tid]
        # choose main subject: highest center+size score among tracks updated this frame
        best_tid, best_sc = None, -1e9
        for tid, t in self.tracks.items():
            sc = self._score_main(t['box'], W, H)
            if sc > best_sc:
                best_sc, best_tid = sc, tid
        self.active_id = best_tid
        return best_tid

# -----------------------------
# Metric logger + v2 runner
class MetricLogger:
    def __init__(self, path: Optional[str]):
        self.path = path
        self.f = None
        if path:
            import csv
            self.csv = csv
            self.f = open(path, 'w', newline='')
            self.w = csv.DictWriter(self.f, fieldnames=[
                'frame','time_s','subject','stance','wheelL','wheelR','kickL','kickR',
                'twistL','twistR','torso_deg','kneeL_deg','kneeR_deg','hands','vt_left','vt_right','ground_slope'
            ])
            self.w.writeheader()
    def write_row(self, row: Dict):
        if self.f:
            row = {k: (str(v) if isinstance(v, (dict,list,tuple)) else v) for k,v in row.items()}
            self.w.writerow(row)
    def close(self):
        if self.f:
            self.f.flush(); self.f.close(); self.f = None

def analyze_video_v2(video_path: str,
                  yolo_path: str,
                  rtm_path: str,
                  process_every_n: int = 2,
                  yolo_size=(640,640),
                  rtm_size=(288,384),
                  backend='onnxruntime',
                  device='cuda',
                  out_path: Optional[str] = None,
                  max_subjects: int = 1,
                  log_csv: Optional[str] = None,
                  show_window: bool = True):
    if YOLOX is None or RTMPose is None:
        raise RuntimeError("rtmlib not available. Please install/import YOLOX and RTMPose wrappers.")
    det = YOLOX(onnx_model=yolo_path, model_input_size=yolo_size, backend=backend, device=device)
    pose = RTMPose(onnx_model=rtm_path, model_input_size=rtm_size, backend=backend, device=device)
    logger = MetricLogger(log_csv) if log_csv else None
    analyzer = SkatingVideoAnalyzer(det, pose, process_every_n=process_every_n, max_subjects=max_subjects, logger=logger)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video: {video_path}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    analyzer.fps = float(fps)
    W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    writer = None
    if out_path:
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H))
    fidx = 0
    t0 = time.time()
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        ts = time.time() - t0
        out = analyzer.process(frame, fidx, ts)
        if writer is not None:
            writer.write(out)
        elif show_window:
            cv2.imshow('Skate Analyzer', out)
            if cv2.waitKey(1) & 0xFF == 27:
                break
        fidx += 1
    cap.release()
    if writer is not None:
        writer.release()
    if show_window:
        cv2.destroyAllWindows()
    if logger is not None:
        logger.close()

class LRTemporalResolver:
    """
    Keeps left/right labels stable across turns by minimizing temporal cost
    (no pelvis frame). Works per subject. Call resolve(kps, scores) each update.
    """
    def __init__(self,
                 conf_thr=0.3,
                 min_swap_improve=12.0,   # pixels: how much better the swap must be
                 persist_frames=2,         # how many consecutive frames before committing a swap
                 ema_alpha=0.6,            # smoothing for saved state
                 use_lengths_penalty=True):
        self.conf_thr = float(conf_thr)
        self.min_swap_improve = float(min_swap_improve)
        self.persist_frames = int(persist_frames)
        self.ema_alpha = float(ema_alpha)
        self.use_lengths_penalty = bool(use_lengths_penalty)
        self.prev = None  # dict of last smoothed positions per joint
        self._swap_streak = 0

    def _p(self, kps, scores, idx):
        if idx is None: return None
        if scores[idx] < self.conf_thr: return None
        return np.array(kps[idx,:2], dtype=float)

    def _d2(self, a, b):
        if a is None or b is None: return 0.0
        d = np.linalg.norm(a - b)
        return float(d)

    def _leg_cost(self, assign, cur, prev):
        # assign: {'L': {'ankle':pt,...}, 'R': {...}}
        # cost = sum of ankle/knee/heel/toe distances vs prev
        w = {'ankle': 1.0, 'knee': 0.7, 'heel': 0.5, 'toe': 0.5}
        cost = 0.0
        for side in ('L','R'):
            for part in ('ankle','knee','heel','toe'):
                cost += w[part] * self._d2(assign[side].get(part), prev[side].get(part))
        # optional: length consistency penalty (keeps thigh/shank similar to last frame)
        if self.use_lengths_penalty:
            def seglen(a, b):
                if a is None or b is None:
                    return None
                return float(np.linalg.norm(a - b))
            for side in ('L','R'):
                cur_thigh = seglen(assign[side]['hip'],  assign[side]['knee'])
                cur_shank = seglen(assign[side]['knee'], assign[side]['ankle'])
                prv_thigh = seglen(prev[side]['hip'],    prev[side]['knee'])
                prv_shank = seglen(prev[side]['knee'],   prev[side]['ankle'])
                if cur_thigh is not None and prv_thigh is not None:
                    cost += 0.4 * abs(cur_thigh - prv_thigh)
                if cur_shank is not None and prv_shank is not None:
                    cost += 0.4 * abs(cur_shank - prv_shank)
        return cost

    def _pack(self, kps, scores, order=('L','R')):
        sides = {}
        for side, A,B,T,H in (
            (order[0], L_ANKLE, L_KNEE, L_BIG_TOE, L_HEEL),
            (order[1], R_ANKLE, R_KNEE, R_BIG_TOE, R_HEEL),
        ):
            sides[side] = {
                'ankle': self._p(kps, scores, A),
                'knee':  self._p(kps, scores, B),
                'toe':   self._p(kps, scores, T),
                'heel':  self._p(kps, scores, H),
                'hip':   self._p(kps, scores, (L_HIP if side=='L' else R_HIP)),
            }
        return sides

    def _ema_update(self, prev, cur):
        a = self.ema_alpha
        for side in ('L','R'):
            for k in ('ankle','knee','heel','toe','hip'):
                c = cur[side].get(k)
                p = prev[side].get(k)
                if c is None:
                    continue
                if p is None:
                    prev[side][k] = c.copy()
                else:
                    prev[side][k] = (1-a)*p + a*c
        return prev

    def resolve(self, kps, scores):
        # Build two assignments: identity (as given) vs swapped
        cur_id = self._pack(kps, scores, order=('L','R'))
        cur_sw = self._pack(kps, scores, order=('R','L'))

        # Initialize prev if first time
        if self.prev is None:
            self.prev = {'L':{}, 'R':{}}
            self.prev = self._ema_update(self.prev, cur_id)
            return kps, scores  # keep as-is on first valid frame

        # Compute costs vs previous smoothed state
        cost_id = self._leg_cost(assign=cur_id, cur=cur_id, prev=self.prev)
        cost_sw = self._leg_cost(assign=cur_sw, cur=cur_sw, prev=self.prev)

        # Decide with hysteresis
        if cost_sw + self.min_swap_improve < cost_id:
            self._swap_streak += 1
        else:
            self._swap_streak = 0

        do_swap = (self._swap_streak >= self.persist_frames)

        # Update prev with the chosen assignment (EMA)
        chosen = cur_sw if do_swap else cur_id
        self.prev = self._ema_update(self.prev, chosen)

        if not do_swap:
            return kps, scores
        # Perform the label swap in-place (swap L* <-> R* indices)
        K = np.asarray(kps).copy()
        S = np.asarray(scores).copy()
        swap_pairs = [
            (L_ANKLE, R_ANKLE), (L_KNEE, R_KNEE), (L_HEEL, R_HEEL), (L_BIG_TOE, R_BIG_TOE),
            (L_HIP, R_HIP), (L_SHOULDER, R_SHOULDER), (L_ELBOW, R_ELBOW), (L_WRIST, R_WRIST),
            (L_EAR, R_EAR)
        ]
        for a,b in swap_pairs:
            K[[a,b]] = K[[b,a]]
            S[[a,b]] = S[[b,a]]
        return K, S





# -----------------------------
# Main Analyzer wrapper
# -----------------------------
class SkatingVideoAnalyzer:
    def __init__(self,
                 yolo_model: YOLOX,
                 rtm_model: RTMPose,
                 process_every_n=2,
                 conf_thr=0.3,
                 draw_pairs: Optional[List[Tuple[int,int]]] = None,
                 max_subjects: int = 1,
                 logger: Optional["MetricLogger"] = None):
        self.det = yolo_model
        self.pose = rtm_model
        self.nskip = max(1, int(process_every_n))
        self.conf_thr = float(conf_thr)
        self.tracker = BoxTracker(iou_thr=0.2)
        self.max_subjects = max(1, int(max_subjects))
        self.draw_pairs = draw_pairs
        self.fps = 30.0
        self.logger = logger
        # per-subject state
        self.subjects = []  # list of dicts per subject

    def _ensure_subjects(self, n):
        while len(self.subjects) < n:
            self.subjects.append({
                'contact': SkateContactAnalyzer(
                    left_idx=FootIndices(L_ANKLE, L_HEEL, L_BIG_TOE, None),
                    right_idx=FootIndices(R_ANKLE, R_HEEL, R_BIG_TOE, None),
                    history_len=160
                ),
                'last_kps': None,
                'last_scores': None,
                'torso_angles': deque(maxlen=60),
                'head_y': deque(maxlen=60),
                'L_knee': deque(maxlen=60),
                'R_knee': deque(maxlen=60),
                'hip_hist': deque(maxlen=6),
                'lr_resolver': LRTemporalResolver(conf_thr=self.conf_thr,
                                                  min_swap_improve=12.0,
                                                  persist_frames=2,
                                                  ema_alpha=0.6)
            })
        if len(self.subjects) > n:
            self.subjects = self.subjects[:n]

    # --- adapt these two if your wrappers differ ---
    def _yolo_detect(self, frame: np.ndarray):
        out = self.det(frame)
        boxes, scores = [], []

        if out is None:
            return boxes, scores

        # Dict style
        if isinstance(out, dict) and 'boxes' in out:
            arr = np.asarray(out['boxes'], dtype=float)
            if arr.ndim == 1 and arr.size >= 4:
                arr = arr.reshape(1, -1)
            if arr.ndim == 2 and arr.shape[1] >= 4:
                for row in arr:
                    boxes.append(row[:4].copy())
                    scores.append(float(row[4]) if arr.shape[1] >= 5 else 1.0)
            else:
                for b in out['boxes']:
                    b = np.asarray(b, dtype=float).ravel()
                    if b.size >= 4:
                        boxes.append(b[:4].copy())
                        scores.append(float(b[4]) if b.size >= 5 else 1.0)
            s = out.get('scores')
            if s is not None and not scores:  # keep dict’s scores if provided separately
                s = np.asarray(s).ravel().tolist()
                scores = [float(x) for x in s]
                if len(scores) != len(boxes):
                    scores = [1.0] * len(boxes)
            elif not scores:
                scores = [1.0] * len(boxes)
            return boxes, scores

        # ndarray style
        if isinstance(out, np.ndarray):
            arr = np.asarray(out, dtype=float)
            if arr.ndim == 1 and arr.size >= 4:
                arr = arr.reshape(1, -1)
            if arr.ndim == 2 and arr.shape[1] >= 4:
                for row in arr:
                    boxes.append(row[:4].copy())
                    scores.append(float(row[4]) if arr.shape[1] >= 5 else 1.0)
            return boxes, scores

        # list/tuple style
        if isinstance(out, (list, tuple)):
            for o in out:
                o = np.asarray(o, dtype=float).ravel()
                if o.size >= 4:
                    boxes.append(o[:4].copy())
                    scores.append(float(o[4]) if o.size >= 5 else 1.0)
            return boxes, scores

        return boxes, scores

    def _rtm_infer(self, frame: np.ndarray, boxes: List[np.ndarray]):
        if not boxes:
            return [], []
        bxs = [np.asarray(b, dtype=float).ravel()[:4].tolist() for b in boxes]
        try:
            res = self.pose(frame, bboxes=bxs)  # your working signature
        except TypeError:
            res = self.pose(frame, bxs)         # fallback
        except Exception:
            return [], []

        if isinstance(res, tuple) and len(res) == 2:
            k_list, s_list = res
        elif isinstance(res, dict):
            k_list = res.get('keypoints') or res.get('kps') or []
            s_list = res.get('scores')
        else:
            k_list, s_list = res, None

        out_k, out_s = [], []
        for i in range(min(len(bxs), len(k_list))):
            kps = np.asarray(k_list[i], dtype=float)
            if kps.ndim != 2 or kps.shape[1] < 2:
                continue
            if kps.shape[1] >= 3:
                scores = kps[:, 2].astype(float)
                kps = kps[:, :2].astype(float)
            else:
                scores = np.ones((kps.shape[0],), dtype=float)
                kps = kps[:, :2].astype(float)
            out_k.append(kps)
            out_s.append(scores)
        return out_k, out_s


    def process(self, frame: np.ndarray, fidx: int, ts: float) -> np.ndarray:
        H, W = frame.shape[:2]
        do_proc = (fidx % self.nskip == 0)
        boxes: List[np.ndarray] = []

        if do_proc:
            boxes, scores = self._yolo_detect(frame)
            boxes = [np.clip(b, [0,0,0,0], [W-1,H-1,W-1,H-1]) for b in boxes]
            self.tracker.update(boxes, frame.shape)
            boxes = sorted(boxes, key=lambda b: self.tracker._score_main(b, W, H), reverse=True)
            boxes = boxes[:self.max_subjects]
            self._ensure_subjects(len(boxes))
            kps_list, scr_list = self._rtm_infer(frame, boxes)
            while len(kps_list) < len(boxes):
                kps_list.append(None); scr_list.append(None)
            for i in range(len(boxes)):
                subj = self.subjects[i]
                kps, kp_scores = kps_list[i], scr_list[i]
                if kps is not None:
                    # ONLY use the smart resolver. 
                    # Do NOT use fix_lr_per_frame(kps, kp_scores) here.
                    kps_fixed, scores_fixed = subj['lr_resolver'].resolve(kps, kp_scores)
                    subj['last_kps'], subj['last_scores'] = kps_fixed, scores_fixed

        for i, subj in enumerate(self.subjects):
            kps = subj['last_kps']; scores = subj['last_scores']
            if kps is None: continue

            contact: SkateContactAnalyzer = subj['contact']
            out = contact.update(kps, scores, ts, leg_choice='auto')

            torso_ang, esh_l, esh_r, head_ctr = posture_metrics(kps, scores, self.conf_thr)
            L_k = knee_angle(kps, scores, 'left', self.conf_thr)
            R_k = knee_angle(kps, scores, 'right', self.conf_thr)
            if torso_ang is not None: subj['torso_angles'].append(float(torso_ang))
            if head_ctr is not None:  subj['head_y'].append(float(head_ctr[1]))
            if L_k is not None:      subj['L_knee'].append(float(L_k))
            if R_k is not None:      subj['R_knee'].append(float(R_k))

            hand_stab = self._hand_stability(i, kps, scores)
            bbox = boxes[i] if i < len(boxes) else self._bbox_from_kps(kps, scores)
            
            frame = self._draw_subject(frame, bbox, kps, scores, out, torso_ang, L_k, R_k, hand_stab, subj)

            # Logger update (Removed deleted keys to avoid crash)
            if self.logger is not None:
                slope = contact.ground_line[0] if contact.ground_line else 0.0
                # Safe extraction of angles
                waL = out['wheeling']['left'].get('angle') if out['wheeling']['left'] else None
                waR = out['wheeling']['right'].get('angle') if out['wheeling']['right'] else None
                
                self.logger.write_row({
                    'frame': fidx, 'time_s': round(ts,4), 'subject': i,
                    'wheelL_ang': waL, 'wheelR_ang': waR,
                    'torso_deg': None if torso_ang is None else float(torso_ang),
                    'kneeL_deg': None if L_k is None else float(L_k),
                    'kneeR_deg': None if R_k is None else float(R_k),
                    'ground_slope': slope
                })
        return frame

    def _bbox_from_kps(self, kps, scores):
        vis = scores >= self.conf_thr
        if not np.any(vis):
            return np.array([0,0,0,0], dtype=float)
        xs, ys = kps[vis,0], kps[vis,1]
        return np.array([float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())], dtype=float)

    def _hand_stability(self, idx, kps, scores, window=10):
        key = f'_wr_{idx}'
        if not hasattr(self, key):
            setattr(self, key, {'L': deque(maxlen=window), 'R': deque(maxlen=window)})
        hist = getattr(self, key)
        if scores[L_WRIST] >= self.conf_thr:
            hist['L'].append(tuple(kps[L_WRIST,:2]))
        if scores[R_WRIST] >= self.conf_thr:
            hist['R'].append(tuple(kps[R_WRIST,:2]))
        def speed(q):
            if len(q) < 2: return 0.0
            (x0,y0),(x1,y1) = q[-2], q[-1]
            return float(abs(x1-x0)+abs(y1-y0))
        vL, vR = speed(hist['L']), speed(hist['R'])
        label = 'stable' if max(vL,vR) < 6.0 else ('moderate' if max(vL,vR) < 14.0 else 'active')
        return {'left_v': vL, 'right_v': vR, 'label': label}

    def _draw_subject(self, frame, bbox, kps, scores, out, torso_ang, L_k, R_k, hand_stab, subj):
        # --- Normalize bbox ---
        b = np.asarray(bbox, dtype=float).ravel()
        if b.size < 4 or not np.all(np.isfinite(b[:4])):
            b = self._bbox_from_kps(kps, scores)
        else:
            b = b[:4]
        H, W = frame.shape[:2]
        b = np.clip(b, [0, 0, 0, 0], [W - 1, H - 1, W - 1, H - 1])

        # --- Draw bbox & Skeleton ---
        draw_bbox(frame, [b])
        
        if kps.ndim == 2: kps_vis = kps[None, :, :]
        else: kps_vis = kps
        if scores.ndim == 1: scores_vis = scores[:, None][None, :, :]
        else: scores_vis = scores
        
        kps_vis = np.nan_to_num(kps_vis, copy=False)
        scores_vis = np.nan_to_num(scores_vis, copy=False)
        
        draw_skeleton(frame, kps_vis, scores_vis, openpose_skeleton=False, kpt_thr=self.conf_thr)

        # --- Formatting Helper for Wheeling Angle ---
        def fmt_wheel(w_dict):
            if not w_dict or w_dict.get('angle') is None:
                return "-"
            ang = w_dict['angle']
            lbl = w_dict.get('label')
            return f"{ang}°" + (f" ({lbl})" if lbl else "")

        # --- Logic for Torso Labeling ---
        posture_label = "-"
        upright_angle_str = "-"
        
        if torso_ang is not None:
            val = float(torso_ang)
            if val <= 10:
                posture_label = "Upright"
            elif val <= 30:
                posture_label = "Slight Lean"
            else:
                posture_label = "Aggressive Lean"
            upright_angle_str = f"{val:.1f}°"

        # --- Panel Construction ---
        x1, y1, x2, y2 = map(int, b)
        panel = {
            'Posture':       posture_label,
            'Upright Angle': upright_angle_str,
            'L-Angle':       fmt_wheel(out['wheeling']['left']),
            'R-Angle':       fmt_wheel(out['wheeling']['right']),
            'kneeL°':        None if L_k is None else int(round(L_k)),
            'kneeR°':        None if R_k is None else int(round(R_k)),
            'hands':         (hand_stab['label'] if hand_stab else None),
        }
        
        # Draw the text panel
        self._draw_panel(frame, panel, x=20, y=40)
        
        # Sparklines
        try:
            self._draw_sparkline(frame, list(subj['L_knee']), (x1 + 8,  y2 - 34, 60, 26), label='Lk')
            self._draw_sparkline(frame, list(subj['R_knee']), (x1 + 74, y2 - 34, 60, 26), label='Rk')
        except Exception:
            pass

        return frame





    @staticmethod
    def _draw_panel(frame, panel: Dict, x=12, y=18, lh=18):
        for i, (k,v) in enumerate(panel.items()):
            txt = f"{k}: {v}"
            cv2.putText(frame, txt, (x, y + i*lh), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255,255,255), 1, cv2.LINE_AA)

    @staticmethod
    def _draw_sparkline(frame, values: List[float], rect: Tuple[int,int,int,int], label: str = None):
        x,y,w,h = rect
        cv2.rectangle(frame, (x,y), (x+w, y+h), (200,200,200), 1)
        if not values:
            return
        vs = np.array(values, dtype=float)
        vs = vs[np.isfinite(vs)] if vs.size else vs
        if vs.size == 0:
            return
        vmin, vmax = float(np.min(vs)), float(np.max(vs))
        if abs(vmax - vmin) < 1e-6:
            vmin, vmax = vmax-1.0, vmax+1.0
        xs = np.linspace(0, w-2, num=len(vs))
        ys = h-2 - (vs - vmin) * (h-4) / (vmax - vmin)
        pts = np.vstack([xs + x + 1, ys + y + 1]).T.astype(np.int32)
        for j in range(1, len(pts)):
            cv2.line(frame, tuple(pts[j-1]), tuple(pts[j]), (0,255,255), 1, cv2.LINE_AA)
        if label:
            cv2.putText(frame, label, (x+2, y-2), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (200,200,200), 1, cv2.LINE_AA)

# -----------------------------
# Public utility: analyze a video file
# -----------------------------

def analyze_video(video_path: str,
                  yolo_path: str,
                  rtm_path: str,
                  process_every_n: int = 2,
                  yolo_size=(640,640),
                  rtm_size=(288,384),
                  backend='onnxruntime',
                  device='cuda',
                  out_path: Optional[str] = None):
    # init models
    if YOLOX is None or RTMPose is None:
        raise RuntimeError("rtmlib not available. Please install/import YOLOX and RTMPose wrappers.")
    det = YOLOX(onnx_model=yolo_path, model_input_size=yolo_size, backend=backend, device=device)
    pose = RTMPose(onnx_model=rtm_path, model_input_size=rtm_size, backend=backend, device=device)
    analyzer = SkatingVideoAnalyzer(det, pose, process_every_n=process_every_n)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video: {video_path}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    analyzer.fps = float(fps)
    W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    writer = None
    if out_path:
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H))
    fidx = 0
    t0 = time.time()
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        ts = time.time() - t0
        out = analyzer.process(frame, fidx, ts)
        if writer is not None:
            writer.write(out)
        else:
            cv2.imshow('Skate Analyzer', out)
            if cv2.waitKey(1) & 0xFF == 27:
                break
        fidx += 1
    cap.release()
    if writer is not None:
        writer.release()
    cv2.destroyAllWindows()
    print("finished")


In [2]:
from pathlib import Path

# if the code is in the same notebook cell/file, you can call analyze_video directly:
video = "/mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/spin.mp4"
yolo_path = "/mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/yolox_onnx/yolox_l_8xb8-300e_humanart-ce1d7a62/end2end.onnx"
rtm_path = "/mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/rtmpose_onnx/rtmpose-m_simcc-body7_pt-body7-halpe26_700e-384x288-89e6428b_20230605/end2end.onnx"
# Save to a file instead of opening a window (recommended inside notebooks/headless):
out = str(Path("/mnt/samsung_sata/Skate-Sport-Video-Analysis/annotated.mp4"))
analyze_video(video, yolo_path, rtm_path, process_every_n=2, out_path=out)
print("Wrote:", out)

load /mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/yolox_onnx/yolox_l_8xb8-300e_humanart-ce1d7a62/end2end.onnx with onnxruntime backend
load /mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/rtmpose_onnx/rtmpose-m_simcc-body7_pt-body7-halpe26_700e-384x288-89e6428b_20230605/end2end.onnx with onnxruntime backend


[0;93m2025-11-19 16:26:49.183856626 [W:onnxruntime:, session_state.cc:1280 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-11-19 16:26:49.183867076 [W:onnxruntime:, session_state.cc:1282 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m
[0;93m2025-11-19 16:26:49.310115160 [W:onnxruntime:, session_state.cc:1280 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-11-19 16:26:49.310123240 [W:onnxruntime:, session_state.cc:1282 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m


finished
Wrote: /mnt/samsung_sata/Skate-Sport-Video-Analysis/annotated.mp4


In [None]:
if __name__ == '__main__':
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('--video', required=True, help='Path to input video')
    p.add_argument('--yolo', required=True, help='Path to YOLOX onnx model')
    p.add_argument('--rtm', required=True, help='Path to RTMPose onnx model')
    p.add_argument('--out', default=None, help='Optional path to save annotated video (.mp4)')
    p.add_argument('--nskip', type=int, default=2, help='Process every N frames (>=1)')
    p.add_argument('--max_subjects', type=int, default=1, help='Analyze up to this many people (1-2 recommended)')
    p.add_argument('--log_csv', default=None, help='CSV path to log per-frame metrics')
    p.add_argument('--no_window', action='store_true', help='Disable realtime window display')
    args = p.parse_args()
    analyze_video_v2(
        args.video, args.yolo, args.rtm,
        process_every_n=args.nskip,
        out_path=args.out,
        max_subjects=args.max_subjects,
        log_csv=args.log_csv,
        show_window=(not args.no_window)
    )

In [10]:
# -----------------------------
# If run as a script
# -----------------------------
if __name__ == '__main__':
    import argparse
    p = argparse.ArgumentParser()
    p.add_argument('--video', required=True, help='Path to input video')
    p.add_argument('--yolo', required=True, help='Path to YOLOX onnx model')
    p.add_argument('--rtm', required=True, help='Path to RTMPose onnx model')
    p.add_argument('--out', default=None, help='Optional path to save annotated video (.mp4)')
    p.add_argument('--nskip', type=int, default=2, help='Process every N frames (>=1)')
    args = p.parse_args()
    analyze_video(args.video, args.yolo, args.rtm, process_every_n=args.nskip, out_path=args.out)

usage: ipykernel_launcher.py [-h] --video VIDEO --yolo YOLO --rtm RTM
                             [--out OUT] [--nskip NSKIP]
ipykernel_launcher.py: error: the following arguments are required: --video, --yolo, --rtm


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [13]:
"""
Skating Video Analyzer
----------------------
- Detects people with YOLOX, estimates keypoints with RTMPose
- Robust per-foot contact classification w/ wheeling detection
- Swing/kick (momentum generation) detection
- Twist + twist-push detection (torsion -> translation)
- Body stability (torso uprightness, head bob) -> REMOVED HAND STABILITY
- Knee angles (per leg) over time
- Ground-line fitting (handles slopes/uneven ground)
- Multi-person: pick main subject (centered + large box), track with simple IoU
- Frame skipping (process every N frames) to reduce compute
- On-screen overlays: skeleton, bbox, status panel inside the box

⚠️ Notes
- Hand stability logic removed per request.
- Torso display updated to show Upright/Slight Lean/Aggressive Lean.
"""
from __future__ import annotations
import cv2
import math
import time
import numpy as np
from collections import deque
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from rtmlib.tools import RTMPose, YOLOX
from rtmlib import draw_skeleton, draw_bbox

# --- External models (adjust imports to your environment) ---
try:
    from rtmlib.tools import RTMPose, YOLOX
    from rtmlib import draw_skeleton, draw_bbox
except Exception:
    RTMPose = YOLOX = None
    def draw_skeleton(img, kpts, pairs=None, conf=None, conf_thr=0.2, **kw):
        # Minimal fallback: draw small circles for keypoints
        for i, p in enumerate(kpts):
            if len(p) >= 3 and p[2] < conf_thr: 
                continue
            cv2.circle(img, (int(p[0]), int(p[1])), 2, (255,255,255), -1)
        return img
    def draw_bbox(img, box, **kw):
        x1,y1,x2,y2 = map(int, box)
        return cv2.rectangle(img,(x1,y1),(x2,y2),(255,255,255),1)

# -----------------------------
# Keypoint indexing (HALPE-26 / body7)
# -----------------------------
KEYPOINT_DICT = {"NOSE":0, "L_EYE":1, "R_EYE":2, "L_EAR":3, "R_EAR":4, "L_SHOULDER":5,
                 "R_SHOULDER":6, "L_ELBOW":7, "R_ELBOW":8, "L_WRIST":9, "R_WRIST":10,
                 "L_HIP":11, "R_HIP":12, "L_KNEE":13, "R_KNEE":14, "L_ANKLE":15, "R_ANKLE":16,
                 "HEAD":17, "NECK":18, "HIP_CENTER":19, "L_BIG_TOE":20, "R_BIG_TOE":21, "L_SMALL_TOE":22,
                 "R_SMALL_TOE":23, "L_HEEL":24, "R_HEEL":25}

# Convenience constants
NOSE = KEYPOINT_DICT["NOSE"]
L_EAR, R_EAR = KEYPOINT_DICT["L_EAR"], KEYPOINT_DICT["R_EAR"]
L_SHOULDER, R_SHOULDER = KEYPOINT_DICT["L_SHOULDER"], KEYPOINT_DICT["R_SHOULDER"]
L_ELBOW, R_ELBOW = KEYPOINT_DICT["L_ELBOW"], KEYPOINT_DICT["R_ELBOW"]
L_WRIST, R_WRIST = KEYPOINT_DICT["L_WRIST"], KEYPOINT_DICT["R_WRIST"]
L_HIP, R_HIP = KEYPOINT_DICT["L_HIP"], KEYPOINT_DICT["R_HIP"]
L_KNEE, R_KNEE = KEYPOINT_DICT["L_KNEE"], KEYPOINT_DICT["R_KNEE"]
L_ANKLE, R_ANKLE = KEYPOINT_DICT["L_ANKLE"], KEYPOINT_DICT["R_ANKLE"]
L_BIG_TOE, R_BIG_TOE = KEYPOINT_DICT["L_BIG_TOE"], KEYPOINT_DICT["R_BIG_TOE"]
L_HEEL, R_HEEL = KEYPOINT_DICT["L_HEEL"], KEYPOINT_DICT["R_HEEL"]

# -----------------------------
# Anthro + physics params (approx, Dempster-based)
# -----------------------------
GRAVITY = 9.81
ANTHRO = {
    "Head":      {"mass_perc": 0.081,  "com_perc_prox": 0.50},
    "Trunk":     {"mass_perc": 0.497,  "com_perc_prox": 0.50},
    "Thigh":     {"mass_perc": 0.100,  "com_perc_prox": 0.433},
    "Shank":     {"mass_perc": 0.0465, "com_perc_prox": 0.433},
    "Foot":      {"mass_perc": 0.0145, "com_perc_prox": 0.50},
    "UpperArm":  {"mass_perc": 0.027,  "com_perc_prox": 0.436},
    "Forearm":   {"mass_perc": 0.016,  "com_perc_prox": 0.430},
    "Hand":      {"mass_perc": 0.006,  "com_perc_prox": 0.50},
}

# -----------------------------
# Small math/pose helpers
# -----------------------------

def xyc(kp):
    if kp is None:
        return None
    a = np.asarray(kp)
    if a.size >= 3:
        return float(a[0]), float(a[1]), float(a[2])
    return float(a[0]), float(a[1]), 1.0

def valid(kp, thr=0.3):
    v = xyc(kp)
    return (v is not None) and (v[2] >= thr)

def midpoint(p1, p2, min_conf=0.0):
    a, b = xyc(p1), xyc(p2)
    if a is None or b is None:
        return None
    if a[2] < min_conf or b[2] < min_conf:
        return None
    x = 0.5 * (a[0] + b[0])
    y = 0.5 * (a[1] + b[1])
    c = min(a[2], b[2])
    return np.array([x, y, c], dtype=np.float32)

def vec(a, b):
    """Returns vector from point a to point b (b - a)."""
    if a is None or b is None:
        return None
    aa, bb = np.asarray(a)[:2], np.asarray(b)[:2]
    if aa.size < 2 or bb.size < 2:
        return None
    return bb - aa

def angle_between(v1, v2):
    """
    Calculates the interior angle (0-180 degrees) between two vectors.
    """
    if v1 is None or v2 is None:
        return None
    v1 = np.asarray(v1, dtype=float)
    v2 = np.asarray(v2, dtype=float)
    
    # Magnitudes
    n1 = np.linalg.norm(v1)
    n2 = np.linalg.norm(v2)
    
    # Prevent Division by Zero
    if n1 < 1e-6 or n2 < 1e-6:
        return 0.0
    
    # Dot product formula: a . b = |a||b|cos(theta)
    dot_prod = np.dot(v1, v2)
    c = dot_prod / (n1 * n2)
    
    # Clipping is crucial because float precision can result in 1.000000002
    c = max(-1.0, min(1.0, c))
    
    return math.degrees(math.acos(c))

def angle_to_horizontal(vseg):
    if vseg is None:
        return None
    return math.degrees(math.atan2(float(vseg[1]), float(vseg[0])))

def sgolay(y, window=5, poly=2):
    # Minimal S-G filter (optional): uses numpy polyfit over sliding windows
    if len(y) < window:
        return np.asarray(y)
    out = np.zeros_like(y, dtype=float)
    half = window // 2
    xs = np.arange(len(y))
    for i in range(len(y)):
        lo = max(0, i - half)
        hi = min(len(y), i + half + 1)
        xi = xs[lo:hi]
        yi = y[lo:hi]
        if len(xi) < 2:
            out[i] = y[i]
            continue
        coeff = np.polyfit(xi - xi.mean(), yi, poly)
        out[i] = np.polyval(coeff, 0.0)  # center value
    return out

def deriv(y, dt, smooth=True):
    y = np.asarray(y, dtype=float)
    if y.size < 2:
        return np.zeros_like(y)
    if smooth:
        y = sgolay(y)
    return np.gradient(y, dt)

# -----------------------------
# IoU + simple geometry
# -----------------------------

def iou_xyxy(a, b):
    xa1, ya1, xa2, ya2 = a
    xb1, yb1, xb2, yb2 = b
    inter_w = max(0.0, min(xa2, xb2) - max(xa1, xb1))
    inter_h = max(0.0, min(ya2, yb2) - max(ya1, yb1))
    inter = inter_w * inter_h
    if inter <= 0:
        return 0.0
    area_a = max(0.0, xa2 - xa1) * max(0.0, ya2 - ya1)
    area_b = max(0.0, xb2 - xb1) * max(0.0, yb2 - yb1)
    return float(inter / (area_a + area_b - inter + 1e-6))

# -----------------------------
# L/R sanity fix (pelvis-aligned)
# -----------------------------

def pelvis_local_frame(kpts, scores, conf_thr=0.2):
    K = np.asarray(kpts)
    S = np.asarray(scores)
    mh = midpoint(K[L_HIP], K[R_HIP], 0.0)
    ms = midpoint(K[L_SHOULDER], K[R_SHOULDER], 0.0)
    if mh is None or ms is None:
        return None, None
    pelvis = mh[:2]
    up = ms[:2] - pelvis
    n = float(np.linalg.norm(up))
    if n < 1e-6:
        return pelvis, np.eye(2, dtype=np.float32)
    ang = math.atan2(up[1], up[0])
    # rotate so that up -> +y
    rot = np.array([[math.cos(math.pi/2 - ang), -math.sin(math.pi/2 - ang)],
                    [math.sin(math.pi/2 - ang),  math.cos(math.pi/2 - ang)]], dtype=np.float32)
    return pelvis, rot

def _side_sign(x, margin=4.0):
    if x < -margin: return -1
    if x >  margin: return  1
    return 0

def fix_lr_per_frame(kpts, scores, side_margin_px=4.0, conf_thr=0.2):
    K = np.asarray(kpts).copy()
    S = np.asarray(scores).copy()
    pelvis, rot = pelvis_local_frame(K, S, conf_thr=conf_thr)
    if pelvis is None or rot is None:
        return K, S
    def to_local(idx):
        if S[idx] < conf_thr:
            return None
        p = K[idx]
        return (p[:2] - pelvis) @ rot.T
    # Knees
    lk, rk = to_local(L_KNEE), to_local(R_KNEE)
    if lk is not None and rk is not None:
        sl, sr = _side_sign(lk[0], side_margin_px), _side_sign(rk[0], side_margin_px)
        if sl == 1 and sr != 1:
            K[[L_KNEE, R_KNEE]] = K[[R_KNEE, L_KNEE]]
            S[[L_KNEE, R_KNEE]] = S[[R_KNEE, L_KNEE]]
        elif sr == -1 and sl != -1:
            K[[L_KNEE, R_KNEE]] = K[[R_KNEE, L_KNEE]]
            S[[L_KNEE, R_KNEE]] = S[[R_KNEE, L_KNEE]]
    # Ankles
    la, ra = to_local(L_ANKLE), to_local(R_ANKLE)
    if la is not None and ra is not None:
        sl, sr = _side_sign(la[0], side_margin_px), _side_sign(ra[0], side_margin_px)
        if sl == 1 and sr != 1:
            K[[L_ANKLE, R_ANKLE]] = K[[R_ANKLE, L_ANKLE]]
            S[[L_ANKLE, R_ANKLE]] = S[[R_ANKLE, L_ANKLE]]
        elif sr == -1 and sl != -1:
            K[[L_ANKLE, R_ANKLE]] = K[[R_ANKLE, L_ANKLE]]
            S[[L_ANKLE, R_ANKLE]] = S[[R_ANKLE, L_ANKLE]]
    return K, S

# -----------------------------
# Foot contact/wheeling + twist logic
# -----------------------------
@dataclass
class FootIndices:
    ankle: int
    heel: int
    big_toe: int
    small_toe: Optional[int] = None

@dataclass
class FootHistory:
    ankle: deque
    heel: deque
    toe: deque
    contact_flags: deque
    state_labels: deque
    def __init__(self, maxlen=90):
        self.ankle = deque(maxlen=maxlen)
        self.heel = deque(maxlen=maxlen)
        self.toe  = deque(maxlen=maxlen)
        self.contact_flags = deque(maxlen=maxlen)
        self.state_labels  = deque(maxlen=maxlen)

class SkateContactAnalyzer:
    def __init__(self,
                 left_idx: FootIndices,
                 right_idx: FootIndices,
                 history_len=120,
                 # We replace absolute pixel thresholds with Ratios (0.0 - 1.0)
                 ground_tolerance_ratio=0.15, # Foot is "grounded" if within 15% of leg length from bottom
                 pitch_threshold_ratio=0.25,  # Foot is "wheeling" if pitch is steep enough
                 min_conf=0.3,
                 **kwargs): # Absorb leftover kwargs
        self.left_idx = left_idx
        self.right_idx = right_idx
        self.min_conf = min_conf
        self.ground_tolerance = ground_tolerance_ratio
        self.pitch_threshold = pitch_threshold_ratio
        
        # Simple history for smoothing (optional)
        self.left_hist  = deque(maxlen=history_len)
        self.right_hist = deque(maxlen=history_len)
        
        # We still calculate wheeling angle for display, but not for classification
        self.wheel_frames = 5

    def _get_coords(self, kps, idx_obj):
        # Helper to safely extract (x,y) for ankle, heel, toe
        # Returns None if any keypoint is missing/low conf
        a_idx, h_idx, t_idx = idx_obj.ankle, idx_obj.heel, idx_obj.big_toe
        
        # Check boundaries
        if max(a_idx, h_idx, t_idx) >= len(kps): return None, None, None
        
        a = kps[a_idx]
        h = kps[h_idx]
        t = kps[t_idx]
        
        # Simple check if 3rd dim exists (score)
        if len(a) > 2 and (a[2] < self.min_conf or h[2] < self.min_conf or t[2] < self.min_conf):
            return None, None, None
            
        return a[:2], h[:2], t[:2]

    def _dist(self, p1, p2):
        return math.sqrt((p1[0]-p2[0])**2 + (p1[1]-p2[1])**2)

    def update(self, kps: np.ndarray, scores: np.ndarray, t: float, leg_choice='auto'):
        # 1. Get basic coordinates
        La, Lh, Lt = self._get_coords(kps, self.left_idx)
        Ra, Rh, Rt = self._get_coords(kps, self.right_idx)

        # 2. Determine Scale (Shank Length)
        # We use Knee->Ankle distance. If missing, fallback to a default or foot length.
        def get_shank_len(knee_idx, ankle_idx):
            if scores[knee_idx] < self.min_conf or scores[ankle_idx] < self.min_conf:
                return 0.0
            return self._dist(kps[knee_idx,:2], kps[ankle_idx,:2])

        L_shank = get_shank_len(KEYPOINT_DICT['L_KNEE'], KEYPOINT_DICT['L_ANKLE'])
        R_shank = get_shank_len(KEYPOINT_DICT['R_KNEE'], KEYPOINT_DICT['R_ANKLE'])
        
        # Use the larger valid shank as the "Body Unit". If both fail, use 1.0 (avoids div/0)
        scale_unit = max(L_shank, R_shank, 1.0)

        # 3. Determine "Ground Level" for this frame
        # We assume the lowest point in the entire skeleton (Y-max) is touching the floor.
        # We look at Heels and Toes of both feet.
        candidates = []
        if Lh is not None: candidates.extend([Lh[1], Lt[1]])
        if Rh is not None: candidates.extend([Rh[1], Rt[1]])
        
        if not candidates:
            return self._empty_result()

        ground_y_level = max(candidates) # In OpenCV, Y increases downwards, so Max Y is lowest point

        # 4. Define Classification Function
        def classify_foot(a, h, t):
            if a is None: return "Unknown", 0.0
            
            # A. Ground Check
            # Is the lowest part of this foot close to the global ground level?
            lowest_part = max(h[1], t[1])
            dist_to_floor = ground_y_level - lowest_part
            
            # Threshold is dynamic: e.g. 15% of shank length
            is_grounded = dist_to_floor < (self.ground_tolerance * scale_unit)
            
            if not is_grounded:
                return "AIR", 0.0
            
            # B. Pitch Check (Heel vs Toe)
            # Normalized Pitch = (Toe_Y - Heel_Y) / Foot_Length
            # Y is down. 
            # Positive Pitch -> Toe is "lower" (bigger Y) -> Toe Dominant?
            # Wait, if Toe is ON floor and Heel is IN AIR (lower Y), Toe_Y > Heel_Y. 
            # So Positive = Toe Wheeling.
            
            foot_len = self._dist(h, t)
            if foot_len < 1e-6: foot_len = 1.0
            
            raw_pitch = t[1] - h[1]
            norm_pitch = raw_pitch / foot_len
            
            # Use the dynamic threshold
            if norm_pitch > self.pitch_threshold:
                return "TOE", 0.0 # Toe is significantly lower/deeper
            elif norm_pitch < -self.pitch_threshold:
                return "HEEL", 0.0 # Heel is significantly lower/deeper
            else:
                return "FLAT", 0.0

        # 5. Run Classification
        L_state, _ = classify_foot(La, Lh, Lt)
        R_state, _ = classify_foot(Ra, Rh, Rt)
        
        # 6. Calculate Angle (Optional, for display only)
        # Kept the old logic just for the "degree" display, but it doesn't affect state
        L_wheel = self._calc_angle(kps, scores, 'left')
        R_wheel = self._calc_angle(kps, scores, 'right')

        return {
            'left': {'state': L_state},
            'right':{'state': R_state},
            'wheeling': {'left': L_wheel, 'right': R_wheel},
            'ground_line': None # We aren't using line fitting anymore
        }

    def _calc_angle(self, kps, scores, side):
        # (Kept from previous code for visual feedback)
        idx = self.left_idx if side == 'left' else self.right_idx
        knee = KEYPOINT_DICT['L_KNEE'] if side == 'left' else KEYPOINT_DICT['R_KNEE']
        if scores[knee] < self.min_conf or scores[idx.ankle] < self.min_conf:
             return {'label': None, 'angle': None}
        # Simple Shin-Foot angle
        v1 = kps[knee,:2] - kps[idx.ankle,:2]
        v2 = kps[idx.big_toe,:2] - kps[idx.ankle,:2]
        ang = abs(angle_between(v1, v2))
        return {'label': None, 'angle': round(ang, 1)}

    def _empty_result(self):
        return {
            'left': {'state': 'Unknown'}, 'right':{'state': 'Unknown'},
            'wheeling': {'left': None, 'right': None}, 'ground_line': None
        }
# -----------------------------
# Posture & knee angles / body stability
# -----------------------------

def posture_metrics(kps: np.ndarray, scores: np.ndarray, conf_thr=0.3):
    l_sh, r_sh = kps[L_SHOULDER], kps[R_SHOULDER]
    l_hp, r_hp = kps[L_HIP], kps[R_HIP]
    l_ear, r_ear = kps[L_EAR], kps[R_EAR]
    nose = kps[NOSE]
    torso_angle = None
    esh_l = None
    esh_r = None
    head_center = None
    ms = midpoint(l_sh, r_sh, conf_thr) if (scores[L_SHOULDER]>=conf_thr and scores[R_SHOULDER]>=conf_thr) else None
    mh = midpoint(l_hp, r_hp, conf_thr) if (scores[L_HIP]>=conf_thr and scores[R_HIP]>=conf_thr) else None
    if ms is not None and mh is not None:
        torso_vec = vec(mh, ms)
        if torso_vec is not None:
            torso_angle = angle_between(torso_vec, np.array([0.0, -1.0]))
    if scores[L_EAR]>=conf_thr and scores[L_SHOULDER]>=conf_thr and scores[L_HIP]>=conf_thr:
        esh_l = angle_between(vec(l_sh, l_ear), vec(l_sh, l_hp))
    if scores[R_EAR]>=conf_thr and scores[R_SHOULDER]>=conf_thr and scores[R_HIP]>=conf_thr:
        esh_r = angle_between(vec(r_sh, r_ear), vec(r_sh, r_hp))
    head_kps = []
    for idx in (NOSE, L_EAR, R_EAR):
        if scores[idx] >= conf_thr:
            head_kps.append(kps[idx,:2])
    if head_kps:
        head_center = np.mean(np.asarray(head_kps), axis=0)
    return torso_angle, esh_l, esh_r, head_center


def knee_angle(kps: np.ndarray, scores: np.ndarray, side: str, conf_thr=0.3):
    if side == 'left':
        hip, knee, ankle = L_HIP, L_KNEE, L_ANKLE
    else:
        hip, knee, ankle = R_HIP, R_KNEE, R_ANKLE
    if min(scores[hip], scores[knee], scores[ankle]) < conf_thr:
        return None
    thigh = vec(kps[hip], kps[knee])
    shank = vec(kps[knee], kps[ankle])
    if thigh is None or shank is None:
        return None
    return 180.0 - angle_between(thigh, shank)

# -----------------------------
# Simple multi-person tracker (greedy IoU) + main-subject chooser
# -----------------------------
class BoxTracker:
    def __init__(self, iou_thr=0.2):
        self.iou_thr = iou_thr
        self.tracks = {}  # id -> {'box': [x1,y1,x2,y2], 'age': int}
        self.next_id = 0

    def update(self, boxes: List[np.ndarray], frame_shape):
        # Returns a list of IDs corresponding to the input boxes
        assigned_ids = [-1] * len(boxes)
        used_tracks = set()
        
        # 1. Match existing tracks
        for i, box in enumerate(boxes):
            best_id = -1
            best_iou = 0.0
            
            for tid, track in self.tracks.items():
                if tid in used_tracks: continue
                
                iou = iou_xyxy(track['box'], box)
                if iou > best_iou:
                    best_iou = iou
                    best_id = tid
            
            if best_iou >= self.iou_thr:
                assigned_ids[i] = best_id
                used_tracks.add(best_id)
                self.tracks[best_id]['box'] = box
                self.tracks[best_id]['age'] = 0

        # 2. Create new tracks for unmatched boxes
        for i in range(len(boxes)):
            if assigned_ids[i] == -1:
                new_id = self.next_id
                self.next_id += 1
                self.tracks[new_id] = {'box': boxes[i], 'age': 0}
                assigned_ids[i] = new_id
        
        # 3. Age out stale tracks
        for tid in list(self.tracks.keys()):
            if tid not in used_tracks:
                self.tracks[tid]['age'] += 1
                if self.tracks[tid]['age'] > 30: # Lost for 1 sec
                    del self.tracks[tid]
                    
        return assigned_ids

    def _score_main(self, b, W, H):
        # (Same scoring logic as before)
        x1,y1,x2,y2 = b
        cx, cy = (x1+x2)/2.0, (y1+y2)/2.0
        dx = (cx - W/2.0) / max(1.0, W)
        dy = (cy - H/2.0) / max(1.0, H)
        area = max(1.0, (x2-x1)*(y2-y1))
        return (1.0 - (abs(dx)+abs(dy))) + 0.00001 * area

# -----------------------------
# Metric logger + v2 runner
class MetricLogger:
    def __init__(self, path: Optional[str]):
        self.path = path
        self.f = None
        if path:
            import csv
            self.csv = csv
            self.f = open(path, 'w', newline='')
            # Removed 'hands' from fieldnames
            self.w = csv.DictWriter(self.f, fieldnames=[
                'frame','time_s','subject','stance','wheelL','wheelR','kickL','kickR',
                'twistL','twistR','torso_deg','kneeL_deg','kneeR_deg','vt_left','vt_right','ground_slope'
            ])
            self.w.writeheader()
    def write_row(self, row: Dict):
        if self.f:
            row = {k: (str(v) if isinstance(v, (dict,list,tuple)) else v) for k,v in row.items()}
            self.w.writerow(row)
    def close(self):
        if self.f:
            self.f.flush(); self.f.close(); self.f = None

def analyze_video_v2(video_path: str,
                  yolo_path: str,
                  rtm_path: str,
                  process_every_n: int = 2,
                  yolo_size=(640,640),
                  rtm_size=(288,384),
                  backend='onnxruntime',
                  device='cuda',
                  out_path: Optional[str] = None,
                  max_subjects: int = 1,
                  log_csv: Optional[str] = None,
                  show_window: bool = True):
    if YOLOX is None or RTMPose is None:
        raise RuntimeError("rtmlib not available. Please install/import YOLOX and RTMPose wrappers.")
    det = YOLOX(onnx_model=yolo_path, model_input_size=yolo_size, backend=backend, device=device)
    pose = RTMPose(onnx_model=rtm_path, model_input_size=rtm_size, backend=backend, device=device)
    logger = MetricLogger(log_csv) if log_csv else None
    analyzer = SkatingVideoAnalyzer(det, pose, process_every_n=process_every_n, max_subjects=max_subjects, logger=logger)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video: {video_path}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    analyzer.fps = float(fps)
    W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    writer = None
    if out_path:
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H))
    fidx = 0
    t0 = time.time()
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        ts = time.time() - t0
        out = analyzer.process(frame, fidx, ts)
        if writer is not None:
            writer.write(out)
        elif show_window:
            cv2.imshow('Skate Analyzer', out)
            if cv2.waitKey(1) & 0xFF == 27:
                break
        fidx += 1
    cap.release()
    if writer is not None:
        writer.release()
    if show_window:
        cv2.destroyAllWindows()
    if logger is not None:
        logger.close()

class LRTemporalResolver:
    """
    Smart L/R resolver that uses history AND anatomy (bone lengths) 
    to prevent legs from swapping during falls or crossovers.
    """
    def __init__(self,
                 conf_thr=0.3,
                 min_swap_improve=12.0,
                 persist_frames=2,
                 ema_alpha=0.6):
        self.conf_thr = float(conf_thr)
        self.min_swap_improve = float(min_swap_improve)
        self.persist_frames = int(persist_frames)
        self.ema_alpha = float(ema_alpha)
        
        # Store smoothed positions
        self.prev_kps = None 
        # Store smoothed bone lengths (to detect if a leg suddenly grows/shrinks)
        self.prev_lens = {'L': {'thigh': None, 'shank': None}, 
                          'R': {'thigh': None, 'shank': None}}
        self._swap_streak = 0

    def _p(self, kps, scores, idx):
        if idx is None or scores[idx] < self.conf_thr: return None
        return np.array(kps[idx,:2], dtype=float)

    def _dist(self, a, b):
        if a is None or b is None: return 0.0
        return float(np.linalg.norm(a - b))

    def _get_lengths(self, assign):
        # Calculate Thigh (Hip->Knee) and Shank (Knee->Ankle) lengths
        lens = {'L': {}, 'R': {}}
        for side in ('L','R'):
            lens[side]['thigh'] = self._dist(assign[side]['hip'], assign[side]['knee'])
            lens[side]['shank'] = self._dist(assign[side]['knee'], assign[side]['ankle'])
        return lens

    def _anatomy_cost(self, cur_lens):
        # Penalize if the new bone length is vastly different from history
        cost = 0.0
        # Tolerance: we allow bone length to change by ~15% (perspective shift)
        # If it changes by 50%+, it's likely a wrong connection.
        for side in ('L', 'R'):
            for part in ('thigh', 'shank'):
                curr = cur_lens[side][part]
                prev = self.prev_lens[side][part]
                if curr > 0 and prev is not None and prev > 0:
                    diff = abs(curr - prev)
                    ratio = diff / prev
                    if ratio > 0.2: # 20% change penalty
                        cost += 100.0 * ratio # Heavy penalty
        return cost

    def _pack(self, kps, scores, order=('L','R')):
        sides = {}
        for side, A,B,T,H in (
            (order[0], L_ANKLE, L_KNEE, L_BIG_TOE, L_HEEL),
            (order[1], R_ANKLE, R_KNEE, R_BIG_TOE, R_HEEL),
        ):
            sides[side] = {
                'ankle': self._p(kps, scores, A),
                'knee':  self._p(kps, scores, B),
                'toe':   self._p(kps, scores, T),
                'heel':  self._p(kps, scores, H),
                'hip':   self._p(kps, scores, (L_HIP if side=='L' else R_HIP)),
            }
        return sides

    def resolve(self, kps, scores):
        # 1. Pack current raw detections
        cur_id = self._pack(kps, scores, order=('L','R'))
        cur_sw = self._pack(kps, scores, order=('R','L'))

        # 2. First Frame Initialization
        if self.prev_kps is None:
            self.prev_kps = cur_id
            self.prev_lens = self._get_lengths(cur_id)
            return kps, scores

        # 3. Calculate Movement Cost (Euclidean distance from last frame)
        def move_cost(c, p):
            cost = 0.0
            for side in ('L','R'):
                for k in ('ankle','knee','hip'):
                    cost += self._dist(c[side][k], p[side][k])
            return cost

        m_cost_id = move_cost(cur_id, self.prev_kps)
        m_cost_sw = move_cost(cur_sw, self.prev_kps)

        # 4. Calculate Anatomy Cost (Bone length consistency)
        lens_id = self._get_lengths(cur_id)
        lens_sw = self._get_lengths(cur_sw)
        
        a_cost_id = self._anatomy_cost(lens_id)
        a_cost_sw = self._anatomy_cost(lens_sw)

        # 5. Total Cost
        total_id = m_cost_id + a_cost_id
        total_sw = m_cost_sw + a_cost_sw

        # 6. Decision Logic (Hysteresis)
        if total_sw + self.min_swap_improve < total_id:
            self._swap_streak += 1
        else:
            self._swap_streak = 0

        do_swap = (self._swap_streak >= self.persist_frames)
        chosen = cur_sw if do_swap else cur_id
        chosen_lens = lens_sw if do_swap else lens_id

        # 7. Update History (EMA)
        alpha = self.ema_alpha
        for side in ('L','R'):
            # Update positions
            for k in ('ankle','knee','hip'):
                curr = chosen[side][k]
                prev = self.prev_kps[side][k]
                if curr is not None:
                    if prev is None: self.prev_kps[side][k] = curr
                    else: self.prev_kps[side][k] = (1-alpha)*prev + alpha*curr
            
            # Update bone lengths (slower update to keep them stable)
            for part in ('thigh','shank'):
                curr_l = chosen_lens[side][part]
                prev_l = self.prev_lens[side][part]
                if curr_l > 0:
                    if prev_l is None: self.prev_lens[side][part] = curr_l
                    else: self.prev_lens[side][part] = 0.9*prev_l + 0.1*curr_l

        if not do_swap:
            return kps, scores

        # 8. Apply Swap to Output Array
        K = np.asarray(kps).copy()
        S = np.asarray(scores).copy()
        swap_pairs = [
            (L_ANKLE, R_ANKLE), (L_KNEE, R_KNEE), (L_HEEL, R_HEEL), (L_BIG_TOE, R_BIG_TOE),
            (L_HIP, R_HIP), (L_SHOULDER, R_SHOULDER), (L_ELBOW, R_ELBOW), (L_WRIST, R_WRIST),
            (L_EAR, R_EAR)
        ]
        for a,b in swap_pairs:
            K[[a,b]] = K[[b,a]]
            S[[a,b]] = S[[b,a]]
        return K, S


# -----------------------------
# Main Analyzer wrapper
# -----------------------------
class SkatingVideoAnalyzer:
    def __init__(self,
                 yolo_model: YOLOX,
                 rtm_model: RTMPose,
                 process_every_n=2,
                 conf_thr=0.3,
                 draw_pairs: Optional[List[Tuple[int,int]]] = None,
                 max_subjects: int = 1,
                 logger: Optional["MetricLogger"] = None):
        self.det = yolo_model
        self.pose = rtm_model
        self.nskip = max(1, int(process_every_n))
        self.conf_thr = float(conf_thr)
        self.tracker = BoxTracker(iou_thr=0.2)
        self.max_subjects = max_subjects
        self.draw_pairs = draw_pairs
        self.fps = 30.0
        self.logger = logger
        self.subject_states = {}
        
        # FIX 1: Persistent cache for frame skipping
        self.cached_subjects = [] 

    def _get_or_create_state(self, tid):
        if tid not in self.subject_states:
            self.subject_states[tid] = {
                'contact': SkateContactAnalyzer(
                    left_idx=FootIndices(L_ANKLE, L_HEEL, L_BIG_TOE, None),
                    right_idx=FootIndices(R_ANKLE, R_HEEL, R_BIG_TOE, None),
                    ground_tolerance_ratio=0.15, 
                    pitch_threshold_ratio=0.25,
                    min_conf=self.conf_thr
                ),
                'lr_resolver': LRTemporalResolver(conf_thr=self.conf_thr),
                'last_kps': None,
                'last_scores': None,
                'L_knee': deque(maxlen=60),
                'R_knee': deque(maxlen=60),
            }
        return self.subject_states[tid]

    # Legacy support
    def _ensure_subjects(self, n): pass

    # ... (Keep _yolo_detect and _rtm_infer exactly as they were) ...
    # ... (Paste _yolo_detect and _rtm_infer here if copying full class) ...
    # For brevity, I assume you have them from previous steps. 
    def _yolo_detect(self, frame):
        # ... [Insert your existing _yolo_detect code here] ...
        out = self.det(frame)
        boxes, scores = [], []
        if out is None: return boxes, scores
        if isinstance(out, dict) and 'boxes' in out:
            arr = np.asarray(out['boxes'], dtype=float)
            if arr.ndim == 1 and arr.size >= 4: arr = arr.reshape(1, -1)
            if arr.ndim == 2 and arr.shape[1] >= 4:
                for row in arr:
                    boxes.append(row[:4].copy())
                    scores.append(float(row[4]) if arr.shape[1] >= 5 else 1.0)
            else:
                for b in out['boxes']:
                    b = np.asarray(b, dtype=float).ravel()
                    if b.size >= 4:
                        boxes.append(b[:4].copy())
                        scores.append(float(b[4]) if b.size >= 5 else 1.0)
            return boxes, scores
        if isinstance(out, np.ndarray):
            arr = np.asarray(out, dtype=float)
            if arr.ndim == 1 and arr.size >= 4: arr = arr.reshape(1, -1)
            if arr.ndim == 2 and arr.shape[1] >= 4:
                for row in arr:
                    boxes.append(row[:4].copy())
                    scores.append(float(row[4]) if arr.shape[1] >= 5 else 1.0)
            return boxes, scores
        if isinstance(out, (list, tuple)):
            for o in out:
                o = np.asarray(o, dtype=float).ravel()
                if o.size >= 4:
                    boxes.append(o[:4].copy())
                    scores.append(float(o[4]) if o.size >= 5 else 1.0)
            return boxes, scores
        return boxes, scores

    def _rtm_infer(self, frame, boxes):
        # ... [Insert your existing _rtm_infer code here] ...
        if not boxes: return [], []
        bxs = [np.asarray(b, dtype=float).ravel()[:4].tolist() for b in boxes]
        try: res = self.pose(frame, bboxes=bxs)
        except: return [], []
        if isinstance(res, tuple): k_list, s_list = res
        elif isinstance(res, dict):
            k_list = res.get('keypoints') or res.get('kps') or []
            s_list = res.get('scores')
        else: k_list, s_list = res, None
        out_k, out_s = [], []
        for i in range(min(len(bxs), len(k_list))):
            kps = np.asarray(k_list[i], dtype=float)
            if kps.ndim != 2 or kps.shape[1] < 2: continue
            if kps.shape[1] >= 3:
                scores = kps[:, 2].astype(float)
                kps = kps[:, :2].astype(float)
            else:
                scores = np.ones((kps.shape[0],), dtype=float)
                kps = kps[:, :2].astype(float)
            out_k.append(kps)
            out_s.append(scores)
        return out_k, out_s

    def _get_global_state(self, L_state: str, R_state: str) -> str:
        l_type = L_state if L_state not in ["Unknown"] else "AIR"
        r_type = R_state if R_state not in ["Unknown"] else "AIR"
        if l_type == "AIR" and r_type == "AIR": return "Airborne"
        if l_type != "AIR" and r_type == "AIR": return f"1-Foot Left ({l_type})"
        if l_type == "AIR" and r_type != "AIR": return f"1-Foot Right ({r_type})"
        return f"2-Foot ({l_type} - {r_type})"

    def process(self, frame: np.ndarray, fidx: int, ts: float) -> np.ndarray:
        H, W = frame.shape[:2]
        do_proc = (fidx % self.nskip == 0)
        
        # FIX 2: Only clear/update cache if we are actually processing
        if do_proc:
            # Reset cache for this new detection cycle
            self.cached_subjects = [] 
            
            raw_boxes, raw_scores = self._yolo_detect(frame)
            clipped_boxes = [np.clip(b, [0,0,0,0], [W-1,H-1,W-1,H-1]) for b in raw_boxes]
            box_ids = self.tracker.update(clipped_boxes, frame.shape)
            candidates = list(zip(clipped_boxes, box_ids))
            candidates.sort(key=lambda x: self.tracker._score_main(x[0], W, H), reverse=True)
            top_candidates = candidates[:self.max_subjects]
            
            if top_candidates:
                chosen_boxes = [c[0] for c in top_candidates]
                chosen_ids   = [c[1] for c in top_candidates]
                kps_list, scr_list = self._rtm_infer(frame, chosen_boxes)
                
                for i, tid in enumerate(chosen_ids):
                    # Store in the persistent cache
                    self.cached_subjects.append({
                        'id': tid,
                        'bbox': chosen_boxes[i],
                        'kps': kps_list[i],
                        'scores': scr_list[i]
                    })
            
            # Garbage collection
            active_ids = set(box_ids)
            for stored_id in list(self.subject_states.keys()):
                if stored_id not in self.tracker.tracks:
                    del self.subject_states[stored_id]

        # FIX 3: ALWAYS iterate over cached_subjects, not a local list
        for subj_data in self.cached_subjects:
            tid = subj_data['id']
            bbox = subj_data['bbox']
            raw_kps = subj_data['kps']
            raw_scores = subj_data['scores']
            
            if raw_kps is None: continue
            
            state = self._get_or_create_state(tid)
            
            # Only update physics/history if this is a NEW detection frame
            # Otherwise we just redraw the previous state to avoid calculating physics on stale data
            if do_proc:
                kps, scores = state['lr_resolver'].resolve(raw_kps, raw_scores)
                state['last_kps'] = kps
                state['last_scores'] = scores
                
                contact_out = state['contact'].update(kps, scores, ts)
                
                torso_ang, _, _, _ = posture_metrics(kps, scores, self.conf_thr)
                L_k = knee_angle(kps, scores, 'left', self.conf_thr)
                R_k = knee_angle(kps, scores, 'right', self.conf_thr)
                
                if L_k: state['L_knee'].append(L_k)
                if R_k: state['R_knee'].append(R_k)

                # Store results in state for skipping frames
                state['last_contact'] = contact_out
                state['last_torso'] = torso_ang
                state['last_Lk'] = L_k
                state['last_Rk'] = R_k

                if self.logger is not None:
                     waL = contact_out['wheeling']['left'].get('angle') if contact_out['wheeling']['left'] else None
                     waR = contact_out['wheeling']['right'].get('angle') if contact_out['wheeling']['right'] else None
                     self.logger.write_row({
                         'frame': fidx, 'time_s': round(ts,4), 'subject': tid,
                         'wheelL_ang': waL, 'wheelR_ang': waR,
                         'torso_deg': None if torso_ang is None else float(torso_ang),
                         'kneeL_deg': None if L_k is None else float(L_k),
                         'kneeR_deg': None if R_k is None else float(R_k),
                         'ground_slope': 0.0
                     })

            # Retrieve data for drawing (either fresh or from last processing)
            kps = state.get('last_kps', raw_kps)
            scores = state.get('last_scores', raw_scores)
            contact_out = state.get('last_contact', {'left':{'state':'Unknown'},'right':{'state':'Unknown'},'wheeling':{'left':None,'right':None}})
            torso_ang = state.get('last_torso')
            L_k = state.get('last_Lk')
            R_k = state.get('last_Rk')

            state_str = self._get_global_state(contact_out['left']['state'], contact_out['right']['state'])
            
            frame = self._draw_subject(frame, bbox, kps, scores, contact_out, torso_ang, L_k, R_k, state_str, state)
            
        return frame

    def _bbox_from_kps(self, kps, scores):
        vis = scores >= self.conf_thr
        if not np.any(vis): return np.array([0,0,0,0], dtype=float)
        xs, ys = kps[vis,0], kps[vis,1]
        return np.array([float(xs.min()), float(ys.min()), float(xs.max()), float(ys.max())], dtype=float)

    def _draw_subject(self, frame, bbox, kps, scores, out, torso_ang, L_k, R_k, skating_state_str, subj):
        b = np.asarray(bbox, dtype=float).ravel()
        if b.size < 4 or not np.all(np.isfinite(b[:4])): b = self._bbox_from_kps(kps, scores)
        else: b = b[:4]
        H, W = frame.shape[:2]
        b = np.clip(b, [0, 0, 0, 0], [W - 1, H - 1, W - 1, H - 1])

        draw_bbox(frame, [b])
        
        kps_vis = kps[None, :, :] if kps.ndim == 2 else kps
        scores_vis = scores[:, None][None, :, :] if scores.ndim == 1 else scores
        kps_vis = np.nan_to_num(kps_vis, copy=False)
        scores_vis = np.nan_to_num(scores_vis, copy=False)
        draw_skeleton(frame, kps_vis, scores_vis, openpose_skeleton=False, kpt_thr=self.conf_thr)

        def fmt_wheel(w_dict):
            if not w_dict or w_dict.get('angle') is None: return "-"
            return f"{w_dict['angle']}°"

        posture_label = "-"
        if torso_ang is not None:
            val = float(torso_ang)
            if val <= 10: posture_label = "Upright"
            elif val <= 30: posture_label = "Slight Lean"
            else: posture_label = "Aggressive"

        x1, y1, x2, y2 = map(int, b)
        panel = {
            'STATE':         skating_state_str,
            'Posture':       posture_label,
            'L-Knee':        None if L_k is None else f"{int(round(L_k))}°",
            'R-Knee':        None if R_k is None else f"{int(round(R_k))}°",
            'L-Ankle':       fmt_wheel(out['wheeling']['left']),
            'R-Ankle':       fmt_wheel(out['wheeling']['right']),
        }
        
        self._draw_panel(frame, panel, x=20, y=40)
        
        try:
            self._draw_sparkline(frame, list(subj['L_knee']), (x1 + 8,  y2 - 34, 60, 26), label='Lk')
            self._draw_sparkline(frame, list(subj['R_knee']), (x1 + 74, y2 - 34, 60, 26), label='Rk')
        except Exception: pass

        return frame

    @staticmethod
    def _draw_panel(frame, panel: Dict, x=12, y=18, lh=18):
        for i, (k,v) in enumerate(panel.items()):
            txt = f"{k}: {v}"
            cv2.putText(frame, txt, (x, y + i*lh), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255,255,255), 1, cv2.LINE_AA)

    @staticmethod
    def _draw_sparkline(frame, values: List[float], rect: Tuple[int,int,int,int], label: str = None):
        x,y,w,h = rect
        cv2.rectangle(frame, (x,y), (x+w, y+h), (200,200,200), 1)
        if not values: return
        vs = np.array(values, dtype=float)
        vs = vs[np.isfinite(vs)] if vs.size else vs
        if vs.size == 0: return
        vmin, vmax = float(np.min(vs)), float(np.max(vs))
        if abs(vmax - vmin) < 1e-6: vmin, vmax = vmax-1.0, vmax+1.0
        xs = np.linspace(0, w-2, num=len(vs))
        ys = h-2 - (vs - vmin) * (h-4) / (vmax - vmin)
        pts = np.vstack([xs + x + 1, ys + y + 1]).T.astype(np.int32)
        for j in range(1, len(pts)):
            cv2.line(frame, tuple(pts[j-1]), tuple(pts[j]), (0,255,255), 1, cv2.LINE_AA)
        if label:
            cv2.putText(frame, label, (x+2, y-2), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (200,200,200), 1, cv2.LINE_AA)

# -----------------------------
# Public utility: analyze a video file
# -----------------------------

def analyze_video(video_path: str,
                  yolo_path: str,
                  rtm_path: str,
                  process_every_n: int = 2,
                  yolo_size=(640,640),
                  rtm_size=(288,384),
                  backend='onnxruntime',
                  device='cuda',
                  out_path: Optional[str] = None):
    # init models
    if YOLOX is None or RTMPose is None:
        raise RuntimeError("rtmlib not available. Please install/import YOLOX and RTMPose wrappers.")
    det = YOLOX(onnx_model=yolo_path, model_input_size=yolo_size, backend=backend, device=device)
    pose = RTMPose(onnx_model=rtm_path, model_input_size=rtm_size, backend=backend, device=device)
    analyzer = SkatingVideoAnalyzer(det, pose, process_every_n=process_every_n)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Cannot open video: {video_path}")
    fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
    analyzer.fps = float(fps)
    W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    writer = None
    if out_path:
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H))
    fidx = 0
    t0 = time.time()
    while True:
        ok, frame = cap.read()
        if not ok:
            break
        ts = time.time() - t0
        out = analyzer.process(frame, fidx, ts)
        if writer is not None:
            writer.write(out)
        else:
            cv2.imshow('Skate Analyzer', out)
            if cv2.waitKey(1) & 0xFF == 27:
                break
        fidx += 1
    cap.release()
    if writer is not None:
        writer.release()
    cv2.destroyAllWindows()
    print("finished")

In [14]:
from pathlib import Path

# if the code is in the same notebook cell/file, you can call analyze_video directly:
video = "/mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/spin.mp4"
yolo_path = "/mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/yolox_onnx/yolox_l_8xb8-300e_humanart-ce1d7a62/end2end.onnx"
rtm_path = "/mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/rtmpose_onnx/rtmpose-m_simcc-body7_pt-body7-halpe26_700e-384x288-89e6428b_20230605/end2end.onnx"
# Save to a file instead of opening a window (recommended inside notebooks/headless):
out = str(Path("/mnt/samsung_sata/Skate-Sport-Video-Analysis/annotated.mp4"))
analyze_video(video, yolo_path, rtm_path, process_every_n=2, out_path=out)
print("Wrote:", out)

load /mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/yolox_onnx/yolox_l_8xb8-300e_humanart-ce1d7a62/end2end.onnx with onnxruntime backend
load /mnt/samsung_sata/Skate-Sport-Video-Analysis/Docker_RtmPose/rtmpose_onnx/rtmpose-m_simcc-body7_pt-body7-halpe26_700e-384x288-89e6428b_20230605/end2end.onnx with onnxruntime backend


[0;93m2025-11-20 13:15:30.632823482 [W:onnxruntime:, session_state.cc:1280 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-11-20 13:15:30.632834332 [W:onnxruntime:, session_state.cc:1282 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m
[0;93m2025-11-20 13:15:30.731461558 [W:onnxruntime:, session_state.cc:1280 VerifyEachNodeIsAssignedToAnEp] Some nodes were not assigned to the preferred execution providers which may or may not have an negative impact on performance. e.g. ORT explicitly assigns shape related ops to CPU to improve perf.[m
[0;93m2025-11-20 13:15:30.731470218 [W:onnxruntime:, session_state.cc:1282 VerifyEachNodeIsAssignedToAnEp] Rerunning with verbose output on a non-minimal build will show node assignments.[m


finished
Wrote: /mnt/samsung_sata/Skate-Sport-Video-Analysis/annotated.mp4
