# COCO Pose Estimation Quality Metric

**Evaluating ControlNet-Generated Poses using YOLOv8-Pose & Procrustes Alignment**

This notebook compares poses from ControlNet-generated images against ground truth COCO keypoints using:
1. **COCO API** to extract ground truth keypoints from dataset
2. **YOLOv8-Pose** unified model for keypoint detection from generated images
3. **Procrustes Alignment** (translation + rotation + scale) to isolate pose shape
4. **MPJPE** (primary metric) and **OKS** (secondary metric) for evaluation

**Key Innovation:**
- Procrustes alignment removes location, size, and orientation differences â†’ measures pure **pose shape similarity**
- MPJPE: Mean Per Joint Position Error (lower = better)
- OKS: Object Keypoint Similarity (higher = better)

**Workflow:**
- Load 300 images from val_captions.json
- Extract COCO ground truth keypoints
- Run YOLOv8-Pose on generated images
- Align predictions using Procrustes analysis
- Compare using MPJPE and OKS metrics

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 1. Install & Verify Dependencies

In [None]:

%pip install -q numpy pandas matplotlib seaborn scipy
%pip install -q opencv-python
%pip install -q pycocotools

import sys
print(f"Python version: {sys.version}")


print("Installing YOLOv8-Pose (ultralytics)...")
%pip install -q -U ultralytics torch torchvision

print("Installation complete (YOLOv8-Pose)")
print("No MMPose/MMDetection/OpenMMLab compilation issues")

In [None]:

import sys
import torch
print(f"Python: {sys.version}")
print(f"Torch:   {torch.__version__}, CUDA available: {torch.cuda.is_available()}")

try:
    from ultralytics import YOLO
    print(f"Ultralytics (YOLOv8): imported successfully")
except Exception as e:
    print(f"Ultralytics import failed: {type(e).__name__}: {e}")
    raise

try:
    import cv2
    import numpy as np
    import pandas as pd
    print(f"Core libs (cv2, numpy, pandas): imported successfully")
    print("\nAll packages loaded successfully (YOLOv8-Pose ready!)")
except Exception as e:
    print(f"Core libs import failed: {type(e).__name__}: {e}")
    raise

In [None]:
# Imports and configuration
import json
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from typing import List, Dict, Tuple
import pandas as pd
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

from ultralytics import YOLO


VAL_CAPTIONS_PATH = "/content/val_captions.json"
GENERATED_IMAGES_PATH = "/content/drive/MyDrive/generatedimages"  # e.g., "./generated_images"
NUM_IMAGES = 300


YOLO_MODEL = 'yolov8m-pose.pt'  
CONF_THRESHOLD = 0.4  # confidence threshold for detections


COCO_KEYPOINT_NAMES = [
    'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
    'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
    'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
    'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
]


SKELETON = [
    (0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6),
    (5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12),
    (11, 13), (13, 15), (12, 14), (14, 16)
]

print(" Imports & config loaded")
print(f"  Val Captions: {VAL_CAPTIONS_PATH}")
print(f"  Generated Images: {GENERATED_IMAGES_PATH}")
print(f"  Model: {YOLO_MODEL} (YOLOv8-Pose)")
print(f"  Confidence threshold: {CONF_THRESHOLD}")

## 2. Configuration & Dataset Setup

Load COCO annotations and extract ground truth keypoints

In [None]:
import json
import numpy as np
import urllib.request
import zipfile
from pathlib import Path


print("Loading validation captions...")
with open(VAL_CAPTIONS_PATH, 'r') as f:
    val_captions_raw = json.load(f)


image_ids = []

