## Select sample from dataset with captions

In [None]:
import os
import json
import random
import shutil
from tqdm import tqdm

# === CONFIG ===
COCO_ROOT = './coco'
OUTPUT_ROOT = './dataset'
NUM_TRAIN = 5000
NUM_VAL = 1000
random.seed(42)

# === CREATE OUTPUT DIRECTORIES ===
os.makedirs(OUTPUT_ROOT, exist_ok=True)

ANNOTATIONS = {
    "train": os.path.join(COCO_ROOT, "annotations", "captions_train2017.json"),
    "val": os.path.join(COCO_ROOT, "annotations", "captions_val2017.json"),
}
IMAGES = {
    "train": os.path.join(COCO_ROOT, "train2017"),
    "val": os.path.join(COCO_ROOT, "val2017"),
}
OUTPUT = {
    "train": os.path.join(OUTPUT_ROOT, "train"),
    "val": os.path.join(OUTPUT_ROOT, "val"),
}

In [9]:
# === HELPER FUNCTION ===
def load_coco_annotations(ann_path, min_captions=2):
    with open(ann_path, 'r') as f:
        data = json.load(f)

    id_to_filename = {img['id']: img['file_name'] for img in data['images']}
    image_captions = {}

    for ann in data['annotations']:
        img_id = ann['image_id']
        if img_id not in image_captions:
            image_captions[img_id] = []
        image_captions[img_id].append(ann['caption'])

    entries = []
    for img_id, captions in image_captions.items():
        if len(captions) >= min_captions:
            fname = id_to_filename[img_id]
            entries.append((fname, captions))  # Keep all captions

    return entries


# === MAIN EXTRACTION FUNCTION ===
def process_split(split, num_samples):
    print(f"Processing {split} split...")

    os.makedirs(os.path.join(OUTPUT[split], "images"), exist_ok=True)

    entries = load_coco_annotations(ANNOTATIONS[split])
    print(f"Total available {split} entries: {len(entries)}")

    # Shuffle and select only those with existing images
    random.shuffle(entries)
    selected = []
    for fname, caption in entries:
        src_path = os.path.join(IMAGES[split], fname)
        if os.path.exists(src_path):
            selected.append((fname, caption))
        if len(selected) == num_samples:
            break

    print(f"Found {len(selected)} valid samples with existing images.")

    captions_dict = {}

    for fname, caption in tqdm(selected):
        src_path = os.path.join(IMAGES[split], fname)
        dst_path = os.path.join(OUTPUT[split], "images", fname)
        shutil.copy(src_path, dst_path)
        captions_dict[fname] = caption

    with open(os.path.join(OUTPUT[split], "captions.json"), "w") as f:
        json.dump(captions_dict, f, indent=2)

    print(f"Saved {len(captions_dict)} {split} samples to {OUTPUT[split]}.")


In [10]:
# === RUN ===
process_split("train", NUM_TRAIN)
process_split("val", NUM_VAL)

Processing train split...
Total available train entries: 118287
Found 5000 valid samples with existing images.


100%|██████████| 5000/5000 [00:01<00:00, 4501.60it/s]


Saved 5000 train samples to ./dataset/train.
Processing val split...
Total available val entries: 5000
Found 1000 valid samples with existing images.


100%|██████████| 1000/1000 [00:00<00:00, 1701.71it/s]

Saved 1000 val samples to ./dataset/val.



