In [12]:
import os
import json
import random
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image, ImageDraw, ImageOps

In [3]:
BASE_DIR = "../RWDS_Dataset/RWDS_dataset/RWDS_CZ/train"
IMAGES_DIR = os.path.join(BASE_DIR, "Group_combined_train_images_512_02")
ANNOTATIONS_PATH = os.path.join(BASE_DIR, "Group_combined_train_512_02.json")
OUTPUT_DIR = "../Datasets/processed"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [4]:
def load_dataset(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)
    print(f"Loaded {len(data.get('images', []))} images and {len(data.get('annotations', []))} annotations.")
    return data

In [5]:
def build_lookups(data):
    images_info = data.get("images", [])
    annotations = data.get("annotations", [])
    categories = data.get("categories", [])
    
    # image_id -> annotations
    anns_by_imageid = {}
    for ann in annotations:
        anns_by_imageid.setdefault(str(ann["image_id"]), []).append(ann)
    
    cat_map = {str(cat["id"]): cat.get("name", f"cat{cat['id']}") for cat in categories} if categories else {}
    return images_info, anns_by_imageid, cat_map

In [6]:
def get_annotations_for_image(image_info, anns_by_imageid, all_annotations):
    """Match annotations to image robustly."""
    img_id = str(image_info["id"])
    fname = image_info["file_name"]
    
    # 1. Match by id
    anns = anns_by_imageid.get(img_id, [])
    if anns:
        return anns, "matched_by_id"

    # 2. Match by filename (rare case)
    anns = [a for a in all_annotations if str(a.get("image_id")) == fname]
    if anns:
        return anns, "matched_by_filename"

    # 3. Fallback: leading number
    leading = fname.split("_")[0]
    if leading.isdigit():
        anns = anns_by_imageid.get(leading, [])
        if anns:
            return anns, "matched_by_leading_number"

    return [], "no_match"

In [59]:
def save_image_with_boxes(image_path, annotations, cat_map, save_dir, filename_suffix=""):
    """Draw bounding boxes and save color image (for 'all' folder)."""
    if not os.path.exists(image_path):
        return

    img = Image.open(image_path).convert("RGB")
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.imshow(img)
    ax.axis("off")

    for ann in annotations:
        bbox = ann.get("bbox", [])
        if len(bbox) != 4:
            continue
        x, y, w, h = bbox
        rect = patches.Rectangle((x, y), w, h, linewidth=1.5, edgecolor='red', facecolor='none')
        ax.add_patch(rect)
        cat_name = cat_map.get(str(ann.get("category_id")), f"id{ann.get('category_id')}")
        ax.text(x, max(y - 3, 0), cat_name, color='white', fontsize=6,
                bbox=dict(facecolor='red', alpha=0.5, pad=0.2, edgecolor='none'))

    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, os.path.splitext(os.path.basename(image_path))[0] + filename_suffix + ".jpg")
    plt.tight_layout()
    plt.savefig(save_path, dpi=150, bbox_inches="tight", pad_inches=0)
    plt.close(fig)
    img.close()

In [None]:
def save_object_images(image_info, annotations, cat_map, image_dir, save_dir):
    """Save each object in an image as a separate cropped image, manually correcting 180° rotation."""
    image_path = os.path.join(image_dir, image_info["file_name"])
    if not os.path.exists(image_path):
        return
    
    img = Image.open(image_path).convert("RGB")
    img = img.rotate(180, expand=True)

    for idx, ann in enumerate(annotations, 1):
        bbox = ann.get("bbox", [])
        if len(bbox) != 4:
            continue
        x, y, w, h = map(int, bbox)
        cropped = img.crop((x, y, x + w, y + h))
        
        cat_name = cat_map.get(str(ann.get("category_id")), f"id{ann.get('category_id')}")
        obj_dir = os.path.join(save_dir, cat_name)
        os.makedirs(obj_dir, exist_ok=True)
        
        base_name = os.path.splitext(image_info["file_name"])[0]
        save_path = os.path.join(obj_dir, f"{base_name}_obj{idx}.jpg")
        cropped.save(save_path)
        cropped.close()
    
    img.close()


In [24]:
def save_object_images(image_info, annotations, cat_map, image_dir, save_dir):
    """
    Save each object in an image as a separate cropped image, fixing 180° rotation.
    """
    image_path = os.path.join(image_dir, image_info["file_name"])
    if not os.path.exists(image_path):
        return
    
    img = Image.open(image_path).convert("RGB")
    w_img, h_img = img.size

    # Rotate image 180 degrees
    img = img.rotate(180, expand=True)

    for idx, ann in enumerate(annotations, 1):
        bbox = ann.get("bbox", [])
        if len(bbox) != 4:
            continue
        x, y, w, h = map(int, bbox)

        # Adjust coordinates because image is rotated 180°
        x_new = w_img - (x + w)
        y_new = h_img - (y + h)

        cropped = img.crop((x_new, y_new, x_new + w, y_new + h))
        
        cat_name = cat_map.get(str(ann.get("category_id")), f"id{ann.get('category_id')}")
        obj_dir = os.path.join(save_dir, cat_name)
        os.makedirs(obj_dir, exist_ok=True)
        
        base_name = os.path.splitext(image_info["file_name"])[0]
        save_path = os.path.join(obj_dir, f"{base_name}_obj{idx}.jpg")
        cropped.save(save_path)
        cropped.close()
    
    img.close()

In [25]:
data = load_dataset(ANNOTATIONS_PATH)
images_info, anns_by_imageid, cat_map = build_lookups(data)
annotations_all = data.get("annotations", [])

# Create output folders
for cat_name in list(cat_map.values()) + ["all"]:
    os.makedirs(os.path.join(OUTPUT_DIR, cat_name), exist_ok=True)

Loaded 24160 images and 514625 annotations.


In [26]:
for idx, image_info in enumerate(images_info, 1):
    image_id = image_info["id"]
    file_name = image_info["file_name"]
    image_path = os.path.join(IMAGES_DIR, file_name)

    image_annotations, match_method = get_annotations_for_image(image_info, anns_by_imageid, annotations_all)
    if not image_annotations:
        continue

    # Get categories present
    categories_in_image = list({cat_map.get(str(a["category_id"]), f"id{a['category_id']}") for a in image_annotations})

    # Save "all" (colored with bboxes)
    # save_image_with_boxes(image_path, image_annotations, cat_map, os.path.join(OUTPUT_DIR, "all"))

    # Save individual objects instead of masks
    for cat in categories_in_image:
        cat_id = [k for k, v in cat_map.items() if v == cat][0]
        cat_anns = [a for a in image_annotations if str(a["category_id"]) == cat_id]
        save_object_images(image_info, cat_anns, cat_map, IMAGES_DIR, OUTPUT_DIR)

    # Save category masks (white object, black background)
    # for cat in categories_in_image:
    #     cat_id = [k for k, v in cat_map.items() if v == cat][0]
    #     cat_anns = [a for a in image_annotations if str(a["category_id"]) == cat_id]
    #     cat_dir = os.path.join(OUTPUT_DIR, cat)
    #     save_mask_image(image_info, cat_anns, cat_map, IMAGES_DIR, cat_dir)

    if idx % 100 == 0:
        print(f"Processed {idx}/{len(images_info)} images...")
    
    if idx == 1000:
        break

print("\nAll images processed and saved.")

Processed 100/24160 images...
Processed 200/24160 images...
Processed 300/24160 images...
Processed 500/24160 images...
Processed 600/24160 images...
Processed 700/24160 images...
Processed 800/24160 images...
Processed 900/24160 images...
Processed 1000/24160 images...

All images processed and saved.