if isinstance(val_captions_raw, dict):
    sample_key = next(iter(val_captions_raw.keys()), None)

    # Case 1: filename -> caption mapping
    if sample_key and isinstance(sample_key, str) and sample_key.endswith(('.jpg', '.jpeg', '.png')):
        print("Detected filename-based caption format")

        for filename in list(val_captions_raw.keys())[:NUM_IMAGES]:
            basename = filename.rsplit('.', 1)[0]
            try:
                image_ids.append(int(basename))
            except ValueError:
                pass

    # Case 2: COCO-style annotation dictionaries
    elif 'annotations' in val_captions_raw:
        for cap in val_captions_raw['annotations'][:NUM_IMAGES]:
            if isinstance(cap, dict):
                img_id = cap.get('image_id') or cap.get('id') or cap.get('image')
                if img_id is not None:
                    image_ids.append(img_id)

    elif 'images' in val_captions_raw:
        for cap in val_captions_raw['images'][:NUM_IMAGES]:
            if isinstance(cap, dict):
                img_id = cap.get('image_id') or cap.get('id') or cap.get('image')
                if img_id is not None:
                    image_ids.append(img_id)

    # Fallback: dictionary values containing image metadata
    else:
        for cap in list(val_captions_raw.values())[:NUM_IMAGES]:
            if isinstance(cap, dict):
                img_id = cap.get('image_id') or cap.get('id') or cap.get('image')
                if img_id is not None:
                    image_ids.append(img_id)

elif isinstance(val_captions_raw, list):
    for cap in val_captions_raw[:NUM_IMAGES]:
        if isinstance(cap, dict):
            img_id = cap.get('image_id') or cap.get('id') or cap.get('image')
            if img_id is not None:
                image_ids.append(img_id)

image_ids = list(dict.fromkeys(image_ids))  # preserve order, remove duplicates

if not image_ids:
    raise ValueError("Failed to extract image IDs from validation captions")

print(f"Extracted {len(image_ids)} image IDs")


print("Loading COCO keypoint annotations...")

cache_dir = Path.home() / ".cache" / "coco_annotations"
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / "person_keypoints_val2017.json"

if cache_file.exists():
    with open(cache_file, 'r') as f:
        coco_annotations = json.load(f)
else:
    zip_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
    zip_path = cache_dir / "annotations_trainval2017.zip"

    urllib.request.urlretrieve(zip_url, zip_path)

    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        target = "annotations/person_keypoints_val2017.json"
        zip_ref.extract(target, cache_dir)

    extracted_path = cache_dir / target
    extracted_path.rename(cache_file)
    (cache_dir / "annotations").rmdir()
    zip_path.unlink()

    with open(cache_file, 'r') as f:
        coco_annotations = json.load(f)

print("COCO annotations loaded")


coco_images = {img['id']: img for img in coco_annotations.get('images', [])}
coco_annotations_by_img = {}

for ann in coco_annotations.get('annotations', []):
    coco_annotations_by_img.setdefault(ann['image_id'], []).append(ann)


gt_keypoints_dict = {}

print(f"Extracting ground-truth keypoints for {len(image_ids)} images...")
for idx, img_id in enumerate(image_ids):
    anns = coco_annotations_by_img.get(img_id, [])
    if not anns:
        continue

    gt_keypoints_dict[img_id] = {}
    for ann in anns:
        if 'keypoints' in ann:
            kpts = np.array(ann['keypoints']).reshape(17, 3)
            gt_keypoints_dict[img_id][ann['id']] = {
                'keypoints': kpts,
                'bbox': ann.get('bbox'),
                'area': ann.get('area'),
                'iscrowd': ann.get('iscrowd', 0),
                'category_id': ann.get('category_id')
            }

print(f"Images with GT annotations: {len(gt_keypoints_dict)}")
print(f"Total GT instances: {sum(len(v) for v in gt_keypoints_dict.values())}")


## 3. Load YOLOv8-Pose Model

In [None]:
import torch

print("Loading YOLOv8-Pose model...")
print("(Model will auto-download on first run)\n")

try:
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    
    yolo_model = YOLO(YOLO_MODEL)
    yolo_model.to(device)

    print(f"YOLOv8-Pose model loaded successfully")
    print(f"  Model: {YOLO_MODEL}")
    print(f"  Device: {device}")
    print(f"  Unified detection + pose (no separate pipeline)")

except Exception as e:
    print(f"Error loading YOLOv8-Pose: {e}")
    import traceback
    traceback.print_exc()
    raise

## 4. Helper Functions & Metrics

Procrustes Alignment + MPJPE + OKS computation

