# Body Pose Diversity Metrics
## Pose Entropy & Average Pairwise Diversity

This notebook calculates body pose diversity metrics using MoveNet Lightning (lightweight ~3MB).

## 1. Setup and Installation

In [None]:
# Install dependencies
!pip install -q tensorflow tensorflow-hub numpy scipy matplotlib pillow tqdm

In [None]:
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
from scipy import stats
from scipy.spatial.distance import pdist, squareform
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import json

print(f"TensorFlow version: {tf.__version__}")

## 2. Load MoveNet Lightning Model

MoveNet Lightning is a very fast and lightweight pose estimation model (~3MB).

In [None]:
# Load MoveNet Lightning model from TensorFlow Hub
model_url = "https://tfhub.dev/google/movenet/singlepose/lightning/4"
model = hub.load(model_url)
movenet = model.signatures['serving_default']

print("MoveNet Lightning loaded successfully!")

# MoveNet keypoint indices
KEYPOINT_NAMES = [
    'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
    'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
    'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
    'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
]

# Keypoint connections for visualization
KEYPOINT_EDGES = [
    (0, 1), (0, 2), (1, 3), (2, 4),  # Face
    (5, 6), (5, 7), (7, 9), (6, 8), (8, 10),  # Arms
    (5, 11), (6, 12), (11, 12),  # Torso
    (11, 13), (13, 15), (12, 14), (14, 16)  # Legs
]

## 3. Pose Extraction Functions

In [None]:
def load_and_preprocess_image(image_path, input_size=192):
    """Load and preprocess image for MoveNet"""
    img = tf.io.read_file(str(image_path))
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.resize_with_pad(img, input_size, input_size)
    img = tf.cast(img, dtype=tf.int32)
    return img

def extract_keypoints(image_path):
    """Extract pose keypoints from a single image"""
    try:
        img = load_and_preprocess_image(image_path)
        img = tf.expand_dims(img, axis=0)
        
        outputs = movenet(img)
        keypoints = outputs['output_0'].numpy()[0, 0]  # Shape: (17, 3)
        
        return keypoints  # [y, x, confidence] for each keypoint
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None

def extract_all_poses(image_dir, max_images=None):
    """Extract poses from all images in directory"""
    image_dir = Path(image_dir)
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.webp']
    
    image_paths = []
    for ext in extensions:
        image_paths.extend(list(image_dir.rglob(ext)))
    
    if max_images:
        image_paths = image_paths[:max_images]
    
    print(f"Processing {len(image_paths)} images...")
    
    all_keypoints = []
    valid_paths = []
    
    for path in tqdm(image_paths):
        kp = extract_keypoints(path)
        if kp is not None:
            all_keypoints.append(kp)
            valid_paths.append(path)
    
    return np.array(all_keypoints), valid_paths

## 4. Pose Diversity Metrics

In [None]:
def normalize_pose(keypoints):
    """Normalize pose to be translation and scale invariant"""
    # Use hip center as origin
    left_hip = keypoints[11, :2]
    right_hip = keypoints[12, :2]
    hip_center = (left_hip + right_hip) / 2
    
    # Translate to origin
    normalized = keypoints[:, :2] - hip_center
    
    # Scale by torso length (shoulder to hip)
    left_shoulder = keypoints[5, :2]
    right_shoulder = keypoints[6, :2]
    shoulder_center = (left_shoulder + right_shoulder) / 2
    torso_length = np.linalg.norm(shoulder_center - hip_center)
    
    if torso_length > 0.01:  # Avoid division by zero
        normalized = normalized / torso_length
    
    return normalized.flatten()  # Return as 1D vector (34 values)

def compute_pose_entropy(keypoints_list, num_bins=20):
    """
    Compute pose entropy using histogram binning.
    Higher entropy = more diverse poses.
    """
    # Normalize all poses
    normalized_poses = []
    for kp in keypoints_list:
        normalized = normalize_pose(kp)
        normalized_poses.append(normalized)
    
    normalized_poses = np.array(normalized_poses)
    
    # Compute entropy for each dimension and average
    entropies = []
    for dim in range(normalized_poses.shape[1]):
        values = normalized_poses[:, dim]
        
        # Create histogram
        hist, _ = np.histogram(values, bins=num_bins, density=True)
        hist = hist[hist > 0]  # Remove zero bins
        
        # Compute Shannon entropy
        entropy = -np.sum(hist * np.log(hist + 1e-10)) / np.log(num_bins)
        entropies.append(entropy)
    
    return np.mean(entropies), entropies

def compute_pairwise_diversity(keypoints_list, max_pairs=5000):
    """
    Compute average pairwise distance between poses.
    Higher value = more diverse poses.
    """
    # Normalize all poses
    normalized_poses = []
    for kp in keypoints_list:
        normalized = normalize_pose(kp)
        normalized_poses.append(normalized)
    
    normalized_poses = np.array(normalized_poses)
    
    # Sample if too many poses (for efficiency)
    n = len(normalized_poses)
    if n * (n - 1) / 2 > max_pairs:
        indices = np.random.choice(n, size=int(np.sqrt(2 * max_pairs)), replace=False)
        normalized_poses = normalized_poses[indices]
    
    # Compute pairwise distances
    distances = pdist(normalized_poses, metric='euclidean')
    
    return np.mean(distances), np.std(distances), distances

## 5. Load Dataset Configuration

In [None]:
# Try to load config from dataset download notebook
config_path = Path('/content/datasets/dataset_config.json')

if config_path.exists():
    with open(config_path) as f:
        config = json.load(f)
    print("Loaded dataset configuration")
