In [None]:
import tensorflow as tf
import csv
import tensorflow_hub as hub
from tensorflow_docs.vis import embed
import numpy as np
import cv2
from tqdm import tqdm
import os
import time
import requests
from pathlib import Path
import zipfile
import tarfile
import urllib.request
import shutil


print("Num GPUs Available:", len(tf.config.list_physical_devices('GPU')))

# Import matplotlib libraries
from matplotlib import pyplot as plt
from matplotlib.collections import LineCollection
import matplotlib.patches as patches

# Some modules to display an animation using imageio.
import imageio
from IPython.display import HTML, display
from collections import deque
import pandas as pd

## Pose Estimation Init

In [None]:
class SimpleMoveNetDownloader:
    """Simple MoveNet downloader - download only, no loading"""
    
    def __init__(self, cache_dir="movenet_models"):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
    
    def download_hub_model(self, model_name="movenet_lightning"):
        """Download MoveNet from TensorFlow Hub with local caching"""
        
        model_urls = {
            "movenet_thunder": "https://tfhub.dev/google/movenet/singlepose/thunder/4",
            "movenet_lightning": "https://tfhub.dev/google/movenet/singlepose/lightning/4"
        }
        
        if model_name not in model_urls:
            print(f"❌ Unsupported model: {model_name}")
            return False
        
        model_path = self.cache_dir / model_name
        
        # Check if already exists
        if model_path.exists():
            print(f"✅ Model {model_name} already cached at {model_path}")
            return True
        
        print(f"📥 Downloading {model_name} from TensorFlow Hub...")
        print("⏳ This may take a few minutes...")
        
        try:
            # Set up temporary cache
            temp_cache = str(self.cache_dir / "temp_hub_cache")
            old_cache = os.environ.get('TFHUB_CACHE_DIR', '')
            os.environ['TFHUB_CACHE_DIR'] = temp_cache
            
            # Download the model
            print("🔄 Loading model (this triggers the download)...")
            module = hub.load(model_urls[model_name])
            print("✅ Model downloaded successfully!")
            
            # Move from temp cache to permanent location
            if os.path.exists(temp_cache):
                for item in os.listdir(temp_cache):
                    item_path = os.path.join(temp_cache, item)
                    if os.path.isdir(item_path):
                        if model_path.exists():
                            shutil.rmtree(model_path)
                        shutil.move(item_path, str(model_path))
                        print(f"✅ Model cached at {model_path}")
                        break
                
                # Cleanup temp cache
                shutil.rmtree(temp_cache, ignore_errors=True)
            
            # Restore original cache setting
            if old_cache:
                os.environ['TFHUB_CACHE_DIR'] = old_cache
            elif 'TFHUB_CACHE_DIR' in os.environ:
                del os.environ['TFHUB_CACHE_DIR']
            
            return True
            
        except Exception as e:
            print(f"❌ Download failed: {e}")
            print("💡 This might be due to network issues or TF Hub server problems")
            return False

def download_movenet_simple(model_name="movenet_lightning"):
    """Simple function to download MoveNet - use this in Jupyter"""
    
    print("🚀 Simple MoveNet Downloader")
    print("=" * 40)
    
    downloader = SimpleMoveNetDownloader()
    
    if downloader.download_hub_model(model_name):
        print(f"\n🎉 Success! {model_name} is ready to use")
        print(f"📁 Cached in: {downloader.cache_dir.absolute()}")
        return True
    else:
        print(f"\n❌ Failed to download {model_name}")
        return False

In [None]:
success = download_movenet_simple('movenet_lightning')

# Step 2: Load the model for use
if success:
    movenet, input_size = load_movenet_from_cache('movenet_lightning')
    
    if movenet is not None:
        print(f"🎉 MoveNet ready! Input size: {input_size}")

In [None]:


def load_movenet_fast(model_name="movenet_lighning", cache_dir="movenet_models"):
    """
    Fast MoveNet loader - loads from your cached model
    """
    cache_path = Path(cache_dir)
    model_path = cache_path / model_name
    
    if not model_path.exists():
        raise FileNotFoundError(f"Model not found at {model_path}")
    
    # Load the cached SavedModel
    module = tf.saved_model.load(str(model_path))
    
    def movenet_inference(input_image):
        """MoveNet inference function - same interface as before"""
        model = module.signatures['serving_default']
        input_image = tf.cast(input_image, dtype=tf.int32)
        outputs = model(input_image)
        keypoints_with_scores = outputs['output_0'].numpy()
        return keypoints_with_scores
    
    input_size = 256 if "thunder" in model_name else 192
    print(f"✅ Loaded {model_name} from cache (input size: {input_size})")
    return movenet_inference, input_size

# Step 2: Test the cached model loading
movenet, input_size = load_movenet_fast("movenet_lightning", "movenet_models")
print(f"🎉 Model loaded in ~2 seconds instead of 60+ seconds!")

## Preprocessing

In [5]:
KEYPOINT_DICT = {
    'nose': 0, 'left_eye': 1, 'right_eye': 2, 'left_ear': 3,
    'right_ear': 4, 'left_shoulder': 5,
    'right_shoulder': 6, 'left_elbow': 7, 'right_elbow': 8, 'left_wrist': 9,
    'right_wrist': 10, 'left_hip': 11, 'right_hip': 12, 'left_knee': 13,
    'right_knee': 14, 'left_ankle': 15, 'right_ankle': 16
}

KEYPOINT_EDGE_INDS_TO_COLOR = {
    (5, 3): 'r', (6, 4): 'r', (5, 7): 'b', (7, 9): 'b', (6, 8): 'b', (8, 10): 'b',
    (5, 6): 'b', (5, 11): 'orange', (6, 12): 'orange', (7, 5): 'g', (7, 9): 'g',
    (8, 6): 'g', (8, 10): 'g', (11, 13): 'purple', (13, 15): 'purple', (12, 14): 'purple', (14, 16): 'purple'
}

def _keypoints_and_edges_for_display(keypoints_with_scores, height, width, keypoint_threshold=0.3):
    keypoints_all = []
    keypoint_edges_all = []
    edge_colors = []
    num_instances, _, _, _ = keypoints_with_scores.shape

    for idx in range(num_instances):
        kpts_x = keypoints_with_scores[0, idx, :, 1]
        kpts_y = keypoints_with_scores[0, idx, :, 0]
        kpts_scores = keypoints_with_scores[0, idx, :, 2]
        kpts_absolute_xy = np.stack([width * np.array(kpts_x), height * np.array(kpts_y)], axis=-1)
        kpts_above_thresh_absolute = kpts_absolute_xy[kpts_scores > keypoint_threshold, :]
        keypoints_all.append(kpts_above_thresh_absolute)

        for edge_pair, color in KEYPOINT_EDGE_INDS_TO_COLOR.items():
            if kpts_scores[edge_pair[0]] > keypoint_threshold and kpts_scores[edge_pair[1]] > keypoint_threshold:
                x_start, y_start = kpts_absolute_xy[edge_pair[0]]
                x_end, y_end = kpts_absolute_xy[edge_pair[1]]
                line_seg = np.array([[x_start, y_start], [x_end, y_end]])
                keypoint_edges_all.append(line_seg)
                edge_colors.append(color)

    keypoints_xy = np.concatenate(keypoints_all, axis=0) if keypoints_all else np.zeros((0, 2))
    edges_xy = np.stack(keypoint_edges_all, axis=0) if keypoint_edges_all else np.zeros((0, 2, 2))

    return keypoints_xy, edges_xy, edge_colors