In [None]:
from typing import List, Dict
import numpy as np


def extract_keypoints_from_generated_image(image_path: str) -> List[np.ndarray]:
    """
    Runs YOLOv8-Pose to detect persons and extract 17-keypoint poses.
    Returns a list of (17, 3) arrays [x, y, confidence] per detected person,
    or None if no valid poses are found.
    """
    global yolo_model

    results = yolo_model(image_path, conf=CONF_THRESHOLD, verbose=False)
    if not results:
        return None

    result = results[0]
    if result.keypoints is None or len(result.keypoints) == 0:
        return None

    keypoints_xy = result.keypoints.xy
    confidences = result.keypoints.conf
    if keypoints_xy is None or len(keypoints_xy) == 0:
        return None

    detections = []
    for i in range(len(keypoints_xy)):
        kpts_xy = keypoints_xy[i].cpu().numpy()
        kpts_conf = confidences[i].cpu().numpy().reshape(17, 1)
        detections.append(np.concatenate([kpts_xy, kpts_conf], axis=1))

    return detections if detections else None


def procrustes_align(gt_keypoints: np.ndarray, pred_keypoints: np.ndarray) -> np.ndarray:
    """
    Aligns predicted keypoints to ground truth using Procrustes analysis.
    Removes effects of translation, rotation, and uniform scaling so that
    evaluation focuses on pose shape rather than absolute position.
    """
    gt_visible = gt_keypoints[:, 2] > 0
    pred_visible = pred_keypoints[:, 2] > 0
    common_visible = gt_visible & pred_visible

    if common_visible.sum() < 3:
        return pred_keypoints

    gt_pts = gt_keypoints[common_visible, :2]
    pred_pts = pred_keypoints[common_visible, :2]

    gt_centroid = gt_pts.mean(axis=0)
    pred_centroid = pred_pts.mean(axis=0)

    gt_centered = gt_pts - gt_centroid
    pred_centered = pred_pts - pred_centroid

    H = pred_centered.T @ gt_centered
    U, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U.T

    if np.linalg.det(R) < 0:
        Vt[-1, :] *= -1
        R = Vt.T @ U.T

    pred_norm = np.sum(pred_centered ** 2)
    scale = np.trace(np.diag(S)) / pred_norm if pred_norm > 0 else 1.0

    aligned = pred_keypoints.copy()
    aligned[:, :2] -= pred_centroid
    aligned[:, :2] *= scale
    aligned[:, :2] = (R @ aligned[:, :2].T).T
    aligned[:, :2] += gt_centroid

    return aligned


def compute_mpjpe(gt_keypoints: np.ndarray, pred_keypoints: np.ndarray) -> Dict:
    """
    Computes Mean Per Joint Position Error (MPJPE) over visible ground-truth joints.
    """
    visible = gt_keypoints[:, 2] > 0
    if visible.sum() == 0:
        return {
            'mpjpe': 0.0,
            'per_joint_errors': np.zeros(17),
            'num_visible': 0
        }

    diff = pred_keypoints[:, :2] - gt_keypoints[:, :2]
    distances = np.linalg.norm(diff, axis=1)

    return {
        'mpjpe': float(distances[visible].mean()),
        'per_joint_errors': distances,
        'num_visible': int(visible.sum())
    }


def compute_oks(gt_keypoints: np.ndarray, pred_keypoints: np.ndarray, bbox: np.ndarray) -> float:
    """
    Computes Object Keypoint Similarity (OKS) following the COCO formulation.
    """
    sigmas = np.array([
        .26, .25, .25, .35, .35, .79, .79, .72, .72,
        .62, .62, 1.07, 1.07, .87, .87, .89, .89
    ]) / 10.0

    x, y, w, h = bbox
    scale = w * h
    if scale <= 0:
        return 0.0

    visible = gt_keypoints[:, 2] > 0
    if visible.sum() == 0:
        return 0.0

    dx = pred_keypoints[:, 0] - gt_keypoints[:, 0]
    dy = pred_keypoints[:, 1] - gt_keypoints[:, 1]
    d2 = dx**2 + dy**2

    oks_per_kpt = np.exp(-d2 / (2 * scale * sigmas**2))
    return float((oks_per_kpt * visible).sum() / visible.sum())


