# COCO Pose Estimation Quality Metric

**Evaluating ControlNet-Generated Poses using COCO Dataset Keypoints**

This notebook compares poses from ControlNet-generated images against ground truth COCO keypoints using:
1. **COCO API** to extract ground truth keypoints from dataset
2. **MMPose** with RTMPose model to extract keypoints from generated images
3. **OKS (Object Keypoint Similarity)** metric for comparison
4. Additional metrics: mAP@OKS, PCK, PCKh

**Approach:**
- Load first 75 images from val_captions.json
- Extract COCO ground truth keypoints
- Run MMPose on generated images
- Compare using COCO evaluation metrics

## 1. Install Dependencies

In [None]:
# Install basic dependencies first
%pip install -q numpy pandas matplotlib seaborn scipy
%pip install -q opencv-python
%pip install -q pycocotools

import sys
print(f"Python version: {sys.version}")

# For Python >= 3.11 (Colab), use OpenMMLab v2 stack with openmim
print("Installing OpenMMLab v2 stack (compatible with Python 3.12 + Torch 2.x)...")

# Clean up any old installations
print("Cleaning up previous installations...")
%pip uninstall -y mmcv mmcv-full mmengine mmdet mmpose openmim 2>/dev/null || true

# Upgrade pip
print("Upgrading pip...")
%pip install -U pip setuptools wheel

# Install openmim (OpenMMLab package installer - handles platform-specific wheels)
print("Installing openmim...")
%pip install -q -U openmim

# Install mmengine (required base for all v2 packages)
print("Installing mmengine...")
%pip install -q mmengine

# Use mim to install mmcv (automatically finds compatible prebuilt wheels)
print("Installing mmcv 2.1.0 via mim...")
import subprocess
result = subprocess.run(
    ['mim', 'install', 'mmcv==2.1.0'],
    capture_output=True,
    text=True
)
print(result.stdout)
if result.returncode != 0:
    print(f"Warning: mim install had issues: {result.stderr}")
    print("Trying direct pip install as fallback...")
    %pip install -q "mmcv==2.1.0"

# Install MMDetection 3.3.0
print("Installing MMDetection 3.3.0...")
%pip install -q "mmdet==3.3.0"

# Install MMPose 1.3.2 dependencies first
print("Installing MMPose dependencies...")
%pip install -q json_tricks munkres pillow

# Install MMPose 1.3.2 with --no-deps to skip problematic builds
print("Installing MMPose 1.3.2...")
%pip install -q --no-deps "mmpose==1.3.2"

print("‚úì Installation complete (OpenMMLab v2)!")

In [None]:
# Verify installation
import sys
import torch
print(f"Python: {sys.version}")
print(f"Torch:   {torch.__version__}, CUDA available: {torch.cuda.is_available()}")

try:
    import mmengine
    print(f"MMEngine: {mmengine.__version__}")
except Exception as e:
    print(f"MMEngine import failed: {type(e).__name__}: {e}")

try:
    import mmcv
    print(f"MMCV:     {mmcv.__version__}")
except Exception as e:
    print(f"MMCV import failed: {type(e).__name__}: {e}")

try:
    import mmdet
    print(f"MMDet:    {mmdet.__version__}")
except Exception as e:
    print(f"MMDet import failed: {type(e).__name__}: {e}")

try:
    import mmpose
    print(f"MMPose:   {mmpose.__version__}")
    print("‚úÖ All packages loaded successfully (v2 API)")
except Exception as e:
    print(f"MMPose import failed: {type(e).__name__}: {e}")

In [None]:
# Imports and configuration
import json
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Dict, Tuple
import pandas as pd
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Try v1 APIs first; fall back to v2 inferencers on ImportError
USE_INFERENCER = False
try:
    from mmdet.apis import init_detector, inference_detector
    from mmpose.apis import inference_top_down_pose_model, init_pose_model
    print("‚úì Using OpenMMLab v1 APIs (init_* / inference_*).")
except Exception:
    USE_INFERENCER = True
    from mmdet.apis import DetInferencer
    from mmpose.apis import MMPoseInferencer
    print("‚úì Using OpenMMLab v2 Inferencers (DetInferencer / MMPoseInferencer).")

# Paths and model URLs
VAL_CAPTIONS_PATH = "./val_captions.json"
GENERATED_IMAGES_PATH = "path/to/generated/images"  # e.g., "./generated_images"
NUM_IMAGES = 75

# v1 configs/checkpoints (used when USE_INFERENCER=False)
MMPOSE_CONFIG_URL = "https://raw.githubusercontent.com/open-mmlab/mmpose/master/configs/body/2d_kpt_sview_rgb_img/topdown_heatmap/coco/hrnet_w48_coco_256x192.py"
MMPOSE_CHECKPOINT_URL = "https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_256x192-b9e0b3ab_20200708.pth"

DETECTOR_CONFIG_URL = "https://raw.githubusercontent.com/open-mmlab/mmdetection/v2.28.2/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py"
DETECTOR_CHECKPOINT_URL = "https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth"
DETECTOR_SCORE_THR = 0.4  # keep person boxes above this score

# v2 Inferencer model aliases (used when USE_INFERENCER=True)
# Pose model: top-down HRNet-W48 on COCO 256x192
POSE2D_ALIAS = "td-hm_hrnet-w48_8xb32-210e_coco-256x192"
# Detector model: Faster R-CNN R50-FPN on COCO
DET_MODEL_ALIAS = "mmdet::faster-rcnn_r50_fpn_1x_coco"