# Function to draw keypoints and edges on image with confidence
def draw_prediction_on_image(image, keypoints_with_scores, original_image=None, flip_applied=False, crop_region=None, close_figure=False, output_image_height=None):
    """Draws the keypoint predictions on image."""
    # Use original image if flip was applied (since keypoints are already flipped)
    vis_image = original_image if flip_applied else image
    
    if flip_applied:
        # Flip the keypoints back to match the flipped image
        width = vis_image.shape[1]
        keypoints_with_scores[0, 0, :, 1] = 1 - keypoints_with_scores[0, 0, :, 1]

    height, width, _ = vis_image.shape
    aspect_ratio = float(width) / height
    fig, ax = plt.subplots(figsize=(12 * aspect_ratio, 12))
    fig.tight_layout(pad=0)
    ax.margins(0)
    ax.set_yticklabels([])
    ax.set_xticklabels([])
    plt.axis('off')

    im = ax.imshow(vis_image)
    line_segments = LineCollection([], linewidths=(4), linestyle='solid')
    ax.add_collection(line_segments)
    scat = ax.scatter([], [], s=60, color='#FF1493', zorder=3)

    (keypoint_locs, keypoint_edges, edge_colors) = _keypoints_and_edges_for_display(keypoints_with_scores, height, width)

    line_segments.set_segments(keypoint_edges)
    line_segments.set_color(edge_colors)

    for i, (keypoint, score) in enumerate(zip(keypoint_locs, keypoints_with_scores[0, 0, :, 2])):
        x, y = keypoint
        if score > 0.11:  # Only show keypoints with high enough confidence
            ax.text(x, y, f'{score:.2f}', fontsize=10, color='white', ha='center', va='center', zorder=4)

    if keypoint_edges.shape[0]:
        line_segments.set_segments(keypoint_edges)
        line_segments.set_color(edge_colors)
    if keypoint_locs.shape[0]:
        scat.set_offsets(keypoint_locs)

    if crop_region is not None:
        xmin = max(crop_region['x_min'] * width, 0.0)
        ymin = max(crop_region['y_min'] * height, 0.0)
        rec_width = min(crop_region['x_max'], 0.99) * width - xmin
        rec_height = min(crop_region['y_max'], 0.99) * height - ymin
        rect = patches.Rectangle(
            (xmin, ymin), rec_width, rec_height,
            linewidth=1, edgecolor='b', facecolor='none')
        ax.add_patch(rect)

    fig.canvas.draw()
    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    image_from_plot = image_from_plot.reshape(
        fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)
    if output_image_height is not None:
        output_image_width = int(output_image_height / height * width)
        image_from_plot = cv2.resize(image_from_plot, dsize=(output_image_width, output_image_height),
                                     interpolation=cv2.INTER_CUBIC)
    return image_from_plot

# Cropping Algorithm

MIN_CROP_KEYPOINT_SCORE = 0.2

def init_crop_region(image_height, image_width):
    """Defines the default crop region."""
    if image_width > image_height:
        box_height = image_width / image_height
        box_width = 1.0
        y_min = (image_height / 2 - image_width / 2) / image_height
        x_min = 0.0
    else:
        box_height = 1.0
        box_width = image_height / image_width
        y_min = 0.0
        x_min = (image_width / 2 - image_height / 2) / image_width

    return {
        'y_min': y_min,
        'x_min': x_min,
        'y_max': y_min + box_height,
        'x_max': x_min + box_width,
        'height': box_height,
        'width': box_width
    }

def torso_visible(keypoints):
    """Checks whether there are enough torso keypoints."""
    return ((keypoints[0, 0, KEYPOINT_DICT['left_hip'], 2] > MIN_CROP_KEYPOINT_SCORE or
             keypoints[0, 0, KEYPOINT_DICT['right_hip'], 2] > MIN_CROP_KEYPOINT_SCORE) and
            (keypoints[0, 0, KEYPOINT_DICT['left_shoulder'], 2] > MIN_CROP_KEYPOINT_SCORE or
             keypoints[0, 0, KEYPOINT_DICT['right_shoulder'], 2] > MIN_CROP_KEYPOINT_SCORE))

def determine_torso_and_body_range(keypoints, target_keypoints, center_y, center_x):
    """Calculates the maximum distance from each keypoint to the center location."""
    torso_joints = ['left_shoulder', 'right_shoulder', 'left_hip', 'right_hip']
    max_torso_yrange = 0.0
    max_torso_xrange = 0.0
    for joint in torso_joints:
        dist_y = abs(center_y - target_keypoints[joint][0])
        dist_x = abs(center_x - target_keypoints[joint][1])
        if dist_y > max_torso_yrange:
            max_torso_yrange = dist_y
        if dist_x > max_torso_xrange:
            max_torso_xrange = dist_x

    max_body_yrange = 0.0
    max_body_xrange = 0.0
    for joint in KEYPOINT_DICT.keys():
        if keypoints[0, 0, KEYPOINT_DICT[joint], 2] < MIN_CROP_KEYPOINT_SCORE:
            continue
        dist_y = abs(center_y - target_keypoints[joint][0])
        dist_x = abs(center_x - target_keypoints[joint][1])
        if dist_y > max_body_yrange:
            max_body_yrange = dist_y
        if dist_x > max_body_xrange:
            max_body_xrange = dist_x

    return [max_torso_yrange, max_torso_xrange, max_body_yrange, max_body_xrange]

def determine_crop_region(keypoints, image_height, image_width):
    """Determines the region to crop the image for the model."""
    target_keypoints = {}
    for joint in KEYPOINT_DICT.keys():
        target_keypoints[joint] = [
            keypoints[0, 0, KEYPOINT_DICT[joint], 0] * image_height,
            keypoints[0, 0, KEYPOINT_DICT[joint], 1] * image_width
        ]

    if torso_visible(keypoints):
        center_y = (target_keypoints['left_hip'][0] + target_keypoints['right_hip'][0]) / 2
        center_x = (target_keypoints['left_hip'][1] + target_keypoints['right_hip'][1]) / 2

        (max_torso_yrange, max_torso_xrange, max_body_yrange, max_body_xrange) = determine_torso_and_body_range(
            keypoints, target_keypoints, center_y, center_x)

        crop_length_half = np.amax([max_torso_xrange * 1.9, max_torso_yrange * 1.9,
                                    max_body_yrange * 1.2, max_body_xrange * 1.2])

        tmp = np.array([center_x, image_width - center_x, center_y, image_height - center_y])
        crop_length_half = np.amin([crop_length_half, np.amax(tmp)])

        crop_corner = [center_y - crop_length_half, center_x - crop_length_half]

        if crop_length_half > max(image_width, image_height) / 2:
            return init_crop_region(image_height, image_width)
        else:
            crop_length = crop_length_half * 2
            return {
                'y_min': crop_corner[0] / image_height,
                'x_min': crop_corner[1] / image_width,
                'y_max': (crop_corner[0] + crop_length) / image_height,
                'x_max': (crop_corner[1] + crop_length) / image_width,
                'height': (crop_corner[0] + crop_length) / image_height - crop_corner[0] / image_height,
                'width': (crop_corner[1] + crop_length) / image_width - crop_corner[1] / image_width
            }
    else:
        return init_crop_region(image_height, image_width)

