In [1]:
"""
Making a mini COCO dataset from the full COCO dataset with balanced classes
"""

import json
from pathlib import Path
from collections import defaultdict
import random

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [5]:
def make_coco_mini_balanced(ann_path: Path, out_path: Path, images_per_class: int = 250,
                            random_seed: int = 42,):
    """
    Create a mini COCO dataset with balanced classes.

    Args:
        ann_path: Path to original COCO annotation file
        out_path: Path to save mini COCO annotation file
        images_per_class: Number of images per category (default 250)
        random_seed: Random seed for reproducibility
    """
    random.seed(random_seed)

    # Load original annotations
    ann = json.loads(ann_path.read_text())
    images = ann["images"]
    annotations = ann["annotations"]
    categories = ann["categories"]

    # Create mappings for easier lookup
    image_dict = {img["id"]: img for img in images}
    category_dict = {cat["id"]: cat for cat in categories}

    # Group annotations by image_id and category_id
    image_annotations = defaultdict(lambda: defaultdict(list))
    image_categories = defaultdict(set) # image_id -> set of category_ids

    for ann in annotations:
        img_id = ann["image_id"]
        cat_id = ann["category_id"]
        image_annotations[img_id][cat_id].append(ann)
        image_categories[img_id].add(cat_id)

    # Count images per category
    category_image_counts = defaultdict(set) # category_id -> set of image_ids

    for img_id, cats in image_categories.items():
        for cat_id in cats:
            category_image_counts[cat_id].add(img_id)

    # Print initial statistics
    print("Original dataset statistics:")
    print(f"Total images: {len(images)}")
    print(f"Total categories: {len(categories)}")
    for cat_id, img_ids in sorted(category_image_counts.items()):
        cat_name = category_dict[cat_id]["name"]
        print(f"  Category {cat_id} ({cat_name}): {len(img_ids)} images")

    # Select balanced set of images
    selected_image_ids = set()

    # For each category, select images_per_class images
    for cat_id in sorted(category_dict.keys()):
        cat_name = category_dict[cat_id]["name"]
        available_images = list(category_image_counts[cat_id])

        if len(available_images) < images_per_class:
            print(f"Warning: Category {cat_id} ({cat_name}) has only {len(available_images)} images, "
                  f"requested {images_per_class}. Using all available.")
            selected_for_cat = available_images
        else:
            # Randomly select images for this category
            selected_for_cat = random.sample(available_images, images_per_class)

        selected_image_ids.update(selected_for_cat)
        print(f"Selected {len(selected_for_cat)} images for category {cat_id} ({cat_name})")

    # Convert to list and shuffle
    selected_image_ids = list(selected_image_ids)
    random.shuffle(selected_image_ids)

    # Limit to exactly images_per_class * num_categories if we have more
    max_images = images_per_class * len(categories)
    if len(selected_image_ids) > max_images:
        print(f"Note: Selected {len(selected_image_ids)} unique images, "
              f"keeping first {max_images} to maintain balance.")
        selected_image_ids = selected_image_ids[:max_images]

    # Create mini dataset
    mini_images = [image_dict[img_id] for img_id in selected_image_ids]

    # Collect all annotations for selected images
    mini_annotations = []
    for img_id in selected_image_ids:
        for cat_id, anns in image_annotations[img_id].items():
            mini_annotations.extend(anns)

    # Create final mini COCO dataset
    mini = {
        "info": ann.get("info", {}),
        "licenses": ann.get("licenses", []),
        "images": mini_images,
        "annotations": mini_annotations,
        "categories": categories,
    }

    # Save to file
    out_path.write_text(json.dumps(mini))

    # Print final statistics
    print("\nMini dataset statistics:")
    print(f"Saved mini COCO to {out_path}")
    print(f"Total images: {len(mini_images)}")
    print(f"Total annotations: {len(mini_annotations)}")

    # Count images per category in mini dataset
    mini_category_counts = defaultdict(set)
    for ann in mini_annotations:
        mini_category_counts[ann["category_id"]].add(ann["image_id"])

    print("\nImages per category in mini dataset:")
    for cat_id, img_ids in sorted(mini_category_counts.items()):
        cat_name = category_dict[cat_id]["name"]
        print(f"  Category {cat_id} ({cat_name}): {len(img_ids)} images")

    # Calculate overlap statistics
    total_images = len(selected_image_ids)
    print(f"\nTotal unique images selected: {total_images}")
    print(f"Target: {images_per_class} images per category × {len(categories)} categories = "
          f"{images_per_class * len(categories)} images")


if __name__ == "__main__":
    root = Path("/content/drive/MyDrive/xai-stability-data/coco/annotations")
    ann_path = root / "instances_train2017.json"
    out_path = root / "instances_train2017_mini.json"

    # 700 images per class × 80 classes = 56,000 images
    make_coco_mini_balanced(
        ann_path=ann_path,
        out_path=out_path,
        images_per_class=700,
        random_seed=42
    )

Original dataset statistics:
Total images: 118287
Total categories: 80
  Category 1 (person): 64115 images
  Category 2 (bicycle): 3252 images
  Category 3 (car): 12251 images
  Category 4 (motorcycle): 3502 images
  Category 5 (airplane): 2986 images
  Category 6 (bus): 3952 images
  Category 7 (train): 3588 images
  Category 8 (truck): 6127 images
  Category 9 (boat): 3025 images
  Category 10 (traffic light): 4139 images
  Category 11 (fire hydrant): 1711 images
  Category 13 (stop sign): 1734 images
  Category 14 (parking meter): 705 images
  Category 15 (bench): 5570 images
  Category 16 (bird): 3237 images
  Category 17 (cat): 4114 images
  Category 18 (dog): 4385 images
  Category 19 (horse): 2941 images
  Category 20 (sheep): 1529 images
  Category 21 (cow): 1968 images
  Category 22 (elephant): 2143 images
  Category 23 (bear): 960 images
  Category 24 (zebra): 1916 images
  Category 25 (giraffe): 2546 images
  Category 27 (backpack): 5528 images
  Category 28 (umbrella): 3968