COCO_KEYPOINT_NAMES = [
    'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
    'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
    'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
    'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
]

SKELETON = [
    (0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6),
    (5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12),
    (11, 13), (13, 15), (12, 14), (14, 16)
]

print("‚úì Imports & config loaded")
print(f"  Val Captions: {VAL_CAPTIONS_PATH}")
print(f"  Generated Images: {GENERATED_IMAGES_PATH}")
print(f"  API Mode: {'Inferencer(v2)' if USE_INFERENCER else 'Classic(v1)'}")

## 4. Load COCO Dataset & Extract Ground Truth Keypoints (On-the-fly)

No local COCO dataset needed! We'll fetch annotations directly from COCO servers.

In [None]:
# Load val_captions.json
print("Loading val_captions.json...")
with open(VAL_CAPTIONS_PATH, 'r') as f:
    val_captions_raw = json.load(f)

# Handle different JSON structures
image_ids = []

if isinstance(val_captions_raw, dict):
    # Check if it's a filename->caption dictionary (e.g., {"000000480936.jpg": "caption text"})
    sample_key = next(iter(val_captions_raw.keys())) if val_captions_raw else None
    
    if sample_key and isinstance(sample_key, str) and sample_key.endswith(('.jpg', '.jpeg', '.png')):
        # Format: {"filename.jpg": "caption"}
        print("Detected filename->caption dictionary format")
        
        # Extract image IDs from filenames
        for filename in list(val_captions_raw.keys())[:NUM_IMAGES]:
            # Extract numeric ID from filename like "000000480936.jpg"
            # Remove extension and any leading zeros
            basename = filename.rsplit('.', 1)[0]  # "000000480936"
            try:
                img_id = int(basename)  # Convert to integer (removes leading zeros)
                image_ids.append(img_id)
            except ValueError:
                print(f"  Warning: Could not extract ID from filename: {filename}")
    
    elif 'annotations' in val_captions_raw:
        val_captions = val_captions_raw['annotations']
        for cap in val_captions[:NUM_IMAGES]:
            if isinstance(cap, dict):
                img_id = cap.get('image_id') or cap.get('id') or cap.get('image')
                if img_id and img_id not in image_ids:
                    image_ids.append(img_id)
    
    elif 'images' in val_captions_raw:
        val_captions = val_captions_raw['images']
        for cap in val_captions[:NUM_IMAGES]:
            if isinstance(cap, dict):
                img_id = cap.get('image_id') or cap.get('id') or cap.get('image')
                if img_id and img_id not in image_ids:
                    image_ids.append(img_id)
    
    else:
        # Try to use the values if they look like a list
        vals = list(val_captions_raw.values())
        if vals and isinstance(vals[0], dict):
            for cap in vals[:NUM_IMAGES]:
                img_id = cap.get('image_id') or cap.get('id') or cap.get('image')
                if img_id and img_id not in image_ids:
                    image_ids.append(img_id)

elif isinstance(val_captions_raw, list):
    for cap in val_captions_raw[:NUM_IMAGES]:
        if isinstance(cap, dict):
            img_id = cap.get('image_id') or cap.get('id') or cap.get('image')
            if img_id and img_id not in image_ids:
                image_ids.append(img_id)

if not image_ids:
    raise ValueError(f"Could not extract image IDs from val_captions.json. Please check the file format.")

print(f"Total captions: {len(val_captions_raw) if isinstance(val_captions_raw, dict) else len(val_captions_raw)}")
print(f"Extracted {len(image_ids)} unique image IDs from first {NUM_IMAGES} entries")
print(f"Sample image IDs: {image_ids[:5]}")

# ============================================
# Download COCO annotations on-the-fly
# ============================================
print("\n" + "="*70)
print("Downloading COCO annotations on-the-fly...")
print("="*70)

import urllib.request
import tempfile
import zipfile
import os

# Download COCO keypoints annotation JSON
print("Downloading COCO annotations (annotations_trainval2017.zip)...")
print("(This may take 2-3 minutes, file is ~250MB)")