def crop_and_resize(image, crop_region, crop_size):
    """Crops and resizes the image to prepare for the model input."""
    boxes=[[crop_region['y_min'], crop_region['x_min'],
            crop_region['y_max'], crop_region['x_max']]]
    output_image = tf.image.crop_and_resize(
        image, box_indices=[0], boxes=boxes, crop_size=crop_size)
    return output_image

def should_flip_image(keypoints_with_scores):
    """Determines if the image should be flipped based on keypoint positions."""
    # Get relevant keypoints with confidence checks
    left_shoulder = keypoints_with_scores[0, 0, KEYPOINT_DICT['left_shoulder']]
    right_shoulder = keypoints_with_scores[0, 0, KEYPOINT_DICT['right_shoulder']]
    left_wrist = keypoints_with_scores[0, 0, KEYPOINT_DICT['left_wrist']]
    left_knee = keypoints_with_scores[0, 0, KEYPOINT_DICT['left_knee']]

    score = 0
    valid_keypoints = 0
    
    # Shoulder comparison
    if left_shoulder[2] > KEYPOINT_THRESHOLD and right_shoulder[2] > KEYPOINT_THRESHOLD:
        if left_shoulder[1] > right_shoulder[1]:
            score += 1  # Facing left
        else:
            score -= 1  # Facing right
        valid_keypoints += 1
    
    # Wrist position
    if left_wrist[2] > KEYPOINT_THRESHOLD and left_shoulder[2] > KEYPOINT_THRESHOLD:
        if left_wrist[1] < left_shoulder[1]:
            score += 1  # Facing left
        else:
            score -= 1  # Facing right
        valid_keypoints += 1
    
    # Knee position
    if left_knee[2] > KEYPOINT_THRESHOLD and left_shoulder[2] > KEYPOINT_THRESHOLD:
        if left_knee[1] < left_shoulder[1]:
            score += 1  # Facing left
        else:
            score -= 1  # Facing right
        valid_keypoints += 1
    
    return score > 0 if valid_keypoints >= 2 else False

def flip_image_and_keypoints(image, keypoints_with_scores):
    """Flips the image and keypoints horizontally."""
    # Flip keypoints
    keypoints_with_scores[0, 0, :, 1] = 1 - keypoints_with_scores[0, 0, :, 1]
    
    # Flip image
    flipped_image = cv2.flip(image, 1)
    
    return keypoints_with_scores, flipped_image


# Confidence threshold for keypoints
KEYPOINT_THRESHOLD = 0.3

## Inference Angle Calc Save Function

In [6]:
def run_inference(movenet, image, crop_region, crop_size):
    """Runs model inference on the cropped region with proper flip handling."""
    image_height, image_width, _ = image.shape
    
    # First pass to determine orientation
    input_image = crop_and_resize(tf.expand_dims(image, axis=0), crop_region, crop_size=crop_size)
    keypoints_with_scores = movenet(input_image)
    flip_required = should_flip_image(keypoints_with_scores)
    
    # Second pass if flipping is needed
    if flip_required:
        flipped_image = cv2.flip(image, 1)
        input_image = crop_and_resize(tf.expand_dims(flipped_image, axis=0), crop_region, crop_size=crop_size)
        keypoints_with_scores = movenet(input_image)
        original_image = image.copy()
        image = flipped_image
    else:
        original_image = image.copy()
    
    # Adjust keypoints for crop region
    for idx in range(17):
        keypoints_with_scores[0, 0, idx, 0] = (
            crop_region['y_min'] * image_height +
            crop_region['height'] * image_height *
            keypoints_with_scores[0, 0, idx, 0]) / image_height
        keypoints_with_scores[0, 0, idx, 1] = (
            crop_region['x_min'] * image_width +
            crop_region['width'] * image_width *
            keypoints_with_scores[0, 0, idx, 1]) / image_width

    return keypoints_with_scores, image, original_image, flip_required
  


def save_visualized_frame(frame, output_folder, frame_idx, flip_applied):
    """Save the visualized frame to the output folder with flip check."""
    # If the flip was applied, flip the frame back before saving
    if flip_applied:
        frame = cv2.flip(frame, 1)
    
    output_path = os.path.join(output_folder, f"frame_{frame_idx:04d}.png")
    cv2.imwrite(output_path, frame)
    print(f"Saved frame {frame_idx} to {output_path}")

# Helper function to calculate joint angles using cosine law
def calculate_angle(a, b, c):
    a = np.array(a)
    b = np.array(b)
    c = np.array(c)
    ba = a - b
    bc = c - b
    cosine_angle = np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
    cosine_angle = np.clip(cosine_angle, -1.0, 1.0)
    angle = np.degrees(np.arccos(cosine_angle))
    return angle

def get_video_rotation(video_path):
    """Detect video rotation using OpenCV"""
    cap = cv2.VideoCapture(video_path)
    rotation = 0
    try:
        # Trying to detect rotation from metadata
        if cap.get(cv2.CAP_PROP_ORIENTATION_META) == 90:
            rotation = 90
        elif cap.get(cv2.CAP_PROP_ORIENTATION_META) == 180:
            rotation = 180
        elif cap.get(cv2.CAP_PROP_ORIENTATION_META) == 270:
            rotation = 270
    except Exception as e:
        print(f"Error while reading video metadata: {e}")
    cap.release()
    return rotation

def apply_rotation(frame, angle):
        if angle == 90:
            return cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
        elif angle == 180:
            return cv2.rotate(frame, cv2.ROTATE_180)
        elif angle == 270:
            return cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE)
        return frame

## Reba Guide

In [17]:
class TrunkREBA:
    """
    Standard REBA implementation for trunk scoring.
    
    Angle convention: 
    - Positive angle = forward flexion (leaning forward)
    - Negative angle = extension (leaning backward)
    
    Standard REBA trunk scoring:
    - 0-5° flexion: Score 1
    - 5-20° flexion: Score 2
    - 20-60° flexion: Score 3
    - >60° flexion: Score 4
    """
    def __init__(self, trunk_degrees):
        self.trunk_degrees = trunk_degrees

    def trunk_reba_score(self):
        waist_angle = self.trunk_degrees[0]  # Waist angle from get_joint_angles_and_labels
        
        # Base score starts at 0
        trunk_reba_score = 0
        
        # Standard REBA thresholds
        if waist_angle >= 0:  # Positive angle = forward flexion
            if 0 <= waist_angle <= 5:  # 0-5 degrees flexion
                trunk_reba_score = 1
            elif 5 < waist_angle <= 20:  # 5-20 degrees flexion
                trunk_reba_score = 2
            elif 20 < waist_angle <= 60:  # 20-60 degrees flexion
                trunk_reba_score = 3
            elif 60 < waist_angle:  
                trunk_reba_score = 4
            
        else:  # Negative angle = extension/backward lean
            abs_angle = abs(waist_angle)
            if 0 < abs_angle <= 5:  # 0-5 degrees extension
                trunk_reba_score = 1
            elif 5 < abs_angle:  
                trunk_reba_score = 2
        
        return [trunk_reba_score, 0, 0]