else:
    # Manual configuration
    config = {
        'vitonhd': '/content/datasets/vitonhd',
        'deepfashion1': '/content/datasets/deepfashion1',
        'dresscode': '/content/datasets/dresscode',
    }
    print("Using default paths - run 01_dataset_download.ipynb first")

print(f"Dataset paths: {config}")

## 6. Evaluate Pose Diversity on Datasets

In [None]:
def evaluate_pose_diversity(dataset_name, dataset_path, max_images=500):
    """Evaluate pose diversity metrics for a dataset"""
    print(f"\n{'='*60}")
    print(f"Evaluating: {dataset_name}")
    print(f"{'='*60}")
    
    dataset_path = Path(dataset_path)
    if not dataset_path.exists():
        print(f"Dataset path not found: {dataset_path}")
        return None
    
    # Extract poses
    keypoints, paths = extract_all_poses(dataset_path, max_images)
    
    if len(keypoints) == 0:
        print("No valid poses extracted")
        return None
    
    print(f"\nExtracted {len(keypoints)} valid poses")
    
    # Compute metrics
    print("\nComputing pose entropy...")
    entropy, dim_entropies = compute_pose_entropy(keypoints)
    
    print("Computing pairwise diversity...")
    avg_distance, std_distance, distances = compute_pairwise_diversity(keypoints)
    
    results = {
        'dataset': dataset_name,
        'num_images': len(keypoints),
        'pose_entropy': float(entropy),
        'avg_pairwise_diversity': float(avg_distance),
        'std_pairwise_diversity': float(std_distance),
    }
    
    print(f"\nResults for {dataset_name}:")
    print(f"  - Pose Entropy: {entropy:.4f}")
    print(f"  - Avg Pairwise Diversity: {avg_distance:.4f} (Â±{std_distance:.4f})")
    
    return results, keypoints, distances

In [None]:
# Evaluate all datasets
all_results = {}

for name, path in config.items():
    if name in ['vitonhd', 'deepfashion1', 'dresscode']:
        result = evaluate_pose_diversity(name.upper(), path, max_images=500)
        if result:
            all_results[name] = result[0]

## 7. Visualization

In [None]:
def visualize_pose(image_path, keypoints, ax=None):
    """Visualize pose on image"""
    img = Image.open(image_path)
    img = np.array(img)
    h, w = img.shape[:2]
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 8))
    
    ax.imshow(img)
    
    # Draw keypoints
    for i, (y, x, conf) in enumerate(keypoints):
        if conf > 0.3:
            ax.plot(x * w, y * h, 'ro', markersize=5)
    
    # Draw skeleton
    for start, end in KEYPOINT_EDGES:
        if keypoints[start, 2] > 0.3 and keypoints[end, 2] > 0.3:
            y1, x1 = keypoints[start, :2]
            y2, x2 = keypoints[end, :2]
            ax.plot([x1 * w, x2 * w], [y1 * h, y2 * h], 'g-', linewidth=2)
    
    ax.axis('off')
    return ax

# Visualize sample poses
def visualize_samples(dataset_path, num_samples=4):
    """Visualize sample poses from dataset"""
    dataset_path = Path(dataset_path)
    if not dataset_path.exists():
        return
    
    images = list(dataset_path.rglob('*.jpg'))[:num_samples] + \
             list(dataset_path.rglob('*.png'))[:num_samples]
    images = images[:num_samples]
    
    if not images:
        return
    
    fig, axes = plt.subplots(1, len(images), figsize=(4*len(images), 6))
    if len(images) == 1:
        axes = [axes]
    
    for ax, img_path in zip(axes, images):
        kp = extract_keypoints(img_path)
        if kp is not None:
            visualize_pose(img_path, kp, ax)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize sample poses from each dataset
for name, path in config.items():
    if name in ['vitonhd', 'deepfashion1', 'dresscode']:
        print(f"\n{name.upper()} Sample Poses:")
        visualize_samples(path, num_samples=4)

In [None]:
# Compare metrics across datasets
if all_results:
    datasets = list(all_results.keys())
    entropies = [all_results[d]['pose_entropy'] for d in datasets]
    diversities = [all_results[d]['avg_pairwise_diversity'] for d in datasets]
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    # Pose Entropy comparison
    axes[0].bar(datasets, entropies, color=['#3498db', '#e74c3c', '#2ecc71'])
    axes[0].set_ylabel('Pose Entropy')
    axes[0].set_title('Pose Entropy by Dataset')
    axes[0].set_ylim(0, 1)
    
    # Pairwise Diversity comparison
    axes[1].bar(datasets, diversities, color=['#3498db', '#e74c3c', '#2ecc71'])
    axes[1].set_ylabel('Avg Pairwise Diversity')
    axes[1].set_title('Pairwise Pose Diversity by Dataset')
    
    plt.tight_layout()
    plt.show()

## 8. Save Results

In [None]:
# Save results to JSON
results_path = Path('/content/datasets/pose_diversity_results.json')

with open(results_path, 'w') as f:
    json.dump(all_results, f, indent=2)

print(f"Results saved to: {results_path}")

# Print summary table
print("\n" + "="*60)
print("POSE DIVERSITY METRICS SUMMARY")
print("="*60)
print(f"{'Dataset':<15} {'Pose Entropy':<15} {'Pairwise Div.':<15}")
print("-"*45)
for name, results in all_results.items():
    print(f"{name:<15} {results['pose_entropy']:<15.4f} {results['avg_pairwise_diversity']:<15.4f}")
print("="*60)