print("Pose evaluation utilities initialized")


## 5. Extract Keypoints from Generated Images

In [None]:

gen_images_dir = Path(GENERATED_IMAGES_PATH)
if not gen_images_dir.exists():
    raise FileNotFoundError(f"Generated images directory not found: {GENERATED_IMAGES_PATH}")

image_extensions = ['.png', '.jpg', '.jpeg']
generated_image_files = []
for ext in image_extensions:
    generated_image_files.extend(gen_images_dir.glob(f'*{ext}'))
    generated_image_files.extend(gen_images_dir.glob(f'*{ext.upper()}'))

# Deduplicate
generated_image_files = list(dict.fromkeys(generated_image_files))
print(f"Found {len(generated_image_files)} generated images in {GENERATED_IMAGES_PATH}")
print(f"Need matches for {len(image_ids)} validation captions")

# Helper: find file whose name contains the image id
def find_generated_image_for_id(img_id: int) -> Path:
    id_plain = str(img_id)
    id_padded = f"{img_id:012d}"
    candidates = []
    for f in generated_image_files:
        name = f.name
        if not name.lower().startswith("generated_"):
            continue
        if id_plain in name or id_padded in name:
            candidates.append(f)
    if not candidates:
        return None
    # Prefer exact plain match, then padded, else first candidate
    for f in candidates:
        if f"generated_{id_plain}" in f.name:
            return f
    for f in candidates:
        if f"generated_{id_padded}" in f.name:
            return f
    return sorted(candidates)[0]

print("\n Testing YOLOv8-Pose extraction on first image...")
test_img_id = image_ids[0]
test_img_path = find_generated_image_for_id(test_img_id)
if test_img_path:
    print(f"Test image: {test_img_path}")
    try:
        test_kpts_list = extract_keypoints_from_generated_image(str(test_img_path))
        if test_kpts_list:
            print(f" Extracted {len(test_kpts_list)} person(s) with keypoints; first shape: {test_kpts_list[0].shape}")
        else:
            print(" Extraction returned None or empty list")
    except Exception as e:
        print(f" Extraction error: {type(e).__name__}: {e}")
        import traceback
        traceback.print_exc()
else:
    print(f" Could not find test image for ID {test_img_id}")



generated_keypoints_dict = {}  # {image_id: [keypoints_array_per_person]}
missing_images = []
extraction_failures = []

print("\nMatching generated images by filename substring (image id) ...\n")

for idx, img_id in enumerate(image_ids):
    if (idx + 1) % 10 == 0:
        print(f"  [{idx+1}/{len(image_ids)}] Processed", end='\r')

    gen_img_path = find_generated_image_for_id(img_id)

    if gen_img_path is None:
        missing_images.append(img_id)
        continue

    # Extract keypoints (all detected persons)
    try:
        kpts_list = extract_keypoints_from_generated_image(str(gen_img_path))

        if kpts_list:
            generated_keypoints_dict[img_id] = kpts_list
        else:
            extraction_failures.append((img_id, "Extraction returned None or empty"))
    except Exception as e:
        error_msg = f"{type(e).__name__}: {str(e)}" if str(e) else type(e).__name__
        extraction_failures.append((img_id, error_msg))

print(f"\n\n Processed {len(image_ids)} target images")
print(f"  Successfully extracted (>=1 person): {len(generated_keypoints_dict)}")
print(f"  Missing images: {len(missing_images)}")
print(f"  Extraction failures: {len(extraction_failures)}")

if missing_images[:5]:
    print(f"\n  First 5 missing image IDs: {missing_images[:5]}")
    print(f"   Ensure filenames include the image id, e.g., generated_<id>.png")

if extraction_failures[:5]:
    print(f"\n  First 5 extraction failures:")
    for img_id, reason in extraction_failures[:5]:
        print(f"    - Image {img_id}: {reason}")

## 6. Compute Pose Shape Similarity Metrics