class NeckREBA:
    """  
    Standard REBA neck scoring:
    - 0-20° flexion: Score 1
    - >20° flexion: Score 2
    - Any extension: Score 2
    """
    def __init__(self, neck_degrees):
        self.neck_degrees = neck_degrees

    def neck_reba_score(self):
        neck_angle = self.neck_degrees[0]  # Raw neck angle from get_joint_angles_and_labels
        
        # Base score starts at 0
        neck_reba_score = 0
        
        # Standard REBA thresholds for neck
        if 0 <= neck_angle < 20:  # 0-20 degrees flexion
            neck_reba_score = 1
        elif neck_angle >= 20:  # >20 degrees flexion
            neck_reba_score = 2
        elif neck_angle < 0:  # Extension
            neck_reba_score = 2
                
        return [neck_reba_score, 0, 0]


class UpperArmREBA:
    """
    For calculating REBA score based on upper arm angles
    - Using standard REBA thresholds without buffer
    - Only using flexion for scoring, ignoring other factors
    """
    def __init__(self, arm_degrees):
        self.arm_degrees = arm_degrees  # [left_angle, right_angle]

    def upper_arm_score(self):
        left_angle = self.arm_degrees[0]
        right_angle = self.arm_degrees[1]
        
        # Take the worst case between left and right arm
        max_angle = max(abs(left_angle), abs(right_angle))
        
        # Base score
        upper_arm_reba_score = 0
        
        # Standard REBA thresholds without buffer
        if -20 <= max_angle < 20:  # -20 to 20 degrees
            upper_arm_reba_score = 1
        elif 20 <= max_angle < 45:  # 20 to 45 degrees
            upper_arm_reba_score = 2
        elif max_angle < -20 or (45 <= max_angle < 90):  # <-20 or 45 to 90 degrees
            upper_arm_reba_score = 3
        elif 90 <= max_angle:  # >90 degrees
            upper_arm_reba_score = 4
            
        return [upper_arm_reba_score]  


class LAREBA:
    """
    For calculating REBA score based on lower arm angles
    Using standard REBA thresholds with minor adjustments
    
    Standard REBA lower arm scoring:
    - 60-100°: Score 1 (neutral)
    - <60° or >100°: Score 2
    
    Input: List of measured reflex angles [left_reflex, right_reflex]
    """
    def __init__(self, reflex_angles):
        # Convert reflex angles to elbow bend angles
        self.elbow_bend_angles = [180 - angle for angle in reflex_angles]

    def lower_arm_score(self):
        left_bend = self.elbow_bend_angles[0]
        right_bend = self.elbow_bend_angles[1]
        
        left_score = self._score_single_arm(left_bend)
        right_score = self._score_single_arm(right_bend)
        
        return [max(left_score, right_score)]  # Return worst score
    
    def _score_single_arm(self, bend_angle):
        """Score based on elbow bend angle"""
        if 60 <= bend_angle <= 100:
            return 1  # Neutral
        return 2  # Non-neutral


class LegREBA:
    """
    For calculating REBA score based on leg angles
    Updated per expert feedback: legs don't matter for sitting assessment
    Always returns score of 1 (neutral) regardless of leg angles
    """
    def __init__(self, leg_degrees):
        self.leg_degrees = leg_degrees  # [left_angle, right_angle] - kept for future compatibility

    def leg_reba_score(self):
        """
        Returns neutral score of 1 for all leg positions
        Per ergonomics expert: leg angles don't matter for sitting assessment
        """
        # Always return 1 (neutral) - legs don't matter for sitting
        leg_reba_score = 1
        return [leg_reba_score]

class AngleSmoother:
    """Helper class to smooth angle measurements"""
    def __init__(self, window_size=3):
        from collections import deque
        import numpy as np
        self.history = deque(maxlen=window_size)
        self.np = np
        
    def smooth(self, angle):
        if angle is not None:
            self.history.append(angle)
            return self.np.mean(self.history)
        return None

## Get Pose Inference Data

In [8]:
# Helper function to get a valid keypoint
def get_keypoint_if_valid(validated_keypoints, keypoint_name):
    kp = validated_keypoints[keypoint_name]
    return (kp['y'], kp['x']) if kp['valid'] else None

# Helper function to calculate the angle with fallback
def calculate_angle_with_fallback(a_name, b_name, c_name, angle_name, validated_keypoints, imputed_angles, neutral_angles):
    a = get_keypoint_if_valid(validated_keypoints, a_name)
    b = get_keypoint_if_valid(validated_keypoints, b_name)
    c = get_keypoint_if_valid(validated_keypoints, c_name)
    
    if a is not None and b is not None and c is not None:
        try:
            angle = calculate_angle(a, b, c)
            return angle
        except:
            pass
    
    # If we get here, use neutral angle and flag as imputed
    imputed_angles[angle_name] = True
    return neutral_angles[angle_name]


def get_reba_tables():
    """
    Returns the REBA scoring tables (A, B, and C).
    
    Returns:
        dict: A dictionary containing the three REBA tables
    """
    # TABLE A: Neck, Trunk, and Legs scores
    # 3D lookup table: table_a[neck][trunk][legs]
    table_a = [
        # Neck = 1
        [
            [1, 2, 3, 4],  # Trunk = 1
            [2, 3, 4, 5],  # Trunk = 2
            [2, 4, 5, 6],  # Trunk = 3
            [3, 5, 6, 7],  # Trunk = 4
            [4, 6, 7, 8],  # Trunk = 5
        ],
        # Neck = 2
        [
            [1, 3, 4, 5],  # Trunk = 1
            [2, 4, 5, 6],  # Trunk = 2
            [3, 5, 6, 7],  # Trunk = 3
            [4, 6, 7, 8],  # Trunk = 4
            [5, 7, 8, 9],  # Trunk = 5
        ],
        # Neck = 3
        [
            [3, 4, 5, 6],  # Trunk = 1
            [3, 5, 6, 7],  # Trunk = 2
            [4, 6, 7, 8],  # Trunk = 3
            [5, 7, 8, 9],  # Trunk = 4
            [6, 8, 9, 10],  # Trunk = 5
        ],
    ]
    
    # TABLE B: Upper Arm, Lower Arm, and Wrist scores
    # 3D lookup table: table_b[upper_arm][lower_arm][wrist]
    table_b = [
        # Upper Arm = 1
        [
            [1, 2, 2],  # Lower Arm = 1
            [1, 2, 3],  # Lower Arm = 2
        ],
        # Upper Arm = 2
        [
            [1, 2, 3],  # Lower Arm = 1
            [2, 3, 4],  # Lower Arm = 2
        ],
        # Upper Arm = 3
        [
            [3, 4, 5],  # Lower Arm = 1
            [3, 4, 5],  # Lower Arm = 2
        ],
        # Upper Arm = 4
        [
            [4, 5, 5],  # Lower Arm = 1
            [4, 5, 6],  # Lower Arm = 2
        ],
        # Upper Arm = 5
        [
            [6, 7, 8],  # Lower Arm = 1
            [6, 7, 8],  # Lower Arm = 2
        ],
        # Upper Arm = 6
        [
            [7, 8, 8],  # Lower Arm = 1
            [7, 8, 9],  # Lower Arm = 2
        ],
    ]
    
    # TABLE C: Score A and Score B combination
    # 2D lookup table: table_c[score_a][score_b]
    table_c = [
        [1, 1, 1, 2, 3, 3, 4, 5, 6, 7, 7, 7],  # Score A = 1
        [1, 2, 2, 3, 4, 4, 5, 6, 6, 7, 7, 8],  # Score A = 2
        [2, 3, 3, 3, 4, 5, 6, 7, 7, 8, 8, 8],  # Score A = 3
        [3, 4, 4, 4, 5, 6, 7, 8, 8, 9, 9, 9],  # Score A = 4
        [4, 4, 4, 5, 6, 7, 8, 8, 9, 9, 9, 9],  # Score A = 5
        [6, 6, 6, 7, 8, 8, 9, 9, 10, 10, 10, 10],  # Score A = 6
        [7, 7, 7, 8, 9, 9, 9, 10, 10, 11, 11, 11],  # Score A = 7
        [8, 8, 8, 9, 10, 10, 10, 10, 10, 11, 11, 11],  # Score A = 8
        [9, 9, 9, 10, 10, 10, 11, 11, 11, 12, 12, 12],  # Score A = 9
        [10, 10, 10, 11, 11, 11, 11, 12, 12, 12, 12, 12],  # Score A = 10
        [11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12],  # Score A = 11
        [12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12],  # Score A = 12
    ]
    
    return {
        "table_a": table_a,
        "table_b": table_b,
        "table_c": table_c
    }

