In [None]:
import json
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import os
import numpy as np

# Option to specify UUIDs to display
uuids_to_display = None  # default - random
# uuids_to_display = ['example-uuid-1', 'example-uuid-2', 'example-uuid-3', 'example-uuid-4', 'example-uuid-5']

# Option to filter by category_id (None = no filter, 0 = grevy's zebra, 1 = plains zebra)
category_filter = None  # default - no filter
# category_filter = 0  # only grevy's zebra
# category_filter = 1  # only plains zebra

# Number of images to display
num_images = 5

# Load the JSON file with viewpoint annotations
json_path = '/fs/ess/PAS2136/ggr_data/results/GGR2020_subset_refactor_final/viewpoint_classifier/vc_annots.json'
with open(json_path, 'r') as f:
    data = json.load(f)

print(f"Total images: {len(data['images'])}")
print(f"Total annotations: {len(data['annotations'])}")
cats = [f"{cat['id']}: {cat['species']}" for cat in data['categories']]
print(f"Categories: {cats}")

# Print overall viewpoint statistics
from collections import Counter

print("\n" + "="*60)
print("OVERALL VIEWPOINT STATISTICS")
print("="*60)

# Count viewpoints per species
viewpoint_counter = Counter()
species_counter = Counter()
for ann in data['annotations']:
    viewpoint = ann['viewpoint']
    cat_id = ann['category_id']
    species = data['categories'][cat_id]['species']
    viewpoint_counter[viewpoint] += 1
    species_counter[(species, viewpoint)] += 1

# Print total counts per viewpoint
print("\nTotal annotations per viewpoint:")
for vp, count in sorted(viewpoint_counter.items()):
    print(f"  {vp}: {count}")


# Create a mapping from image_uuid to image info
image_map = {img['uuid']: img for img in data['images']}

# Group annotations by image
annotations_by_image = {}
for ann in data['annotations']:
    img_uuid = ann['image_uuid']
    if img_uuid not in annotations_by_image:
        annotations_by_image[img_uuid] = []
    annotations_by_image[img_uuid].append(ann)

# Get images that have annotations
images_with_annotations = list(annotations_by_image.keys())

# Apply category filter if specified
if category_filter is not None:
    filtered_images = []
    for img_uuid in images_with_annotations:
        # Check if any annotation for this image matches the category filter
        if any(ann['category_id'] == category_filter for ann in annotations_by_image[img_uuid]):
            filtered_images.append(img_uuid)
    images_with_annotations = filtered_images
    print(f"\nImages with detections after category filter (category_id={category_filter}): {len(images_with_annotations)}")
else:
    print(f"\nImages with detections: {len(images_with_annotations)}")

# Select images to display
if uuids_to_display is None:
    # Select random images with annotations
    selected_images = random.sample(images_with_annotations, min(num_images, len(images_with_annotations)))
else:
    # Use specified UUIDs (filter to only those that exist and have annotations)
    selected_images = [uuid for uuid in uuids_to_display if uuid in images_with_annotations]
    if len(selected_images) < len(uuids_to_display):
        print(f"Warning: Some UUIDs not found or have no annotations")

# Print selected UUIDs for future use
print(f"\nSelected {len(selected_images)} images")
print("\nTo re-display these same images, copy and paste this line at the beginning:")
print(f"uuids_to_display = {selected_images}")

# Define simple, high-contrast colors for different viewpoints
viewpoint_colors = {
    'front': 'red',
    'back': 'blue',
    'left': 'green',
    'right': 'orange',
    'frontleft': 'purple',
    'frontright': 'brown',
    'backleft': 'darkred',
    'backright': 'darkblue'
}