MPJPE (primary) and OKS (secondary) using Procrustes-aligned predictions

In [None]:


# Validate that we have both GT and predictions
print(f"\nValidation:")
print(f"  Images with GT keypoints: {len(gt_keypoints_dict)}")
print(f"  Images with predicted keypoints (any person): {len(generated_keypoints_dict)}")

# Find intersection
common_images = set(gt_keypoints_dict.keys()) & set(generated_keypoints_dict.keys())
print(f"  Images with both GT and predictions: {len(common_images)}")

if len(common_images) == 0:
    print("\n ERROR: No images have both GT and predicted keypoints!")
    print("Please check:")
    print("  1. Generated images are in the correct folder")
    print("  2. Image naming matches one of the patterns")
    print("  3. YOLOv8-Pose is correctly detecting poses")
    raise ValueError("No valid image pairs for comparison")

print(f"\n Ready to compare {len(common_images)} image-annotation pairs")

results = []
best_generated_keypoints_dict_aligned = {}  
best_generated_keypoints_dict_raw = {}      

total_skipped_low_vis = 0
comparison_count = 0

print("\n Strategy: Procrustes Alignment + MPJPE for Pose Shape Similarity")
print("    Translation: Centering to same location")
print("    Rotation: Optimal angle alignment")
print("    Scale: Uniform scaling to same size")
print("Measures PURE POSE SHAPE, ignoring location/size/orientation\n")

for img_id in common_images:
    gt_data = gt_keypoints_dict[img_id]

    if len(gt_data) == 0:
        continue

    ann_id = list(gt_data.keys())[0]
    gt_info = gt_data[ann_id]
    gt_kpts = gt_info['keypoints']  
    gt_bbox = gt_info['bbox'] 

    # Skip if labeled keypoints < 10
    visible_count = int((gt_kpts[:, 2] > 0).sum())
    if visible_count < 10:
        total_skipped_low_vis += 1
        continue

    # Get predictions from generated image
    if img_id not in generated_keypoints_dict:
        continue

    pred_kpts_list = generated_keypoints_dict[img_id]

    if not pred_kpts_list or len(pred_kpts_list) == 0:
        continue

   
    best_mpjpe = float('inf')
    best_pred_aligned = None
    best_pred_raw = None
    best_oks = 0.0

    for pred_kpts in pred_kpts_list:
        # PROCRUSTES ALIGN: translation + rotation + scale
        aligned_pred = procrustes_align(gt_kpts, pred_kpts)

        
        mpjpe_result = compute_mpjpe(gt_kpts, aligned_pred)
        mpjpe = mpjpe_result['mpjpe']

        
        oks = compute_oks(gt_kpts, aligned_pred, gt_bbox)

        # Keep track of best match (lowest MPJPE = best shape match)
        if mpjpe < best_mpjpe:
            best_mpjpe = mpjpe
            best_pred_aligned = aligned_pred
            best_pred_raw = pred_kpts
            best_oks = oks

    if best_pred_aligned is None:
        continue

    results.append({
        'image_id': img_id,
        'annotation_id': ann_id,
        'mpjpe': best_mpjpe,  # Primary metric: lower = better pose shape match
        'oks': best_oks,      # Secondary: higher = better
        'num_detected_poses': len(pred_kpts_list)
    })

    best_generated_keypoints_dict_aligned[img_id] = best_pred_aligned
    best_generated_keypoints_dict_raw[img_id] = best_pred_raw
    comparison_count += 1

print(f"\nCompared {comparison_count} image-annotation pairs")
print(f"Skipped (labeled keypoints < 10): {total_skipped_low_vis}")


results_df = pd.DataFrame(results)

print(f"\n MPJPE (Mean Per Joint Position Error) - PRIMARY METRIC:")
print(f"   Lower is better (measures pose shape similarity)")
print(f"  Mean:   {results_df['mpjpe'].mean():.2f} pixels")
print(f"  Median: {results_df['mpjpe'].median():.2f} pixels")
print(f"  Std:    {results_df['mpjpe'].std():.2f} pixels")
print(f"  Min:    {results_df['mpjpe'].min():.2f} pixels (best)")
print(f"  Max:    {results_df['mpjpe'].max():.2f} pixels (worst)")