def calculate_final_reba_score(neck_score, trunk_score, leg_score, upper_arm_score, lower_arm_score, 
                             wrist_score=1, load_score=0, coupling_score=0, activity_score=0,
                             is_sitting=True, return_details=False):
    """
    Calculate REBA score using Table A, Table B, and Table C lookups.
    
    Args:
        neck_score: Score for neck posture (1-3)
        trunk_score: Score for trunk posture (1-6)
        leg_score: Score for leg posture (1-4)
        upper_arm_score: Score for upper arm posture (1-6)
        lower_arm_score: Score for lower arm posture (1-3)
        wrist_score: Score for wrist posture (default=1, neutral)
        load_score: Score for load/force (default=0)
        coupling_score: Score for coupling/grip (default=0)
        activity_score: Score for activity type (default=0)
        is_sitting: Boolean for sitting posture (default=True)
        return_details: Whether to return intermediate scores (default=False)
    
    Returns:
        If return_details=False: tuple (final_score, risk_level)
        If return_details=True: tuple (final_score, risk_level, intermediate_scores)
    """
    # Adjust leg score for sitting posture if needed
    if is_sitting:
        # Sitting typically has a minimum leg score of 1 
        # (since legs are supported when sitting correctly)
        leg_score = max(1, leg_score)
    
    # Get the REBA scoring tables
    tables = get_reba_tables()
    table_a = tables["table_a"]
    table_b = tables["table_b"]
    table_c = tables["table_c"]
    
    # Make sure scores are within valid range for table lookup
    neck_idx = min(max(neck_score - 1, 0), 2)  # 0-2 index for scores 1-3
    trunk_idx = min(max(trunk_score - 1, 0), 4)  # 0-4 index for scores 1-5
    leg_idx = min(max(leg_score - 1, 0), 3)  # 0-3 index for scores 1-4
    
    # Table A lookup
    score_a = table_a[neck_idx][trunk_idx][leg_idx]
    
    # Add load/force score
    score_a += load_score
    
    # Make sure scores are within valid range for table lookup
    upper_arm_idx = min(max(upper_arm_score - 1, 0), 5)  # 0-5 index for scores 1-6
    lower_arm_idx = min(max(lower_arm_score - 1, 0), 1)  # 0-1 index for scores 1-2
    wrist_idx = min(max(wrist_score - 1, 0), 2)  # 0-2 index for scores 1-3
    
    # Table B lookup
    score_b = table_b[upper_arm_idx][lower_arm_idx][wrist_idx]
    
    # Add coupling score
    score_b += coupling_score
    
    # Make sure scores are within valid range for table lookup
    score_a_idx = min(max(score_a - 1, 0), 11)  # 0-11 index for scores 1-12
    score_b_idx = min(max(score_b - 1, 0), 11)  # 0-11 index for scores 1-12
    
    # Table C lookup
    score_c = table_c[score_a_idx][score_b_idx]
    
    # Add activity score
    final_score = score_c + activity_score
    
    # Determine risk level
    if final_score <= 1:
        risk_level = "Negligible"
    elif final_score <= 3:
        risk_level = "Low"
    elif final_score <= 7:
        risk_level = "Medium"
    elif final_score <= 10:
        risk_level = "High"
    else:
        risk_level = "Very High"
    
    # Either return just the basic info or include the intermediate scores
    if return_details:
        # Create intermediate scores dictionary
        intermediate_scores = {
            "score_a": score_a,
            "score_b": score_b,
            "score_c": score_c,
            "neck_score": neck_score,
            "trunk_score": trunk_score,
            "leg_score": leg_score,
            "upper_arm_score": upper_arm_score,
            "lower_arm_score": lower_arm_score,
            "wrist_score": wrist_score,
            "load_score": load_score,
            "coupling_score": coupling_score,
            "activity_score": activity_score
        }
        return final_score, risk_level, intermediate_scores
    
    
    return final_score, risk_level


