In [5]:
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
!unzip annotations_trainval2017.zip

--2025-08-07 16:28:33--  http://images.cocodataset.org/annotations/annotations_trainval2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 3.5.29.225, 3.5.30.39, 3.5.21.80, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|3.5.29.225|:80... connected.
HTTP request sent, awaiting response... 503 Slow Down
2025-08-07 16:28:33 ERROR 503: Slow Down.

unzip:  cannot find or open annotations_trainval2017.zip, annotations_trainval2017.zip.zip or annotations_trainval2017.zip.ZIP.


In [1]:
from pycocotools.coco import COCO
import pandas as pd
import json, os, random

# --- paths ---
ANNO_DIR = '../data/annotations_trainval2017/annotations'
instances_json = f'{ANNO_DIR}/instances_train2017.json'
captions_json  = f'{ANNO_DIR}/captions_train2017.json'
SAVE_DIR = '../data'
os.makedirs(SAVE_DIR, exist_ok=True)

# --- load annotations ---
coco = COCO(instances_json)
cap  = COCO(captions_json)

cat_id   = coco.getCatIds(catNms=['cat'])[0]      # 17
dog_id   = coco.getCatIds(catNms=['dog'])[0]      # 18
human_id = coco.getCatIds(catNms=['person'])[0]   # 1

# Precompute membership sets (faster than per-image getAnnIds calls)
imgs_with_cat   = set(coco.getImgIds(catIds=[cat_id]))
imgs_with_dog   = set(coco.getImgIds(catIds=[dog_id]))
imgs_with_human = set(coco.getImgIds(catIds=[human_id]))

# Map: image_id -> [captions...]
caps_by_img = {}
for ann in cap.dataset['annotations']:
    caps_by_img.setdefault(ann['image_id'], []).append(ann['caption'])

# Build master table (one row per image)
records = []
for img in coco.loadImgs(coco.getImgIds()):
    img_id = img['id']
    records.append({
        'image_id': img_id,
        'file_name': img['file_name'],
        'captions': json.dumps(caps_by_img.get(img_id, []), ensure_ascii=False),
        'has_cat':   int(img_id in imgs_with_cat),
        'has_dog':   int(img_id in imgs_with_dog),
        'has_human': int(img_id in imgs_with_human),
    })
df = pd.DataFrame(records)

# --- helper: balance to ~1:1 by downsampling majority ---
SEED = 42
def balanced_1to1(df_bin, seed=SEED):
    pos = df_bin[df_bin['label'] == 1]
    neg = df_bin[df_bin['label'] == 0]
    if len(pos) == 0 or len(neg) == 0:
        return df_bin.copy()
    n = min(len(pos), len(neg))
    out = pd.concat([
        pos.sample(n, random_state=seed),
        neg.sample(n, random_state=seed)
    ], axis=0).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    return out

# --- build three binary datasets from the full DF (overlap allowed) ---
def make_binary(df, label_col):
    raw = df[['image_id', 'file_name', 'captions', label_col]].rename(columns={label_col: 'label'})
    bal = balanced_1to1(raw)
    return raw, bal

cat_raw,   cat_df   = make_binary(df, 'has_cat')
dog_raw,   dog_df   = make_binary(df, 'has_dog')
human_raw, human_df = make_binary(df, 'has_human')

# --- quick stats ---
def stats(name, d):
    pos = int(d.label.sum()); neg = len(d) - pos
    print(f"{name}: {len(d)} rows | pos={pos}, neg={neg}")

stats("cat_df (balanced)",   cat_df)
stats("dog_df (balanced)",   dog_df)
stats("human_df (balanced)", human_df)

# --- save (balanced) ---
cat_df.to_csv(  f'{SAVE_DIR}/coco_cat_binary_with_captions_balanced.csv',   index=False)
dog_df.to_csv(  f'{SAVE_DIR}/coco_dog_binary_with_captions_balanced.csv',   index=False)
human_df.to_csv(f'{SAVE_DIR}/coco_human_binary_with_captions_balanced.csv', index=False)

# (optional) also save the full, unbalanced versions for reference
cat_raw.to_csv(  f'{SAVE_DIR}/coco_cat_binary_with_captions_all.csv',   index=False)
dog_raw.to_csv(  f'{SAVE_DIR}/coco_dog_binary_with_captions_all.csv',   index=False)
human_raw.to_csv(f'{SAVE_DIR}/coco_human_binary_with_captions_all.csv', index=False)


loading annotations into memory...
Done (t=14.84s)
creating index...
index created!
loading annotations into memory...
Done (t=0.76s)
creating index...
index created!
cat_df (balanced): 8228 rows | pos=4114, neg=4114
dog_df (balanced): 8770 rows | pos=4385, neg=4385
human_df (balanced): 108344 rows | pos=54172, neg=54172


In [3]:
human_df

Unnamed: 0,image_id,file_name,captions,label
0,92953,000000092953.jpg,"[""A woman rides a horse while others look on"",...",1
1,533739,000000533739.jpg,"[""Heavy traffic in a city with a \""Citi Bank\""...",1
2,72850,000000072850.jpg,"[""a close up of a baseball player with a ball ...",1
3,421103,000000421103.jpg,"[""A horse and donkeys standing up in a field o...",0
4,453481,000000453481.jpg,"[""A couple of people with some bikes on a stre...",1
...,...,...,...,...
108339,571287,000000571287.jpg,"[""A large cooked cut pizza on a table."", ""A cl...",0
108340,266503,000000266503.jpg,"[""A sign that has a camel on it."", ""Various st...",0
108341,464498,000000464498.jpg,"[""a few batches of various baked goods like do...",0
108342,394627,000000394627.jpg,"[""A small boy holding a yellow baseball bat, ""...",1


In [38]:
coco.loadAnns(coco.getAnnIds(imgIds=[img_id],iscrowd=None))

[{'segmentation': [[303.0,
    227.0,
    303.0,
    221.0,
    296.0,
    215.0,
    292.0,
    208.0,
    297.0,
    199.0,
    301.0,
    189.0,
    308.0,
    179.0,
    316.0,
    169.0,
    325.0,
    166.0,
    326.0,
    166.0,
    323.0,
    159.0,
    330.0,
    141.0,
    335.0,
    132.0,
    339.0,
    130.0,
    351.0,
    131.0,
    358.0,
    139.0,
    359.0,
    146.0,
    359.0,
    158.0,
    363.2,
    173.05,
    371.11,
    178.99,
    374.41,
    182.29,
    376.39,
    196.14,
    377.05,
    204.72,
    380.35,
    213.95,
    382.99,
    220.55,
    384.31,
    231.77,
    382.33,
    237.05,
    375.07,
    236.39,
    370.45,
    231.77,
    369.79,
    225.83,
    361.88,
    231.11,
    348.68,
    231.11,
    307.78,
    235.07,
    306.46,
    229.79,
    303.16,
    223.19]],
  'area': 5945.7057,
  'iscrowd': 0,
  'image_id': 475546,
  'bbox': [292.0, 130.0, 92.31, 107.05],
  'category_id': 1,
  'id': 439530},
 {'segmentation': [[413.5,
    228.18,
   

In [32]:
dir(coco)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'annToMask',
 'annToRLE',
 'anns',
 'catToImgs',
 'cats',
 'createIndex',
 'dataset',
 'download',
 'getAnnIds',
 'getCatIds',
 'getImgIds',
 'imgToAnns',
 'imgs',
 'info',
 'loadAnns',
 'loadCats',
 'loadImgs',
 'loadNumpyAnnotations',
 'loadRes',
 'showAnns']