print(f"\nOKS (Object Keypoint Similarity) - SECONDARY METRIC:")
print(f"  Mean:   {results_df['oks'].mean():.4f}")
print(f"  Median: {results_df['oks'].median():.4f}")
print(f"  Std:    {results_df['oks'].std():.4f}")
print(f"  Min:    {results_df['oks'].min():.4f}")
print(f"  Max:    {results_df['oks'].max():.4f}")

## 7. Visualization: Best & Worst Pose Examples

Ranked by MPJPE (lower = better pose shape match)

In [None]:
def draw_skeleton(image, keypoints, skeleton, title=""):
    """Draw skeleton on image."""
    image = image.copy()
    # Draw keypoints
    for kpt_idx, (x, y, conf) in enumerate(keypoints):
        if conf > 0.3:  # Only draw if confidence > 0.3
            cv2.circle(image, (int(x), int(y)), 5, (0, 255, 0), -1)
            cv2.putText(image, str(kpt_idx), (int(x) + 5, int(y) - 5),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)

    # Draw skeleton connections
    for start, end in skeleton:
        if keypoints[start, 2] > 0.3 and keypoints[end, 2] > 0.3:
            x1, y1, _ = keypoints[start]
            x2, y2, _ = keypoints[end]
            cv2.line(image, (int(x1), int(y1)), (int(x2), int(y2)), (255, 0, 0), 2)

    return image


def find_generated_image_for_id(img_id: int) -> Path:
    """Locate generated image by id in GENERATED_IMAGES_PATH."""
    gen_dir = Path(GENERATED_IMAGES_PATH)
    if not gen_dir.exists():
        return None
    id_plain = str(img_id)
    id_padded = f"{img_id:012d}"
    image_extensions = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']
    candidates = []
    for ext in image_extensions:
        for f in gen_dir.glob(f"generated_*{ext}"):
            name = f.name
            if id_plain in name or id_padded in name:
                candidates.append(f)
    if not candidates:
        return None
    # Prefer exact plain match, then padded
    for f in candidates:
        if f"generated_{id_plain}" in f.name:
            return f
    for f in candidates:
        if f"generated_{id_padded}" in f.name:
            return f
    return sorted(candidates)[0]


def find_coco_original_image(img_id: int) -> Path:
    """Download COCO image on-demand if not in cache."""
    import urllib.request
    cache_dir = Path.home() / ".cache" / "coco_val2017"
    cache_dir.mkdir(parents=True, exist_ok=True)

    img_filename = f"{img_id:012d}.jpg"
    cached_path = cache_dir / img_filename

    if cached_path.exists():
        return cached_path

    
    coco_url = f"http://images.cocodataset.org/val2017/{img_filename}"
    try:
        print(f"  Downloading COCO image {img_id}...", end='\r')
        urllib.request.urlretrieve(coco_url, cached_path)
        return cached_path
    except Exception as e:
        print(f"  Failed to download {img_id}: {e}")
        return None



print("Visualizing samples with best and worst pose shape similarity...")
print("(Ranked by MPJPE: lower = better pose shape match)")
print("(Left: Generated images | Right: Original COCO images with GT poses)\n")

if results_df.empty:
    print("No results to visualize. Run metric computation first.")