def get_joint_angles_and_labels(keypoints_with_scores, keypoint_threshold=KEYPOINT_THRESHOLD):
    """
    Calculate joint angles and REBA scores from pose keypoints
    
    Args:
        keypoints_with_scores: Output from MoveNet model
        keypoint_threshold: Confidence threshold for valid keypoints
    
    Returns:
        dict: Contains all angles, REBA scores, and risk assessment
    """
    keypoints = keypoints_with_scores[0, 0, :, :2]
    scores = keypoints_with_scores[0, 0, :, 2]

    # Initialize smoothers if they don't exist
    if not hasattr(get_joint_angles_and_labels, 'smoothers'):
        get_joint_angles_and_labels.smoothers = {
            'left_leg': AngleSmoother(),
            'right_leg': AngleSmoother(),
            'neck': AngleSmoother(),
            'trunk': AngleSmoother(),
            'upper_arm': AngleSmoother(),
            'lower_arm': AngleSmoother(),
        }

    # Initialize tracking dictionaries
    imputed_angles = {
        'left_leg': False,
        'right_leg': False,
        'neck': False,
        'waist': False,
        'left_upper_arm': False,
        'right_upper_arm': False,
        'left_lower_arm': False,
        'right_lower_arm': False
    }

    neutral_angles = {
        'left_leg': 100,
        'right_leg': 100,
        'left_upper_arm': 0,
        'right_upper_arm': 0,
        'left_lower_arm': 90,
        'right_lower_arm': 90,
        'waist': 110,
        'neck': 170
    }

    # Create validated keypoints dictionary
    validated_keypoints = {}
    for name, idx in KEYPOINT_DICT.items():
        validated_keypoints[name] = {
            'x': keypoints[idx][1] if scores[idx] > keypoint_threshold else None,
            'y': keypoints[idx][0] if scores[idx] > keypoint_threshold else None,
            'valid': scores[idx] > keypoint_threshold
        }

    # Calculate all angles with fallback
    angles = {}
    
    # Calculate waist angle (with forward/backward detection)
    shoulder_left = get_keypoint_if_valid(validated_keypoints, 'left_shoulder')
    shoulder_right = get_keypoint_if_valid(validated_keypoints, 'right_shoulder')
    hip_left = get_keypoint_if_valid(validated_keypoints, 'left_hip')
    hip_right = get_keypoint_if_valid(validated_keypoints, 'right_hip')
    
    if all([shoulder_left, shoulder_right, hip_left, hip_right]):
        # Original angle calculation
        shoulder_vec = np.array([shoulder_left[0] - shoulder_right[0],
                               shoulder_left[1] - shoulder_right[1]])
        hip_vec = np.array([hip_left[0] - hip_right[0],
                           hip_left[1] - hip_right[1]])
        
        dot_product = np.dot(shoulder_vec, hip_vec)
        shoulder_mag = np.linalg.norm(shoulder_vec)
        hip_mag = np.linalg.norm(hip_vec)
        
        if shoulder_mag > 0 and hip_mag > 0:
            cos_angle = dot_product / (shoulder_mag * hip_mag)
            cos_angle = np.clip(cos_angle, -1.0, 1.0)
            unsigned_angle = np.degrees(np.arccos(cos_angle))
            
            # Trunk flexion (forward/backward lean)
            shoulder_center_y = (shoulder_left[1] + shoulder_right[1]) / 2
            hip_center_y = (hip_left[1] + hip_right[1]) / 2
            
            # Forward lean (shoulders lower than hips in image coordinates)
            if shoulder_center_y > hip_center_y:
                angles['waist'] = unsigned_angle  # Positive for forward
                angles['waist_direction'] = "forward"
            else:
                angles['waist'] = -unsigned_angle  # Negative for backward
                angles['waist_direction'] = "backward"
            
            imputed_angles['waist'] = False
        else:
            angles['waist'] = neutral_angles['waist']
            imputed_angles['waist'] = True
    else:
        angles['waist'] = neutral_angles['waist']
        imputed_angles['waist'] = True

    # Calculate neck angle
    ear_point = get_keypoint_if_valid(validated_keypoints, 'left_ear')
    if ear_point is None:
        ear_point = get_keypoint_if_valid(validated_keypoints, 'right_ear')
    
    if ear_point is not None and shoulder_left is not None and shoulder_right is not None:
        # Calculate midpoint between shoulders
        mid_shoulder = ((shoulder_left[0] + shoulder_right[0])/2, 
                       (shoulder_left[1] + shoulder_right[1])/2)
        
        # Calculate angle between ear and mid-shoulder point (vertical line)
        # Create a point directly above mid_shoulder (same x, lower y in image coordinates)
        vertical_point = (mid_shoulder[0] - 1, mid_shoulder[1])
        
        try:
            angle = calculate_angle(ear_point, mid_shoulder, vertical_point)
            angles['neck'] = angle
            imputed_angles['neck'] = False
        except:
            angles['neck'] = neutral_angles['neck']
            imputed_angles['neck'] = True
    else:
        angles['neck'] = neutral_angles['neck']
        imputed_angles['neck'] = True

    # Calculate other angles
    angle_mapping = {
        'left_upper_arm': ('left_hip', 'left_shoulder', 'left_elbow'),
        'right_upper_arm': ('right_hip', 'right_shoulder', 'right_elbow'),
        'left_lower_arm': ('left_shoulder', 'left_elbow', 'left_wrist'),
        'right_lower_arm': ('right_shoulder', 'right_elbow', 'right_wrist'),
        'left_leg': ('left_hip', 'left_knee', 'left_ankle'),
        'right_leg': ('right_hip', 'right_knee', 'right_ankle')
    }

    for angle_name, points in angle_mapping.items():
        angles[angle_name] = calculate_angle_with_fallback(
            points[0], points[1], points[2], angle_name,
            validated_keypoints, imputed_angles, neutral_angles)

    # Apply smoothing
    for angle_name in angles:
        if angle_name in get_joint_angles_and_labels.smoothers:
            angles[angle_name] = get_joint_angles_and_labels.smoothers[angle_name].smooth(angles[angle_name])

    # Check if we have minimum required angles (neck + waist + at least one other)
    has_minimum_angles = not imputed_angles['neck'] and not imputed_angles['waist']

    if not has_minimum_angles:
        missing_angles = []
        if imputed_angles['neck']:
            missing_angles.append("neck")
        if imputed_angles['waist']:
            missing_angles.append("waist")
        
        print(f"⚠ Skipping frame - Missing: {', '.join(missing_angles)}")
        return None, None

    # Calculate individual REBA scores
    upper_arm_reba = UpperArmREBA([angles['left_upper_arm'], angles['right_upper_arm']]).upper_arm_score()[0]
    lower_arm_reba = LAREBA([angles['left_lower_arm'], angles['right_lower_arm']]).lower_arm_score()[0]
    neck_reba = NeckREBA([angles['neck'], 0, 0]).neck_reba_score()[0]
    trunk_reba = TrunkREBA([angles['waist'], 0, 0]).trunk_reba_score()[0]
    leg_reba = LegREBA([angles['left_leg'], angles['right_leg']]).leg_reba_score()[0]

    # Calculate final score
    final_score, risk_level, intermediate_scores = calculate_final_reba_score(
        neck_reba, trunk_reba, leg_reba, upper_arm_reba, lower_arm_reba,
        is_sitting=True,  
        return_details=True  
    )

    # Prepare complete output
    result = {
        **angles,
        'upper_arm_reba': upper_arm_reba,
        'lower_arm_reba': lower_arm_reba,
        'neck_reba': neck_reba,
        'trunk_reba': trunk_reba,
        'leg_reba': leg_reba,
        'reba_table_a_score': intermediate_scores['score_a'],
        'reba_table_b_score': intermediate_scores['score_b'],
        'reba_table_c_score': intermediate_scores['score_c'],
        'reba_grand_total': final_score,
        'reba_risk_level': risk_level,
        'imputed_angles': imputed_angles,
        'validated_keypoints': validated_keypoints
    }

    return result, final_score

## Save Data To CSV

In [9]:
def get_csv_fieldnames():
    """Return the fieldnames for the CSV file with additional REBA table scores"""
    base_fields = [
        'File Name', 'Frame', 
        # Joint Angles
        'Neck Angle', 'Left Upper Arm Angle', 'Right Upper Arm Angle',
        'Left Lower Arm Angle', 'Right Lower Arm Angle',
        'Waist Angle', 'Left Leg Angle', 'Right Leg Angle',
        
        # REBA Component Scores
        'Upper Arm REBA', 'Lower Arm REBA',
        'Neck REBA', 'Trunk REBA', 'Leg REBA',
        
        # New REBA Table Scores
        'REBA Table A Score', 'REBA Table B Score', 'REBA Table C Score',
        'REBA Grand Total', 'REBA Risk Level',
        
        # Imputation Flags
        'Neck Imputed', 'Left Arm Imputed', 'Right Arm Imputed',
        'Left Elbow Imputed', 'Right Elbow Imputed',
        'Waist Imputed', 'Left Leg Imputed', 'Right Leg Imputed',
    ]
    
    # Add all keypoint coordinate fields
    keypoint_fields = []
    for kp_name in KEYPOINT_DICT.keys():
        keypoint_fields.append(f'{kp_name} X')
        keypoint_fields.append(f'{kp_name} Y')
    
    return base_fields + keypoint_fields


