In [None]:
import json
import random
import matplotlib.pyplot as plt
from PIL import Image
import os
import math

# Parameters
json_path = '/fs/ess/PAS2136/ggr_data/results/GGR2020_subset_refactor_final/species_identifier/si_annots.json'
n_per_category = 30  # Total crops per category (adjust as needed, must be multiple of 3 for best layout)
max_per_row = 3    # No more than 3 crops per row

# Load the JSON file
with open(json_path, 'r') as f:
    data = json.load(f)

# Build category mapping
cat_id_to_name = {cat['id']: cat['species'] for cat in data['categories']}
print("Categories:", cat_id_to_name)

# Build image mapping
image_map = {img['uuid']: img for img in data['images']}

# Group annotations by category
annotations_by_category = {cat['id']: [] for cat in data['categories']}
for ann in data['annotations']:
    annotations_by_category[ann['category_id']].append(ann)

# For each category, select n random annotation crops
for cat_id, cat_name in cat_id_to_name.items():
    anns = annotations_by_category[cat_id]
    if not anns:
        print(f"No annotations found for category '{cat_name}'")
        continue
    selected_anns = random.sample(anns, min(n_per_category, len(anns)))
    n = len(selected_anns)
    n_rows = math.ceil(n / max_per_row)
    print(f"\nCategory '{cat_name}' ({cat_id}): showing {n} annotation crops")

    fig, axes = plt.subplots(n_rows, max_per_row, figsize=(4*max_per_row, 4*n_rows))
    axes = axes.flatten() if n_rows > 1 else [axes]  # flatten for easy indexing

    for idx, ann in enumerate(selected_anns):
        img_uuid = ann['image_uuid']
        img_info = image_map[img_uuid]
        img_path = img_info['image_path']
        if not os.path.exists(img_path):
            print(f"Warning: Image not found at {img_path}")
            axes[idx].text(0.5, 0.5, 'Image not found', ha='center', va='center', transform=axes[idx].transAxes)
            axes[idx].set_title(f"Missing image")
            axes[idx].axis('off')
            continue

        img = Image.open(img_path)
        x, y, w, h = map(int, ann['bbox'])
        crop = img.crop((x, y, x+w, y+h))
        axes[idx].imshow(crop)
        axes[idx].set_title(f"Conf: {ann['confidence']:.2f}\nUUID: {ann['uuid'][:8]}")
        axes[idx].axis('off')

        # Print details
        print(f"  Ann UUID: {ann['uuid']}, Image: {os.path.basename(img_path)}, Conf: {ann['confidence']:.3f}, BBox: {ann['bbox']}")

    # Hide any unused axes
    for j in range(idx+1, n_rows*max_per_row):
        axes[j].axis('off')

    plt.suptitle(f"Category: {cat_name}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.show()