In [None]:
import json
import matplotlib.pyplot as plt
import os
from collections import defaultdict
from PIL import Image
from datetime import datetime
import numpy as np

# Path to your annotation file
# annots_paths = ["/fs/ess/PAS2136/ggr_data/results/GGR2020_subset_refactor_final/lca/lca_left_annots.json", 
#                 "/fs/ess/PAS2136/ggr_data/results/GGR2020_subset_refactor_final/lca/lca_right_annots.json",]
annots_paths = ["/fs/ess/PAS2136/ggr_data/results/D2_Yolov9/lca/lca_left_annots.json", 
                "/fs/ess/PAS2136/ggr_data/results/D2_Yolov9/lca/lca_right_annots.json"]
# Create output directory for figures
output_dir = "cluster_visualizations"
os.makedirs(output_dir, exist_ok=True)
print(f"Saving figures to: {output_dir}")

for annots_path in annots_paths:
    # Load the annotation data
    with open(annots_path, "r") as f:
        data = json.load(f)
    # Build a mapping from image_uuid to image_path
    uuid_to_path = {img["uuid"]: img["image_path"] for img in data["images"]}

    # Organize annotations by cluster and viewpoint
    clusters = defaultdict(lambda: defaultdict(list))
    for ann in data["annotations"]:
        cluster_id = ann.get("LCA_clustering_id", "unknown")
        viewpoint = ann.get("viewpoint", "unknown")
        image_uuid = ann["image_uuid"]
        image_path = uuid_to_path.get(image_uuid)
        bbox = ann.get("bbox")
        if image_path and bbox and os.path.exists(image_path):
            clusters[cluster_id][viewpoint].append((image_path, bbox, ann['uuid']))

    # Count total figures to generate
    total_figures = sum(len(viewpoints) for viewpoints in clusters.values())
    figure_count = 0

    # Visualization: save cropped images for each cluster and viewpoint, max 3 per row
    for cluster_id in sorted(clusters.keys(), key=lambda x: int(x) if x.isdigit() else x):
        print(f"\nProcessing Cluster {cluster_id}")
        for viewpoint in sorted(clusters[cluster_id].keys()):
            crops = clusters[cluster_id][viewpoint]
            n_crops = len(crops)
            if n_crops == 0:
                continue
            
            figure_count += 1
            print(f"  Generating figure {figure_count}/{total_figures}: Cluster {cluster_id} - Viewpoint {viewpoint}")
            
            n_cols = 3
            n_rows = (n_crops + n_cols - 1) // n_cols
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 4))
            
            # Ensure axes is always a list, even for single subplot
            if n_rows == 1 and n_cols == 1:
                axes = [axes]
            elif n_rows == 1 or n_cols == 1:
                axes = axes.flatten()
            else:
                axes = axes.flatten()
            
            fig.suptitle(f"Cluster {cluster_id} - Viewpoint: {viewpoint}", fontsize=16)
            
            for i, (img_path, bbox, annot_name) in enumerate(crops):
                ax = axes[i]
                try:
                    img = Image.open(img_path)
                    x, y, w, h = map(int, bbox)
                    crop = img.crop((x, y, x + w, y + h))
                    ax.imshow(crop)
                    ax.set_title(annot_name, fontsize=8)
                    ax.axis('off')
                except Exception as e:
                    ax.text(0.5, 0.5, f"Error: {e}", ha='center', va='center')
                    ax.axis('off')
            
            # Hide unused axes
            for j in range(n_crops, n_rows * n_cols):
                axes[j].axis('off')
            
            plt.tight_layout()
            
            # Save the figure
            filename = f"cluster_{cluster_id}_viewpoint_{viewpoint}.png"
            filepath = os.path.join(output_dir, filename)
            plt.savefig(filepath, dpi=150, bbox_inches='tight')
            plt.close(fig)  # Close to free memory
            print(f"    Saved: {filename}")

    print(f"\nCompleted! All {figure_count} figures saved to '{output_dir}' directory.")