def create_csv_row(angles, filename, frame_num):
    """Create a dictionary representing one row of CSV data with the new table scores"""
    validated_keypoints = angles.get('validated_keypoints', {})
    
    # Get the risk level directly from the angles if available, otherwise default to "Unknown"
    risk_level = angles.get('reba_risk_level', "Unknown")
    
    row = {
        'File Name': filename,
        'Frame': frame_num,
        
        # Joint Angles
        'Neck Angle': angles.get('neck', -1),
        'Left Upper Arm Angle': angles.get('left_upper_arm', -1),
        'Right Upper Arm Angle': angles.get('right_upper_arm', -1),
        'Left Lower Arm Angle': angles.get('left_lower_arm', -1),
        'Right Lower Arm Angle': angles.get('right_lower_arm', -1),
        'Waist Angle': angles.get('waist', -1),
        'Left Leg Angle': angles.get('left_leg', -1),
        'Right Leg Angle': angles.get('right_leg', -1),
        
        # Individual REBA Scores
        'Upper Arm REBA': angles.get('upper_arm_reba', -1),
        'Lower Arm REBA': angles.get('lower_arm_reba', -1),
        'Neck REBA': angles.get('neck_reba', -1),
        'Trunk REBA': angles.get('trunk_reba', -1),
        'Leg REBA': angles.get('leg_reba', -1),
        
        # New REBA Table Scores
        'REBA Table A Score': angles.get('reba_table_a_score', -1),
        'REBA Table B Score': angles.get('reba_table_b_score', -1),
        'REBA Table C Score': angles.get('reba_table_c_score', -1),
        'REBA Grand Total': angles.get('reba_grand_total', -1),
        'REBA Risk Level': risk_level,
        
        # Imputation Flags
        'Neck Imputed': int(angles.get('imputed_angles', {}).get('neck', False)),
        'Left Arm Imputed': int(angles.get('imputed_angles', {}).get('left_upper_arm', False)),
        'Right Arm Imputed': int(angles.get('imputed_angles', {}).get('right_upper_arm', False)),
        'Left Elbow Imputed': int(angles.get('imputed_angles', {}).get('left_lower_arm', False)),
        'Right Elbow Imputed': int(angles.get('imputed_angles', {}).get('right_lower_arm', False)),
        'Waist Imputed': int(angles.get('imputed_angles', {}).get('waist', False)),
        'Left Leg Imputed': int(angles.get('imputed_angles', {}).get('left_leg', False)),
        'Right Leg Imputed': int(angles.get('imputed_angles', {}).get('right_leg', False)),
    }
    
    # Add keypoint coordinates with None for invalid points
    for kp_name in KEYPOINT_DICT.keys():
        kp = validated_keypoints.get(kp_name, {'x': None, 'y': None})
        row[f'{kp_name} X'] = kp.get('x', None)
        row[f'{kp_name} Y'] = kp.get('y', None)
    
    return row


def save_angles_to_csv(angles_list, output_csv, filename=None, frame_num=0, is_new_file=False):
    """Save all joint angles, REBA scores, and keypoints to CSV"""
    import os
    import csv
    
    file_exists = os.path.exists(output_csv) and os.path.getsize(output_csv) > 0
    
    with open(output_csv, 'a' if not is_new_file and file_exists else 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=get_csv_fieldnames())
        
        # Write header if it's a new file or if the file is empty
        if is_new_file or not file_exists:
            writer.writeheader()

        for idx, angles in enumerate(angles_list):
            row_filename = filename if filename else angles.get('filename', 'unknown')
            row = create_csv_row(angles, row_filename, frame_num + idx)
            writer.writerow(row)

## Processing Media

In [10]:
def process_media(input_path, output_folder, output_csv, frame_interval=3, batch_size=32, checkpoint_file='checkpoint.txt'):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Check if we need to create a new CSV file
    is_new_csv = not os.path.exists(output_csv) or os.path.getsize(output_csv) == 0

    # Initialize data collection
    all_angles = []
    all_filenames = []

    # Directory Processing Mode
    if os.path.isdir(input_path):
        print(f"\n📂 Processing directory: {input_path}")
        files = sorted([f for f in os.listdir(input_path) if f.lower().endswith(
            ('.png','.jpg','.jpeg','.bmp','.tiff','.gif','.mp4','.avi','.mov','.mkv')
        )])

        total_files = len(files)
        for i, fname in enumerate(files, 1):
            full_path = os.path.join(input_path, fname)
            print(f"\n[{i}/{total_files}] Processing: {fname}")

            # Use frame_interval=1 for images, keep specified interval for videos
            current_interval = 1 if fname.lower().endswith(('.png','.jpg','.jpeg','.bmp','.tiff','.gif')) else frame_interval

            # Process each file
            angles = process_single_media(
                full_path, 
                output_folder,
                output_csv,
                current_interval,
                batch_size,
                checkpoint_file,
                save_to_csv=False
            )

            if angles is not None:
                all_angles.append(angles)
                all_filenames.append(fname)

        # Save all collected data to CSV at once
        if all_angles:
            save_angles_to_csv(all_angles, output_csv, is_new_file=is_new_csv)
            print(f"\n✓ Saved data for {len(all_angles)} files to {output_csv}")
        return

    # Single File Processing
    process_single_media(
        input_path,
        output_folder,
        output_csv,
        frame_interval,
        batch_size,
        checkpoint_file,
        save_to_csv=True
    )

In [11]:
def process_single_media(input_path, output_folder, output_csv, frame_interval, batch_size=32, checkpoint_file='checkpoint.txt', save_to_csv=False):
    """
    Process a single image or video file for pose estimation.
    
    Args:
        input_path: Path to the input media file (image or video)
        output_folder: Folder to save the visualization results
        output_csv: Path to save CSV data
        frame_interval: Process every nth frame for videos
        batch_size: Number of frames to process in a batch
        checkpoint_file: File to store processing checkpoint for resuming
        save_to_csv: Whether to save single image results to CSV
    
    Returns:
        For single images: dict with angle data, None for videos
    """
    filename = os.path.basename(input_path)
    print(f"\nStarting processing for: {filename}")
    
    # ====== 1. Check for resume conditions ======
    is_new_csv, resume_frame, existing_frames = check_resume_conditions(
        filename, output_csv, checkpoint_file)
    
    # ====== 2. Process based on file type ======
    file_extension = os.path.splitext(input_path)[1].lower()
    
    # Process image file
    if file_extension in ('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.gif'):
        return process_single_image(
            input_path, filename, output_folder, output_csv, 
            is_new_csv, save_to_csv)
    
    # Process video file
    elif file_extension in ('.mp4', '.avi', '.mov', '.mkv'):
        return process_video(
            input_path, filename, output_folder, output_csv, 
            frame_interval, batch_size, checkpoint_file,
            is_new_csv, resume_frame, existing_frames, apply_rotation)
    
    # Unsupported file type
    else:
        print("⚠ Unsupported file type")
        return None


def check_resume_conditions(filename, output_csv, checkpoint_file):
    """Check if we need to resume processing and from where."""
    is_new_csv = not os.path.exists(output_csv)
    resume_frame = 0
    existing_frames = set()
    
    # Check checkpoint file first
    if os.path.exists(checkpoint_file):
        try:
            with open(checkpoint_file, 'r') as f:
                content = f.read().strip()
                if content:
                    last_file, last_frame = content.split(',')
                    if last_file == filename:
                        resume_frame = int(last_frame)
                        print(f"↻ Resuming {filename} from frame {resume_frame} (checkpoint)")
        except Exception as e:
            print("⚠ Checkpoint file error:", e)

    # Check CSV for existing frames
    if os.path.exists(output_csv):
        try:
            df_existing = pd.read_csv(output_csv)
            if not df_existing.empty and 'File Name' in df_existing.columns:
                file_records = df_existing[df_existing['File Name'] == filename]
                if not file_records.empty:
                    existing_frames = set(file_records['Frame'].values)
                    print(f"ℹ Found {len(existing_frames)} existing frames in CSV")
                    
                    # If no checkpoint but CSV has data, resume from last CSV frame + 1
                    if resume_frame == 0 and len(existing_frames) > 0:
                        resume_frame = max(existing_frames) + 1
                        print(f"↻ Resuming {filename} from frame {resume_frame} (CSV)")
        except Exception as e:
            print("⚠ CSV read error:", e)
            
    return is_new_csv, resume_frame, existing_frames