try:
    # Use a cache directory
    cache_dir = Path.home() / ".cache" / "coco_annotations"
    cache_dir.mkdir(parents=True, exist_ok=True)
    cache_file = cache_dir / "person_keypoints_val2017.json"
    
    if cache_file.exists():
        print(f"Using cached annotations from {cache_file}")
        with open(cache_file, 'r') as f:
            coco_annotations = json.load(f)
    else:
        # Download the zip file containing all annotations
        zip_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
        zip_path = cache_dir / "annotations_trainval2017.zip"
        
        print(f"Downloading annotations zip file...")
        
        # Download with progress
        def download_progress(blocknum, blocksize, totalsize):
            downloaded = blocknum * blocksize
            percent = min(downloaded * 100 / totalsize, 100)
            print(f"  Progress: {percent:.1f}% ({downloaded / 1024 / 1024:.1f}MB / {totalsize / 1024 / 1024:.1f}MB)", end='\r')
        
        urllib.request.urlretrieve(zip_url, zip_path, reporthook=download_progress)
        print("\n‚úì Download complete!")
        
        # Extract the specific JSON file we need
        print(f"Extracting person_keypoints_val2017.json...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            # Extract only the file we need
            target_file = "annotations/person_keypoints_val2017.json"
            zip_ref.extract(target_file, cache_dir)
            
            # Move to cache location
            extracted_path = cache_dir / target_file
            extracted_path.rename(cache_file)
            
            # Clean up
            (cache_dir / "annotations").rmdir()
        
        # Remove zip file to save space
        zip_path.unlink()
        print(f"‚úì Extracted to {cache_file}")
        
        # Load the annotations
        with open(cache_file, 'r') as f:
            coco_annotations = json.load(f)
    
    print("‚úì COCO annotations loaded successfully")
    
except Exception as e:
    print(f"‚úó Error downloading COCO annotations: {e}")
    print("Make sure you have internet connection")
    raise

# ============================================
# Extract ground truth keypoints
# ============================================
print("\nProcessing COCO annotations...")

# Build lookup: image_id -> annotations
coco_images = {img['id']: img for img in coco_annotations.get('images', [])}
coco_annotations_by_img = {}

for ann in coco_annotations.get('annotations', []):
    img_id = ann['image_id']
    if img_id not in coco_annotations_by_img:
        coco_annotations_by_img[img_id] = []
    coco_annotations_by_img[img_id].append(ann)

print(f"Total images in COCO: {len(coco_images)}")
print(f"Total annotations in COCO: {len(coco_annotations.get('annotations', []))}")

# Extract ground truth keypoints for our images
gt_keypoints_dict = {}  # {image_id: {ann_id: keypoints}}

print(f"\nExtracting ground truth keypoints for {len(image_ids)} images...")
for idx, img_id in enumerate(image_ids):
    if (idx + 1) % 10 == 0:
        print(f"  [{idx+1}/{len(image_ids)}]", end='\r')
    
    # Get annotations for this image
    anns = coco_annotations_by_img.get(img_id, [])
    
    if anns:
        gt_keypoints_dict[img_id] = {}
        for ann in anns:
            if 'keypoints' in ann:
                # COCO format: [x1, y1, v1, x2, y2, v2, ...]
                kpts = np.array(ann['keypoints']).reshape(17, 3)  # 17 keypoints, (x, y, v)
                gt_keypoints_dict[img_id][ann['id']] = {
                    'keypoints': kpts,
                    'bbox': ann.get('bbox'),
                    'area': ann.get('area'),
                    'iscrowd': ann.get('iscrowd', 0),
                    'category_id': ann.get('category_id')
                }

print(f"\n‚úì Found ground truth for {len(gt_keypoints_dict)} images")
total_anns = sum(len(v) for v in gt_keypoints_dict.values())
print(f"  Total annotations: {total_anns}")
print(f"  Images without keypoint annotations: {len(image_ids) - len(gt_keypoints_dict)}")

## 5. Load MMPose Model

In [None]:
import torch
import urllib.request
from pathlib import Path
import subprocess
import sys

print("Loading models...")
print("This may take a moment...")

try:
    cache_dir = Path.home() / ".cache" / "mmpose_detector"
    cache_dir.mkdir(parents=True, exist_ok=True)

    if not USE_INFERENCER:
        # -------------------
        # v1: Download configs
        # -------------------
        detector_config_path = cache_dir / "faster_rcnn_r50_fpn_1x_coco.py"
        if not detector_config_path.exists():
            urllib.request.urlretrieve(DETECTOR_CONFIG_URL, detector_config_path)
        pose_config_path = cache_dir / "hrnet_w48_coco_256x192.py"
        if not pose_config_path.exists():
            urllib.request.urlretrieve(MMPOSE_CONFIG_URL, pose_config_path)

        # -------------------
        # v1: Download checkpoints
        # -------------------
        detector_ckpt_path = cache_dir / "faster_rcnn_r50_fpn_1x_coco.pth"
        if not detector_ckpt_path.exists():
            urllib.request.urlretrieve(DETECTOR_CHECKPOINT_URL, detector_ckpt_path)
        pose_ckpt_path = cache_dir / "hrnet_w48_coco_256x192.pth"
        if not pose_ckpt_path.exists():
            urllib.request.urlretrieve(MMPOSE_CHECKPOINT_URL, pose_ckpt_path)

        # -------------------
        # v1: Init detector + pose
        # -------------------
        detector_device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        detector_model = init_detector(
            str(detector_config_path),
            str(detector_ckpt_path),
            device=detector_device
        )

        pose_device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        pose_model = init_pose_model(
            str(pose_config_path),
            str(pose_ckpt_path),
            device=pose_device
        )
        print("‚úì v1 models loaded successfully")
        print(f"  Detector device: {detector_device}")
        print(f"  Pose device: {pose_device}")
    else:
        # -------------------
        # v2: Init inferencers
        # -------------------
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        pose_inferencer = MMPoseInferencer(
            pose2d=POSE2D_ALIAS,
            det_model=DET_MODEL_ALIAS,
            device=device
        )
        print("‚úì v2 inferencers initialized successfully")
        print(f"  Device: {device}")

except Exception as e:
    print(f"‚úó Error loading models/inferencers: {e}")
    import traceback
    traceback.print_exc()
    raise

## 7. Extract Keypoints from Generated Images

In [None]:
def detect_persons(image_path: str, detector_model, score_thr: float = DETECTOR_SCORE_THR):
    """Run person detector and return list of bboxes [x1, y1, x2, y2, score]. (v1 path)"""
    det_results = inference_detector(detector_model, image_path)
    # COCO class index 0 is person
    person_dets = det_results[0] if isinstance(det_results, (list, tuple)) else det_results
    if person_dets is None or len(person_dets) == 0:
        return []
    # Keep boxes above threshold
    keep = []
    for det in person_dets:
        if det[4] >= score_thr:
            keep.append(det)
    return keep


def _extract_keypoints_v1(image_path: str, detector_model, pose_model, score_thr: float = DETECTOR_SCORE_THR) -> List[np.ndarray]:
    """Detect persons (v1) then run pose on each detected bbox."""
    person_bboxes = detect_persons(image_path, detector_model, score_thr)
    if len(person_bboxes) == 0:
        return None

    # Prepare person_results for top-down pose (xyxy + score)
    person_results = []
    for det in person_bboxes:
        x1, y1, x2, y2, score = det
        person_results.append({'bbox': np.array([x1, y1, x2, y2, score])})

    pose_results, _ = inference_top_down_pose_model(
        pose_model,
        image_path,
        person_results,
        bbox_thr=0.0,
        format='xyxy',
        dataset='TopDownCocoDataset',
        return_heatmap=False
    )

    if pose_results is None or len(pose_results) == 0:
        return None

    detections = []
    for pr in pose_results:
        if 'keypoints' in pr:
            kpts = pr['keypoints']
            if kpts is not None and len(kpts) == 17:
                detections.append(kpts)

    return detections if len(detections) > 0 else None


def _extract_keypoints_v2(image_path: str) -> List[np.ndarray]:
    """Use MMPoseInferencer (v2) to get keypoints for all detected persons."""
    global pose_inferencer
    detections = []
    # Inferencer returns a generator; take first result
    gen = pose_inferencer(image_path, return_vis=False)
    try:
        result = next(gen)
    except StopIteration:
        return None
    except Exception:
        return None

    preds = result.get('predictions', [])
    for inst in preds:
        # Expect 'keypoints' (17x2) and optional 'keypoint_scores' (17,)
        if 'keypoints' in inst:
            pts = np.array(inst['keypoints'])
            if pts.ndim == 2 and pts.shape[0] == 17 and pts.shape[1] == 2:
                scores = np.array(inst.get('keypoint_scores', np.ones(17)))
                scores = scores.reshape(17, 1)
                kpts = np.concatenate([pts, scores], axis=1)  # (17,3)
                detections.append(kpts)
    return detections if len(detections) > 0 else None


def extract_keypoints_from_generated_image(image_path: str, detector_model=None, pose_model=None, score_thr: float = DETECTOR_SCORE_THR) -> List[np.ndarray]:
    """
    Unified extractor: if USE_INFERENCER=True (v2), use MMPoseInferencer.
    Otherwise (v1), run detector -> pose.
    Returns list of keypoint arrays (17,3) per detected person, or None.
    """
    if USE_INFERENCER:
        return _extract_keypoints_v2(image_path)
    else:
        return _extract_keypoints_v1(image_path, detector_model, pose_model, score_thr)


def compute_oks(gt_keypoints: np.ndarray, pred_keypoints: np.ndarray, bbox: np.ndarray) -> float:
    """Compute Object Keypoint Similarity (OKS)."""
    sigmas = np.array([
        .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89
    ]) / 10.0
    x, y, w, h = bbox
    scale = w * h
    if scale <= 0:
        return 0.0
    dx = pred_keypoints[:, 0] - gt_keypoints[:, 0]
    dy = pred_keypoints[:, 1] - gt_keypoints[:, 1]
    d_squared = dx**2 + dy**2
    visible = gt_keypoints[:, 2] > 0
    if visible.sum() == 0:
        return 0.0
    oks_per_kpt = np.exp(-d_squared / (2 * scale * sigmas**2))
    oks = (oks_per_kpt * visible).sum() / visible.sum()
    return float(oks)


def compute_pck(gt_keypoints: np.ndarray, pred_keypoints: np.ndarray, threshold: float = 0.2) -> Dict:
    """Compute Percentage of Correct Keypoints (PCK)."""
    visible_gt = gt_keypoints[gt_keypoints[:, 2] > 0]
    if len(visible_gt) == 0:
        return {'pck': 0.0, 'correct_keypoints': 0, 'visible_keypoints': 0}
    x_min, y_min = visible_gt[:, 0].min(), visible_gt[:, 1].min()
    x_max, y_max = visible_gt[:, 0].max(), visible_gt[:, 1].max()
    bbox_diagonal = np.sqrt((x_max - x_min)**2 + (y_max - y_min)**2)
    if bbox_diagonal == 0:
        bbox_diagonal = 1.0
    dx = pred_keypoints[:, 0] - gt_keypoints[:, 0]
    dy = pred_keypoints[:, 1] - gt_keypoints[:, 1]
    distances = np.sqrt(dx**2 + dy**2)
    visible = gt_keypoints[:, 2] > 0
    num_visible = visible.sum()
    if num_visible == 0:
        return {'pck': 0.0, 'correct_keypoints': 0, 'visible_keypoints': 0}
    correct = (distances <= threshold * bbox_diagonal) & visible
    num_correct = correct.sum()
    pck = num_correct / num_visible
    return {
        'pck': float(pck),
        'correct_keypoints': int(num_correct),
        'visible_keypoints': int(num_visible)
    }

print("‚úì Helper functions defined (detector + pose, v1/v2 unified)")

In [None]:
print("="*70)
print("EXTRACTING KEYPOINTS FROM GENERATED IMAGES")
print("="*70)

# Load all generated images from folder
gen_images_dir = Path(GENERATED_IMAGES_PATH)
if not gen_images_dir.exists():
    raise FileNotFoundError(f"Generated images directory not found: {GENERATED_IMAGES_PATH}")

# Collect all image files (png/jpg/jpeg, case-insensitive)
image_extensions = ['.png', '.jpg', '.jpeg']
generated_image_files = []
for ext in image_extensions:
    generated_image_files.extend(gen_images_dir.glob(f'*{ext}'))
    generated_image_files.extend(gen_images_dir.glob(f'*{ext.upper()}'))

# Deduplicate
generated_image_files = list(dict.fromkeys(generated_image_files))
print(f"Found {len(generated_image_files)} generated images in {GENERATED_IMAGES_PATH}")
print(f"Need matches for {len(image_ids)} validation captions")

# Helper: find file whose name contains the image id (plain or zero-padded) and has prefix 'generated_'
def find_generated_image_for_id(img_id: int) -> Path:
    id_plain = str(img_id)
    id_padded = f"{img_id:012d}"
    candidates = []
    for f in generated_image_files:
        name = f.name
        if not name.lower().startswith("generated_"):
            continue
        if id_plain in name or id_padded in name:
            candidates.append(f)
    if not candidates:
        return None
    # Prefer exact plain match, then padded, else first candidate
    for f in candidates:
        if f"generated_{id_plain}" in f.name:
            return f
    for f in candidates:
        if f"generated_{id_padded}" in f.name:
            return f
    return sorted(candidates)[0]

# Test extraction on first image to see full error
print("\nüîç Testing extraction on first image with detector -> pose ...")
test_img_id = image_ids[0]
test_img_path = find_generated_image_for_id(test_img_id)
if test_img_path:
    print(f"Test image: {test_img_path}")
    try:
        test_kpts_list = extract_keypoints_from_generated_image(
            str(test_img_path), detector_model, pose_model, score_thr=DETECTOR_SCORE_THR
        )
        if test_kpts_list:
            print(f"‚úì Extracted {len(test_kpts_list)} person(s) with keypoints; first shape: {test_kpts_list[0].shape}")
        else:
            print("‚úó Extraction returned None or empty list")
    except Exception as e:
        print(f"‚úó Extraction error: {type(e).__name__}: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f"‚úó Could not find test image for ID {test_img_id}")

print("\n" + "="*70)

generated_keypoints_dict = {}  # {image_id: [keypoints_array_per_person]}
missing_images = []
extraction_failures = []

print("\nMatching generated images by filename substring (image id) ...\n")

for idx, img_id in enumerate(image_ids):
    if (idx + 1) % 10 == 0:
        print(f"  [{idx+1}/{len(image_ids)}] Processed", end='\r')
    
    gen_img_path = find_generated_image_for_id(img_id)
    
    if gen_img_path is None:
        missing_images.append(img_id)
        continue
    
    # Extract keypoints (all detected persons)
    try:
        kpts_list = extract_keypoints_from_generated_image(
            str(gen_img_path), detector_model, pose_model, score_thr=DETECTOR_SCORE_THR
        )
        
        if kpts_list:
            generated_keypoints_dict[img_id] = kpts_list
        else:
            extraction_failures.append((img_id, "Extraction returned None or empty"))
    except Exception as e:
        error_msg = f"{type(e).__name__}: {str(e)}" if str(e) else type(e).__name__
        extraction_failures.append((img_id, error_msg))

print(f"\n\n‚úì Processed {len(image_ids)} target images")
print(f"  Successfully extracted (>=1 person): {len(generated_keypoints_dict)}")
print(f"  Missing images: {len(missing_images)}")
print(f"  Extraction failures: {len(extraction_failures)}")

if missing_images[:5]:
    print(f"\n  First 5 missing image IDs: {missing_images[:5]}")
    print(f"  ‚ö†Ô∏è Ensure filenames include the image id, e.g., generated_<id>.png")

if extraction_failures[:5]:
    print(f"\n  First 5 extraction failures:")
    for img_id, reason in extraction_failures[:5]:
        print(f"    - Image {img_id}: {reason}")

## 8. Compute OKS & Comparison Metrics

In [None]:
print("\n" + "="*70)
print("COMPUTING OKS & COMPARISON METRICS")
print("="*70)

# Validate that we have both GT and predictions
print(f"\nValidation:")
print(f"  Images with GT keypoints: {len(gt_keypoints_dict)}")
print(f"  Images with predicted keypoints (any person): {len(generated_keypoints_dict)}")

# Find intersection
common_images = set(gt_keypoints_dict.keys()) & set(generated_keypoints_dict.keys())
print(f"  Images with both GT and predictions: {len(common_images)}")

if len(common_images) == 0:
    print("\n‚úó ERROR: No images have both GT and predicted keypoints!")
    print("Please check:")
    print("  1. Generated images are in the correct folder")
    print("  2. Image naming matches one of the patterns")
    print("  3. MMPose is correctly detecting poses")
    raise ValueError("No valid image pairs for comparison")

print(f"\n‚úì Ready to compare {len(common_images)} image-annotation pairs")

# Helper function to align keypoints by centroid
def align_keypoints(gt_kpts, pred_kpts):
    """
    Align predicted keypoints to ground truth by translating to same centroid.
    This accounts for person being in different location in generated image.

    Returns aligned predicted keypoints.
    """
    # Get visible keypoints for centroid calculation
    gt_visible = gt_kpts[gt_kpts[:, 2] > 0]
    pred_visible = pred_kpts[pred_kpts[:, 2] > 0]

    if len(gt_visible) == 0 or len(pred_visible) == 0:
        return pred_kpts

    # Compute centroids
    gt_centroid = gt_visible[:, :2].mean(axis=0)
    pred_centroid = pred_visible[:, :2].mean(axis=0)

    # Compute translation
    translation = gt_centroid - pred_centroid

    # Apply translation to predicted keypoints
    aligned_pred = pred_kpts.copy()
    aligned_pred[:, 0] += translation[0]  # x
    aligned_pred[:, 1] += translation[1]  # y

    return aligned_pred

results = []
best_generated_keypoints_dict_aligned = {}  # {image_id: best aligned kpts (17,3)}
best_generated_keypoints_dict_raw = {}       # {image_id: best raw kpts (17,3)}

total_skipped_low_vis = 0
comparison_count = 0

print("\nüìç Strategy: Compare ALL detected poses to GT, take MAX OKS")
print("   ALIGN keypoints by centroid (handles different person locations)")
print("   Using GT PERSON BBOX for OKS scale calculation\n")

for img_id in common_images:
    gt_data = gt_keypoints_dict[img_id]

    if len(gt_data) == 0:
        continue

    ann_id = list(gt_data.keys())[0]
    gt_info = gt_data[ann_id]
    gt_kpts = gt_info['keypoints']  # (17, 3)
    gt_bbox = gt_info['bbox']  # [x, y, width, height]

    # Skip if labeled keypoints < 10
    visible_count = int((gt_kpts[:, 2] > 0).sum())
    if visible_count < 10:
        total_skipped_low_vis += 1
        continue

    # Get predictions from generated image (detector + pose pipeline from Cell 7)
    if img_id not in generated_keypoints_dict:
        continue

    pred_kpts_list = generated_keypoints_dict[img_id]

    if not pred_kpts_list or len(pred_kpts_list) == 0:
        continue

    # ============================================================
    # Compare ALL detected poses, take MAX OKS
    # ALIGN keypoints before comparison (critical fix!)
    # ============================================================
    max_oks = 0.0
    best_pred_aligned = None
    best_pred_raw = None
    best_pck_result = None

    for pred_kpts in pred_kpts_list:
        # ALIGN predicted keypoints to GT centroid
        aligned_pred = align_keypoints(gt_kpts, pred_kpts)

        # Compute OKS using GT bbox and ALIGNED keypoints
        oks = compute_oks(gt_kpts, aligned_pred, gt_bbox)

        # Keep track of best match
        if oks > max_oks:
            max_oks = oks
            best_pred_aligned = aligned_pred
            best_pred_raw = pred_kpts
            best_pck_result = compute_pck(gt_kpts, aligned_pred)

    if best_pred_aligned is None:
        continue

    results.append({
        'image_id': img_id,
        'annotation_id': ann_id,
        'oks': max_oks,
        'pck': best_pck_result['pck'],
        'correct_keypoints': best_pck_result['correct_keypoints'],
        'visible_keypoints': best_pck_result['visible_keypoints'],
        'num_detected_poses': len(pred_kpts_list)
    })

    best_generated_keypoints_dict_aligned[img_id] = best_pred_aligned
    best_generated_keypoints_dict_raw[img_id] = best_pred_raw
    comparison_count += 1

print(f"\nCompared {comparison_count} image-annotation pairs")
print(f"Skipped (labeled keypoints < 10): {total_skipped_low_vis}")

# Convert to DataFrame for analysis
results_df = pd.DataFrame(results)

print("\n" + "="*70)
print("üìä POSE PRESERVATION RESULTS")
print("   (MAX OKS across all detected poses, with centroid alignment)")
print("="*70)
print(f"\nOKS (Object Keypoint Similarity):")
print(f"  Mean:   {results_df['oks'].mean():.4f}")
print(f"  Median: {results_df['oks'].median():.4f}")
print(f"  Std:    {results_df['oks'].std():.4f}")
print(f"  Min:    {results_df['oks'].min():.4f}")
print(f"  Max:    {results_df['oks'].max():.4f}")

print(f"\nPCK (Percentage of Correct Keypoints @ 0.2 threshold):")
print(f"  Mean:   {results_df['pck'].mean():.4f}")
print(f"  Median: {results_df['pck'].median():.4f}")
print(f"  Std:    {results_df['pck'].std():.4f}")

# mAP@OKS thresholds
print(f"\nmAP@OKS Thresholds:")
for oks_threshold in [0.5, 0.75, 0.9]:
    mAP = (results_df['oks'] >= oks_threshold).mean()
    print(f"  mAP@OKS={oks_threshold:.2f}: {mAP:.4f}")

# Detection statistics
print(f"\nDetection Statistics:")
print(f"  Avg detected poses per image: {results_df['num_detected_poses'].mean():.2f}")
print(f"  Max detected poses: {results_df['num_detected_poses'].max()}")
print(f"  Images with multiple poses: {(results_df['num_detected_poses'] > 1).sum()}")

print("\n" + "="*70)

## 9. Visualize OKS Distribution

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# OKS histogram
axes[0, 0].hist(results_df['oks'], bins=30, color='steelblue', edgecolor='black', alpha=0.7)
axes[0, 0].axvline(results_df['oks'].mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {results_df["oks"].mean():.3f}')
axes[0, 0].axvline(results_df['oks'].median(), color='green', linestyle='--', linewidth=2, label=f'Median: {results_df["oks"].median():.3f}')
axes[0, 0].set_xlabel('OKS Score', fontsize=12)
axes[0, 0].set_ylabel('Frequency', fontsize=12)
axes[0, 0].set_title('OKS Distribution', fontsize=14, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(alpha=0.3)

# PCK histogram
axes[0, 1].hist(results_df['pck'], bins=30, color='coral', edgecolor='black', alpha=0.7)
axes[0, 1].axvline(results_df['pck'].mean(), color='red', linestyle='--', linewidth=2, label=f'Mean: {results_df["pck"].mean():.3f}')
axes[0, 1].set_xlabel('PCK Score', fontsize=12)
axes[0, 1].set_ylabel('Frequency', fontsize=12)
axes[0, 1].set_title('PCK Distribution (@ 0.2 threshold)', fontsize=14, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# mAP@OKS thresholds
oks_thresholds = [0.5, 0.75, 0.9, 0.95]
mAPs = [(results_df['oks'] >= t).mean() for t in oks_thresholds]
axes[1, 0].bar([f'{t:.2f}' for t in oks_thresholds], mAPs, color='mediumseagreen', edgecolor='black', alpha=0.7)
axes[1, 0].set_xlabel('OKS Threshold', fontsize=12)
axes[1, 0].set_ylabel('mAP (Percentage)', fontsize=12)
axes[1, 0].set_title('mAP @ Different OKS Thresholds', fontsize=14, fontweight='bold')
axes[1, 0].set_ylim([0, 1])
for i, v in enumerate(mAPs):
    axes[1, 0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontweight='bold')
axes[1, 0].grid(axis='y', alpha=0.3)

# Box plot comparison
axes[1, 1].boxplot([results_df['oks'], results_df['pck']], labels=['OKS', 'PCK'],
                   patch_artist=True, widths=0.6)
axes[1, 1].set_ylabel('Score', fontsize=12)
axes[1, 1].set_title('OKS vs PCK Score Distribution', fontsize=14, fontweight='bold')
axes[1, 1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

## 10. Detailed Statistics by Keypoint Type

In [None]:
print("\n" + "="*70)
print("KEYPOINT-LEVEL ANALYSIS")
print("="*70)

# Compute per-keypoint distances using ALIGNED predictions (same as metrics)
keypoint_errors = defaultdict(list)

for img_id in best_generated_keypoints_dict_aligned.keys():
    if img_id not in gt_keypoints_dict:
        continue

    pred_kpts = best_generated_keypoints_dict_aligned[img_id]
    gt_data = gt_keypoints_dict[img_id]

    if len(gt_data) == 0:
        continue

    ann_id = list(gt_data.keys())[0]
    gt_kpts = gt_data[ann_id]['keypoints']

    for kpt_idx in range(17):
        if gt_kpts[kpt_idx, 2] > 0:  # Visible keypoint
            dist = np.sqrt((pred_kpts[kpt_idx, 0] - gt_kpts[kpt_idx, 0])**2 +
                          (pred_kpts[kpt_idx, 1] - gt_kpts[kpt_idx, 1])**2)
            keypoint_errors[COCO_KEYPOINT_NAMES[kpt_idx]].append(dist)

# Print per-keypoint statistics
print("\nPer-Keypoint Euclidean Distance Statistics:")
print(f"{'Keypoint':<20} {'Mean':<10} {'Std':<10} {'Min':<10} {'Max':<10}")
print("-" * 50)

keypoint_stats = []
for kpt_name in COCO_KEYPOINT_NAMES:
    if kpt_name in keypoint_errors and len(keypoint_errors[kpt_name]) > 0:
        errors = np.array(keypoint_errors[kpt_name])
        mean_err = errors.mean()
        std_err = errors.std()
        min_err = errors.min()
        max_err = errors.max()

        print(f"{kpt_name:<20} {mean_err:<10.2f} {std_err:<10.2f} {min_err:<10.2f} {max_err:<10.2f}")

        keypoint_stats.append({
            'keypoint': kpt_name,
            'mean_distance': mean_err,
            'std_distance': std_err,
            'min_distance': min_err,
            'max_distance': max_err,
            'num_samples': len(errors)
        })

keypoint_stats_df = pd.DataFrame(keypoint_stats)

# Plot per-keypoint errors
plt.figure(figsize=(14, 6))
plt.bar(keypoint_stats_df['keypoint'], keypoint_stats_df['mean_distance'],
        color='steelblue', edgecolor='black', alpha=0.7)
plt.errorbar(keypoint_stats_df['keypoint'], keypoint_stats_df['mean_distance'],
             yerr=keypoint_stats_df['std_distance'], fmt='none', color='red', capsize=5, alpha=0.5)
plt.xlabel('Keypoint', fontsize=12)
plt.ylabel('Mean Euclidean Distance (pixels)', fontsize=12)
plt.title('Per-Keypoint Distance from Ground Truth (aligned predictions)', fontsize=14, fontweight='bold')
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()


## 11. Visualization: Pose Comparison

In [None]:
def draw_skeleton(image, keypoints, skeleton, title=""):
    """Draw skeleton on image."""
    image = image.copy()
    # Draw keypoints
    for kpt_idx, (x, y, conf) in enumerate(keypoints):
        if conf > 0.3:  # Only draw if confidence > 0.3
            cv2.circle(image, (int(x), int(y)), 5, (0, 255, 0), -1)
            cv2.putText(image, str(kpt_idx), (int(x) + 5, int(y) - 5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)

    # Draw skeleton connections
    for start, end in skeleton:
        if keypoints[start, 2] > 0.3 and keypoints[end, 2] > 0.3:
            x1, y1, _ = keypoints[start]
            x2, y2, _ = keypoints[end]
            cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)

    return image


def find_generated_image_for_id(img_id: int) -> Path:
    """Locate generated image by id in GENERATED_IMAGES_PATH."""
    gen_dir = Path(GENERATED_IMAGES_PATH)
    if not gen_dir.exists():
        return None
    id_plain = str(img_id)
    id_padded = f"{img_id:012d}"
    image_extensions = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']
    candidates = []
    for ext in image_extensions:
        for f in gen_dir.glob(f"generated_*{ext}"):
            name = f.name
            if id_plain in name or id_padded in name:
                candidates.append(f)
    if not candidates:
        return None
    # Prefer exact plain match, then padded
    for f in candidates:
        if f"generated_{id_plain}" in f.name:
            return f
    for f in candidates:
        if f"generated_{id_padded}" in f.name:
            return f
    return sorted(candidates)[0]


# Show a few examples (samples with highest and lowest OKS)
print("Visualizing samples with highest and lowest OKS scores...\n")

if results_df.empty:
    print("‚úó No results to visualize. Run metric computation first.")
else:
    best_indices = results_df.nlargest(3, 'oks').index
    worst_indices = results_df.nsmallest(3, 'oks').index

    fig, axes = plt.subplots(3, 2, figsize=(16, 14))
    fig.suptitle('Pose Estimation Quality: Best vs Worst Examples', fontsize=16, fontweight='bold')

    for row, (best_idx, worst_idx) in enumerate(zip(best_indices, worst_indices)):
        best_result = results_df.iloc[best_idx]
        worst_result = results_df.iloc[worst_idx]

        best_img_id = int(best_result['image_id'])
        worst_img_id = int(worst_result['image_id'])

        # Best example from generated image
        best_img_path = find_generated_image_for_id(best_img_id)
        if best_img_path and best_img_id in best_generated_keypoints_dict:
            img = cv2.imread(str(best_img_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            pred_kpts = best_generated_keypoints_dict[best_img_id]
            img_with_pose = draw_skeleton(img_rgb, pred_kpts, SKELETON)
            axes[row, 0].imshow(img_with_pose)
            axes[row, 0].set_title(f'Best (OKS={best_result["oks"]:.3f})',
                                  fontsize=12, fontweight='bold', color='green')
        else:
            axes[row, 0].text(0.5, 0.5, 'Image not found', ha='center', va='center')
            axes[row, 0].set_title('Best example missing', fontsize=12, color='red')
        axes[row, 0].axis('off')

        # Worst example from generated image
        worst_img_path = find_generated_image_for_id(worst_img_id)
        if worst_img_path and worst_img_id in best_generated_keypoints_dict:
            img = cv2.imread(str(worst_img_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            pred_kpts = best_generated_keypoints_dict[worst_img_id]
            img_with_pose = draw_skeleton(img_rgb, pred_kpts, SKELETON)
            axes[row, 1].imshow(img_with_pose)
            axes[row, 1].set_title(f'Worst (OKS={worst_result["oks"]:.3f})',
                                  fontsize=12, fontweight='bold', color='red')
        else:
            axes[row, 1].text(0.5, 0.5, 'Image not found', ha='center', va='center')
            axes[row, 1].set_title('Worst example missing', fontsize=12, color='red')
        axes[row, 1].axis('off')

    plt.tight_layout()
    plt.show()


## 12. Summary & Interpretation

### Metrics Explanation:

**OKS (Object Keypoint Similarity)** - Primary Metric
- Measures how similar predicted poses are to ground truth
- Range: 0 to 1 (higher is better)
- Standard COCO evaluation metric
- Formula: OKS = Œ£(exp(-d_i¬≤/(2*s_k¬≤)) √ó vis_i) / Œ£(vis_i)
- Considers both distance and keypoint visibility

**mAP@OKS** - Secondary Metric
- Percentage of poses with OKS ‚â• threshold
- Commonly reported: mAP@OKS=0.5, @0.75, @0.9

**PCK (Percentage of Correct Keypoints)**
- Percentage of keypoints within distance threshold
- Threshold often set to 0.2 √ó bounding box diagonal
- More lenient than OKS

### Why These Metrics?

‚úÖ **OKS**: Industry standard for COCO pose evaluation  
‚úÖ **mAP@OKS**: Shows performance across difficulty levels  
‚úÖ **PCK**: Provides complementary perspective on accuracy  
‚úÖ **Per-keypoint analysis**: Identifies which joints are harder to predict  

### Interpretation Guide:

| OKS Range | Quality | Interpretation |
|-----------|---------|-----------------|
| 0.9 - 1.0 | Excellent | Near-perfect pose estimation |
| 0.75 - 0.9 | Very Good | High-quality results, minor errors |
| 0.5 - 0.75 | Good | Acceptable, some keypoint errors |
| 0.25 - 0.5 | Fair | Significant errors in some keypoints |
| 0.0 - 0.25 | Poor | Major errors, unreliable poses |

### Next Steps:

1. **If OKS is low**: Generated images may lack pose detail
2. **If specific keypoints have high error**: May need model fine-tuning
3. **Compare with/without spatial conditioning**: Evaluate if ControlNet conditioning improves pose