In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import os
import scipy.io


In [None]:
# File paths
mat_file_path = "/Users/keshavsaraogi/data/pose/joints.mat"
image_folder = "/Users/keshavsaraogi/data/pose/images/"

In [None]:
# Load joint data from joints.mat
data = scipy.io.loadmat(mat_file_path)
joints = data["joints"]  # Shape: (3, 14, 2000) -> (x, y, visibility)
num_images = joints.shape[2]

In [None]:
# Define joint names
joint_names = [
    "Right ankle", "Right knee", "Right hip", "Left hip", "Left knee", "Left ankle",
    "Right wrist", "Right elbow", "Right shoulder", "Left shoulder", "Left elbow", "Left wrist",
    "Neck", "Head top"
]

In [None]:
# Function to visualize keypoints on an image
def visualize_pose(image_path, joints, img_index):
    img = cv2.imread(image_path)
    if img is None:
        print(f"Error: Could not load image {image_path}")
        return
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB for correct colors
    keypoints = joints[:, :, img_index]  # Get keypoints for the image

    plt.figure(figsize=(6, 6))
    plt.imshow(img)
    
    # Plot joints
    for i in range(14):
        x, y, v = keypoints[:, i]
        if v == 1:  # Visible keypoints (red)
            plt.scatter(x, y, color="red", marker="o", label=joint_names[i] if i == 0 else "")
        else:  # Missing keypoints (gray)
            plt.scatter(x, y, color="gray", marker="x", alpha=0.5)

    plt.title(f"Pose Visualization for Image {img_index + 1}")
    plt.axis("off")
    plt.show()


In [None]:
# Visualize first 5 images (including missing keypoints)
for img_index in range(5):  
    img_path = os.path.join(image_folder, f"im{img_index+1:04d}.jpg")
    
    if os.path.exists(img_path):
        visualize_pose(img_path, joints, img_index)

print("Displayed first 5 images including missing keypoints.")