def process_single_image(input_path, filename, output_folder, output_csv, is_new_csv, save_to_csv):
    """Process a single image file."""
    try:
        # Read and convert image
        frame = cv2.imread(input_path)
        if frame is None:
            print(f"⚠ Error: Could not read image {filename}")
            return None
            
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        
        # First run inference to detect keypoints
        keypoints, processed_img, original_img, _ = run_inference(
            movenet, frame, init_crop_region(frame.shape[0], frame.shape[1]), 
            crop_size=[input_size, input_size])
        
        # Add flipping check for single images
        flip_required = should_flip_image(keypoints)
        print(f"  Orientation: {'FLIPPED' if flip_required else 'NORMAL'}")
        
        # If flipping is required, flip the image and run inference again
        if flip_required:
            flipped_frame = cv2.flip(frame, 1)
            keypoints, processed_img, original_img, _ = run_inference(
                movenet, flipped_frame, init_crop_region(flipped_frame.shape[0], flipped_frame.shape[1]), 
                crop_size=[input_size, input_size])
        
        # Create and save visualization
        vis_frame = draw_prediction_on_image(processed_img, keypoints, original_img, flip_required)
        vis_frame = cv2.cvtColor(vis_frame, cv2.COLOR_RGB2BGR)
        
        output_path = os.path.join(output_folder, f"{os.path.splitext(filename)[0]}_pose.png")
        cv2.imwrite(output_path, vis_frame)
        print(f"Saved visualization to {output_path}")

        # Calculate angles and joint positions
        angles, _ = get_joint_angles_and_labels(keypoints)
        
        if angles is None:
            print(f"⚠ Skipping image {filename} - insufficient keypoints")
            return None
            
        # Add filename to angles dict for CSV saving
        angles['filename'] = filename
        
        if save_to_csv:
            save_angles_to_csv([angles], output_csv, 
                            filename=filename, frame_num=0, is_new_file=is_new_csv)
        
        return angles
        
    except Exception as e:
        print(f"⚠ Error processing image {filename}: {e}")
        return None


def process_video(input_path, filename, output_folder, output_csv, 
                 frame_interval, batch_size, checkpoint_file,
                 is_new_csv, resume_frame, existing_frames, apply_rotation):
    """Process a video file."""
    cap = cv2.VideoCapture(input_path)
    if not cap.isOpened():
        print(f"⚠ Error opening video {filename}")
        return None
        
    # Get video rotation metadata
    rotation = get_video_rotation(input_path)
    print(f"Rotation detected: {rotation}°")
    
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    duration = total_frames / fps if fps > 0 else 0
    
    print(f"  Total frames: {total_frames} (~{duration:.1f} seconds)")
    print(f"  Processing every {frame_interval} frames")
    
    # Process first frame to determine orientation
    ret, first_frame = cap.read()
    if not ret:
        print(f"⚠ Error reading first frame of {filename}")
        cap.release()
        return None
        
    # Apply rotation to first frame
    first_frame = apply_rotation(first_frame, rotation)
    first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
    
    # Determine if flipping is needed
    kpts, proc_img, orig_img, flip_required = run_inference(
        movenet, first_frame_rgb, 
        init_crop_region(first_frame_rgb.shape[0], first_frame_rgb.shape[1]),
        crop_size=[input_size, input_size])
        
    print(f"  Orientation: {'FLIPPED' if flip_required else 'NORMAL'}")

    # Save first frame visualization
    vis_frame = draw_prediction_on_image(proc_img, kpts, orig_img, flip_required)
    vis_frame = cv2.cvtColor(vis_frame, cv2.COLOR_RGB2BGR)
    save_visualized_frame(vis_frame, output_folder, 0, flip_required)

    # Process all frames
    with open(output_csv, 'a' if not is_new_csv else 'w', newline='') as csvfile:
        writer = csv.DictWriter(csvfile, fieldnames=get_csv_fieldnames())
        if is_new_csv:
            writer.writeheader()

        # Resume from previous position if needed
        if resume_frame > 0:
            cap.set(cv2.CAP_PROP_POS_FRAMES, resume_frame)
            print(f"  Seeking to frame {resume_frame}/{total_frames}")
            
        frame_count = resume_frame if resume_frame > 0 else 0
        processed_count = 0
        
        # Main processing loop
        while cap.isOpened():
            # Read batch of frames
            frames = []
            for _ in range(batch_size):
                ret, frame = cap.read()
                if not ret:
                    break
                frames.append(frame)

            if not frames:
                break

            # Process each frame in batch
            for i, frame in enumerate(frames):
                current_frame = frame_count + i
                if current_frame % frame_interval == 0 and current_frame not in existing_frames:
                    try:
                        # Apply rotation
                        frame = apply_rotation(frame, rotation)
                        frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                        
                        # Apply flipping if needed (consistent with first frame)
                        if flip_required:
                            frame_rgb = cv2.flip(frame_rgb, 1)
                        
                        # Run inference
                        kpts, proc_img, orig_img, _ = run_inference(
                            movenet, frame_rgb,
                            init_crop_region(frame_rgb.shape[0], frame_rgb.shape[1]),
                            crop_size=[input_size, input_size])

                        # Calculate angles
                        angles, _ = get_joint_angles_and_labels(kpts)
                        if angles is None:
                            print(f"⚠ Skipping frame {current_frame} - insufficient keypoints")
                            continue
                        
                        # Add to CSV
                        processed_count += 1
                        row = create_csv_row(angles, filename, current_frame)
                        writer.writerow(row)

                        # Periodically save visualizations
                        if current_frame % 1000 == 0:
                            vis_frame = draw_prediction_on_image(
                                proc_img, kpts, orig_img, flip_required)
                            vis_frame = cv2.cvtColor(vis_frame, cv2.COLOR_RGB2BGR)
                            save_visualized_frame(vis_frame, output_folder, current_frame, flip_required)

                    except Exception as e:
                        print(f"⚠ Frame {current_frame} error: {str(e)[:100]}")

            # Update progress
            frame_count += len(frames)
            progress = (frame_count / total_frames) * 100
            print(f"  Progress: {frame_count}/{total_frames} ({progress:.1f}%) - Processed: {processed_count}", end='\r')
            
            # Update checkpoint
            with open(checkpoint_file, 'w') as f:
                f.write(f"{filename},{frame_count}\n")

        # Cleanup
        cap.release()
        print(f"\n✓ Completed {filename} - Processed {processed_count} keyframes")
        try:
            if os.path.exists(checkpoint_file):
                os.remove(checkpoint_file)
        except Exception as e:
            print(f"⚠ Failed to remove checkpoint: {e}")
        return None

In [None]:
# Call the function with your parameters
process_media('Video', 'rev', 'Datarevisi.csv')