## Select sample from dataset with captions

In [1]:
import os
import json
import random
import shutil
from tqdm import tqdm

# === CONFIG ===
COCO_ROOT = './coco'
OUTPUT_ROOT = './dataset_cap'
NUM_TRAIN = 16000
NUM_VAL = 4000
random.seed(42)

# === CREATE OUTPUT DIRECTORIES ===
os.makedirs(OUTPUT_ROOT, exist_ok=True)

ANNOTATIONS = {
    "train": os.path.join(COCO_ROOT, "annotations", "captions_train2017.json"),
    "val": os.path.join(COCO_ROOT, "annotations", "captions_val2017.json"),
}
INSTANCES = {
    "train": os.path.join(COCO_ROOT, "annotations", "instances_train2017.json"),
    "val": os.path.join(COCO_ROOT, "annotations", "instances_val2017.json"),
}
IMAGES = {
    "train": os.path.join(COCO_ROOT, "train2017"),
    "val": os.path.join(COCO_ROOT, "val2017"),
}
OUTPUT = {
    "train": os.path.join(OUTPUT_ROOT, "train"),
    "val": os.path.join(OUTPUT_ROOT, "val"),
}

In [2]:
# === LOAD CAPTIONS ===
def load_coco_annotations(ann_path, min_captions=2):
    with open(ann_path, 'r') as f:
        data = json.load(f)

    id_to_filename = {img['id']: img['file_name'] for img in data['images']}
    image_captions = {}

    for ann in data['annotations']:
        img_id = ann['image_id']
        if img_id not in image_captions:
            image_captions[img_id] = []
        image_captions[img_id].append(ann['caption'])

    entries = []
    for img_id, captions in image_captions.items():
        if len(captions) >= min_captions:
            fname = id_to_filename[img_id]
            entries.append((img_id, fname, captions))  # include image_id

    return entries

# === LOAD CATEGORIES FROM INSTANCES FILE ===
def load_coco_categories(instances_path):
    with open(instances_path, 'r') as f:
        data = json.load(f)

    img_to_cats = {}
    for ann in data['annotations']:
        img_id = ann['image_id']
        if img_id not in img_to_cats:
            img_to_cats[img_id] = []
        img_to_cats[img_id].append(ann['category_id'])

    return img_to_cats

# === MAIN EXTRACTION FUNCTION ===
def process_split(split, num_samples):
    print(f"Processing {split} split...")

    os.makedirs(os.path.join(OUTPUT[split], "images"), exist_ok=True)

    # Load captions and categories
    entries = load_coco_annotations(ANNOTATIONS[split])
    img_to_cats = load_coco_categories(INSTANCES[split])

    print(f"Total available {split} entries: {len(entries)}")

    random.shuffle(entries)
    selected = []

    for img_id, fname, captions in entries:
        src_path = os.path.join(IMAGES[split], fname)
        if os.path.exists(src_path):
            selected.append((img_id, fname, captions))
        if len(selected) == num_samples:
            break

    print(f"Found {len(selected)} valid samples with existing images.")

    captions_dict = {}

    for img_id, fname, captions in tqdm(selected):
        src_path = os.path.join(IMAGES[split], fname)
        dst_path = os.path.join(OUTPUT[split], "images", fname)
        shutil.copy(src_path, dst_path)

        categories = img_to_cats.get(img_id, [])
        cat_id = max(set(categories), key=categories.count) if categories else -1
        
        if cat_id == -1:
            continue

        captions_dict[fname] = {
            "captions": captions,
            "category_id": cat_id
        }

    with open(os.path.join(OUTPUT[split], "captions.json"), "w") as f:
        json.dump(captions_dict, f, indent=2)

    print(f"Saved {len(captions_dict)} {split} samples to {OUTPUT[split]}.")


In [3]:
# === RUN ===
process_split("train", NUM_TRAIN)
process_split("val", NUM_VAL)

Processing train split...
Total available train entries: 118287
Found 8181 valid samples with existing images.


100%|██████████| 8181/8181 [00:35<00:00, 228.56it/s]


Saved 8097 train samples to ./dataset_cap/train.
Processing val split...
Total available val entries: 5000
Found 4000 valid samples with existing images.


100%|██████████| 4000/4000 [00:10<00:00, 374.14it/s] 

Saved 3967 val samples to ./dataset_cap/val.