else:
    # Sort by MPJPE: lowest = best shape match
    best_indices = results_df.nsmallest(3, 'mpjpe').index
    worst_indices = results_df.nlargest(3, 'mpjpe').index

    fig, axes = plt.subplots(3, 4, figsize=(20, 14))
    fig.suptitle('Pose Shape Similarity: Best vs Worst Examples\n(MPJPE metric: lower = better | Left: Generated | Right: Original COCO)',
                 fontsize=16, fontweight='bold')

    for row, (best_idx, worst_idx) in enumerate(zip(best_indices, worst_indices)):
        best_result = results_df.iloc[best_idx]
        worst_result = results_df.iloc[worst_idx]

        best_img_id = int(best_result['image_id'])
        worst_img_id = int(worst_result['image_id'])

        # BEST EXAMPLE
        # Generated image
        best_img_path = find_generated_image_for_id(best_img_id)
        if best_img_path and best_img_id in best_generated_keypoints_dict_raw:
            img = cv2.imread(str(best_img_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            pred_kpts = best_generated_keypoints_dict_raw[best_img_id]
            img_with_pose = draw_skeleton(img_rgb, pred_kpts, SKELETON)
            axes[row, 0].imshow(img_with_pose)
            axes[row, 0].set_title(f'Best Generated (MPJPE={best_result["mpjpe"]:.1f}px)',
                                  fontsize=11, fontweight='bold', color='green')
        else:
            axes[row, 0].text(0.5, 0.5, 'Image not found', ha='center', va='center')
            axes[row, 0].set_title('Best (generated) missing', fontsize=11, color='red')
        axes[row, 0].axis('off')

        # Original COCO image
        best_coco_path = find_coco_original_image(best_img_id)
        if best_coco_path and best_img_id in gt_keypoints_dict:
            img = cv2.imread(str(best_coco_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            ann_id = list(gt_keypoints_dict[best_img_id].keys())[0]
            gt_kpts = gt_keypoints_dict[best_img_id][ann_id]['keypoints']
            img_with_pose = draw_skeleton(img_rgb, gt_kpts, SKELETON)
            axes[row, 1].imshow(img_with_pose)
            axes[row, 1].set_title(f'Best Original COCO (OKS={best_result["oks"]:.3f})',
                                  fontsize=11, fontweight='bold', color='darkgreen')
        else:
            axes[row, 1].text(0.5, 0.5, 'Image not found', ha='center', va='center')
            axes[row, 1].set_title('Best (COCO) missing', fontsize=11, color='red')
        axes[row, 1].axis('off')

        # WORST EXAMPLE 
        # Generated image
        worst_img_path = find_generated_image_for_id(worst_img_id)
        if worst_img_path and worst_img_id in best_generated_keypoints_dict_raw:
            img = cv2.imread(str(worst_img_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            pred_kpts = best_generated_keypoints_dict_raw[worst_img_id]
            img_with_pose = draw_skeleton(img_rgb, pred_kpts, SKELETON)
            axes[row, 2].imshow(img_with_pose)
            axes[row, 2].set_title(f'Worst Generated (MPJPE={worst_result["mpjpe"]:.1f}px)',
                                  fontsize=11, fontweight='bold', color='red')
        else:
            axes[row, 2].text(0.5, 0.5, 'Image not found', ha='center', va='center')
            axes[row, 2].set_title('Worst (generated) missing', fontsize=11, color='red')
        axes[row, 2].axis('off')

        # Original COCO image
        worst_coco_path = find_coco_original_image(worst_img_id)
        if worst_coco_path and worst_img_id in gt_keypoints_dict:
            img = cv2.imread(str(worst_coco_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            ann_id = list(gt_keypoints_dict[worst_img_id].keys())[0]
            gt_kpts = gt_keypoints_dict[worst_img_id][ann_id]['keypoints']
            img_with_pose = draw_skeleton(img_rgb, gt_kpts, SKELETON)
            axes[row, 3].imshow(img_with_pose)
            axes[row, 3].set_title(f'Worst Original COCO (OKS={worst_result["oks"]:.3f})',
                                  fontsize=11, fontweight='bold', color='darkred')
        else:
            axes[row, 3].text(0.5, 0.5, 'Image not found', ha='center', va='center')
            axes[row, 3].set_title('Worst (COCO) missing', fontsize=11, color='red')
        axes[row, 3].axis('off')

    plt.tight_layout()
    plt.show()


## 8. Summary & Interpretation

**MPJPE:** Mean Per Joint Position Error after Procrustes alignment (pixels)
- Measures pure pose **shape** similarity independent of location/size/orientation
- Lower is better

**OKS:** Object Keypoint Similarity (reference metric)
- Standard COCO evaluation metric
- Higher is better

**Note:** Both metrics computed on Procrustes-aligned keypoints to isolate pose shape quality from spatial differences.