# Process each selected image
for img_idx, img_uuid in enumerate(selected_images):
    # Get image info and annotations
    img_info = image_map[img_uuid]
    img_path = img_info['image_path']
    annotations = annotations_by_image[img_uuid]
    
    # Apply category filter to annotations if specified
    if category_filter is not None:
        annotations = [ann for ann in annotations if ann['category_id'] == category_filter]
    
    # Check if image exists
    if not os.path.exists(img_path):
        print(f"Warning: Image not found at {img_path}")
        continue
    
    # Load image
    img = Image.open(img_path)
    img_array = np.array(img)
    
    print(f"\n{'='*60}")
    print(f"Image {img_idx+1}/{len(selected_images)}: {os.path.basename(img_path)}")
    print(f"Image UUID: {img_uuid}")
    print(f"Total detections: {len(annotations)}")
    print(f"{'='*60}")
    
    # First, show the full image with all detections
    fig_full = plt.figure(figsize=(12, 8))
    ax_full = fig_full.add_subplot(111)
    ax_full.imshow(img_array)
    
    # Draw all bounding boxes
    for ann_idx, ann in enumerate(annotations):
        x, y, w, h = ann['bbox']
        viewpoint = ann['viewpoint']
        color = viewpoint_colors.get(viewpoint, 'red')
        
        rect = patches.Rectangle((x, y), w, h, linewidth=3, 
                               edgecolor=color, facecolor='none')
        ax_full.add_patch(rect)
        
        # Add detection number
        ax_full.text(x + w/2, y - 10, str(ann_idx + 1), 
                    color='white', fontsize=14, weight='bold',
                    ha='center', va='bottom',
                    bbox=dict(boxstyle='circle,pad=0.3', facecolor=color, alpha=0.8))
    
    ax_full.set_title(f"Full Image with All Detections - {len(annotations)} total", fontsize=16)
    ax_full.axis('off')
    plt.tight_layout()
    plt.show()
    
    # Print annotation UUIDs
    print("\nAnnotation UUIDs for this image:")
    for ann_idx, ann in enumerate(annotations):
        print(f"  Detection {ann_idx + 1}: {ann['uuid']}")
    
    # Now show individual detections
    if len(annotations) > 0:
        n_cols = min(4, len(annotations))  # Max 4 columns for individual detections
        n_rows = (len(annotations) + n_cols - 1) // n_cols
        
        fig_ind, axes_ind = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows))
        
        # Handle different cases for axes
        if n_rows == 1 and n_cols == 1:
            axes_ind = [axes_ind]
        elif n_rows == 1:
            axes_ind = list(axes_ind)
        elif n_cols == 1:
            axes_ind = [ax for ax in axes_ind]
        else:
            axes_ind = [ax for row in axes_ind for ax in row]
        
        for ann_idx, ann in enumerate(annotations):
            ax = axes_ind[ann_idx]
            
            # Extract bounding box region
            x, y, w, h = ann['bbox']
            x1, y1 = int(max(0, x)), int(max(0, y))
            x2, y2 = int(min(img_array.shape[1], x + w)), int(min(img_array.shape[0], y + h))
            
            # Crop the detection
            detection_crop = img_array[y1:y2, x1:x2]
            
            # Display the cropped detection
            ax.imshow(detection_crop)
            
            # Create informative caption
            viewpoint = ann['viewpoint']
            confidence = ann['confidence']
            category_name = data['categories'][ann['category_id']]['species']
            color = viewpoint_colors.get(viewpoint, 'red')
            
            caption = f"Detection {ann_idx + 1}\n{viewpoint} ({confidence:.3f})\n{category_name}\nUUID: {ann['uuid'][:8]}..."
            ax.set_title(caption, fontsize=11, color='black', weight='bold')
            ax.axis('off')
        
        # Hide unused subplots
        for idx in range(len(annotations), len(axes_ind)):
            axes_ind[idx].axis('off')
        
        plt.suptitle(f"Individual Detections from Image {img_idx+1}", fontsize=16)
        plt.tight_layout()
        plt.show()

# Create a legend for viewpoint colors
fig_legend, ax_legend = plt.subplots(figsize=(10, 3))
ax_legend.axis('off')
handles = []
labels = []
for viewpoint, color in viewpoint_colors.items():
    handles.append(patches.Rectangle((0, 0), 1, 1, facecolor=color, edgecolor='black', linewidth=2))
    labels.append(viewpoint)
ax_legend.legend(handles, labels, loc='center', ncol=4, title='Viewpoint Colors', 
                 fontsize=12, title_fontsize=14, frameon=True, fancybox=True)
plt.tight_layout()
plt.show()

# Print summary
print("\n" + "="*60)
print("SUMMARY OF ALL VISUALIZED IMAGES")
print("="*60)
for idx, img_uuid in enumerate(selected_images):
    img_info = image_map[img_uuid]
    annotations = annotations_by_image[img_uuid]
    
    # Apply category filter if specified
    if category_filter is not None:
        annotations = [ann for ann in annotations if ann['category_id'] == category_filter]
    
    print(f"\nImage {idx+1}:")
    print(f"  Image UUID: {img_uuid}")
    print(f"  Path: {os.path.basename(img_info['image_path'])}")
    print(f"  Detections: {len(annotations)}")
    
    # Group by viewpoint for summary
    viewpoint_counts = {}
    for ann in annotations:
        vp = ann['viewpoint']
        cat_name = data['categories'][ann['category_id']]['species']
        key = f"{cat_name} - {vp}"
        viewpoint_counts[key] = viewpoint_counts.get(key, 0) + 1
    
    for key, count in viewpoint_counts.items():
        print(f"    - {key}: {count} detection(s)")
    
    # Print annotation details
    print(f"  Annotation details:")
    for ann_idx, ann in enumerate(annotations):
        cat_name = data['categories'][ann['category_id']]['species']
        print(f"    {ann_idx+1}. UUID: {ann['uuid']}")
        print(f"       {ann['viewpoint']}, conf: {ann['confidence']:.3f}, {cat_name}")