In [5]:
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
!unzip annotations_trainval2017.zip

--2025-08-07 16:28:33--  http://images.cocodataset.org/annotations/annotations_trainval2017.zip
Resolving images.cocodataset.org (images.cocodataset.org)... 3.5.29.225, 3.5.30.39, 3.5.21.80, ...
Connecting to images.cocodataset.org (images.cocodataset.org)|3.5.29.225|:80... connected.
HTTP request sent, awaiting response... 503 Slow Down
2025-08-07 16:28:33 ERROR 503: Slow Down.

unzip:  cannot find or open annotations_trainval2017.zip, annotations_trainval2017.zip.zip or annotations_trainval2017.zip.ZIP.


In [44]:
from pycocotools.coco import COCO
import pandas as pd
import json, os, random

# --- paths ---
ANNO_DIR = '../data/annotations_trainval2017/annotations'
instances_json = f'{ANNO_DIR}/instances_train2017.json'
captions_json  = f'{ANNO_DIR}/captions_train2017.json'
SAVE_DIR = '../data'
os.makedirs(SAVE_DIR, exist_ok=True)

# --- load annotations ---
coco = COCO(instances_json)
cap  = COCO(captions_json)

cat_id = coco.getCatIds(catNms=['cat'])[0]  # 17
dog_id = coco.getCatIds(catNms=['dog'])[0]  # 18

# map: image_id -> [captions...]
caps_by_img = {}
for ann in cap.dataset['annotations']:
    caps_by_img.setdefault(ann['image_id'], []).append(ann['caption'])

# build master table
records = []
for img in coco.loadImgs(coco.getImgIds()):
    img_id = img['id']
    has_cat = int(len(coco.getAnnIds(imgIds=[img_id], catIds=[cat_id], iscrowd=None)) > 0)
    has_dog = int(len(coco.getAnnIds(imgIds=[img_id], catIds=[dog_id], iscrowd=None)) > 0)
    captions = caps_by_img.get(img_id, [])
    records.append({
        'image_id': img_id,
        'file_name': img['file_name'],
        'captions': json.dumps(captions, ensure_ascii=False),
        'has_cat': has_cat,
        'has_dog': has_dog,
    })

df = pd.DataFrame(records)

# ---------- disjoint split ----------
SEED = 1337
rng = random.Random(SEED)
all_ids = df['image_id'].tolist()
rng.shuffle(all_ids)
half = len(all_ids) // 2
cat_pool_ids = set(all_ids[:half])      # only used to build cat_df
dog_pool_ids = set(all_ids[half:])      # only used to build dog_df
assert cat_pool_ids.isdisjoint(dog_pool_ids)

# ---------- build raw cat/dog tables (before balancing) ----------
cat_pool = df[df['image_id'].isin(cat_pool_ids)].copy()
dog_pool = df[df['image_id'].isin(dog_pool_ids)].copy()

cat_df_raw = cat_pool[['file_name', 'captions', 'has_cat']].rename(columns={'has_cat':'label'})
dog_df_raw = dog_pool[['file_name', 'captions', 'has_dog']].rename(columns={'has_dog':'label'})

# ---------- balance to ~1:1 by downsampling majority ----------
def balanced_1to1(df_bin, seed=SEED):
    pos = df_bin[df_bin['label'] == 1]
    neg = df_bin[df_bin['label'] == 0]
    if len(pos) == 0 or len(neg) == 0:
        # nothing to balance; return as-is (or raise if you prefer)
        return df_bin.copy()
    n = min(len(pos), len(neg))
    df_bal = pd.concat([
        pos.sample(n, random_state=seed),
        neg.sample(n, random_state=seed)
    ], axis=0).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    return df_bal

cat_df = balanced_1to1(cat_df_raw)
dog_df = balanced_1to1(dog_df_raw)

# ---------- sanity checks ----------
assert set(cat_df['file_name']).isdisjoint(set(dog_df['file_name'])), "Overlap detected!"
print(f"cat_df: {len(cat_df)} rows | pos={int(cat_df.label.sum())}, neg={len(cat_df)-int(cat_df.label.sum())}")
print(f"dog_df: {len(dog_df)} rows | pos={int(dog_df.label.sum())}, neg={len(dog_df)-int(dog_df.label.sum())}")

# ---------- save ----------
cat_df.to_csv(f'{SAVE_DIR}/coco_cat_binary_with_captions_balanced.csv', index=False)
dog_df.to_csv(f'{SAVE_DIR}/coco_dog_binary_with_captions_balanced.csv', index=False)


loading annotations into memory...
Done (t=15.90s)
creating index...
index created!
loading annotations into memory...
Done (t=0.73s)
creating index...
index created!
cat_df: 4066 rows | pos=2033, neg=2033
dog_df: 4444 rows | pos=2222, neg=2222


In [43]:
cat_df

Unnamed: 0,file_name,captions,label
0,000000391895.jpg,"[""A man with a red helmet on a small moped on ...",0
1,000000522418.jpg,"[""A woman wearing a net on her head cutting a ...",0
2,000000184613.jpg,"[""A child holding a flowered umbrella and pett...",0
3,000000318219.jpg,"[""A young boy standing in front of a computer ...",0
4,000000554625.jpg,"[""a boy wearing headphones using one computer ...",0
...,...,...,...
118282,000000444010.jpg,"[""A group of friends sitting down at a table s...",0
118283,000000565004.jpg,"[""wine being poured into a glass over a table""...",0
118284,000000516168.jpg,"[""A man is standing behind a bar with glases"",...",0
118285,000000547503.jpg,"[""A group of men sitting at a bar having drink...",0


In [21]:
dog_df.label.value_counts()

label
0    113902
1      4385
Name: count, dtype: int64

In [38]:
coco.loadAnns(coco.getAnnIds(imgIds=[img_id],iscrowd=None))

[{'segmentation': [[303.0,
    227.0,
    303.0,
    221.0,
    296.0,
    215.0,
    292.0,
    208.0,
    297.0,
    199.0,
    301.0,
    189.0,
    308.0,
    179.0,
    316.0,
    169.0,
    325.0,
    166.0,
    326.0,
    166.0,
    323.0,
    159.0,
    330.0,
    141.0,
    335.0,
    132.0,
    339.0,
    130.0,
    351.0,
    131.0,
    358.0,
    139.0,
    359.0,
    146.0,
    359.0,
    158.0,
    363.2,
    173.05,
    371.11,
    178.99,
    374.41,
    182.29,
    376.39,
    196.14,
    377.05,
    204.72,
    380.35,
    213.95,
    382.99,
    220.55,
    384.31,
    231.77,
    382.33,
    237.05,
    375.07,
    236.39,
    370.45,
    231.77,
    369.79,
    225.83,
    361.88,
    231.11,
    348.68,
    231.11,
    307.78,
    235.07,
    306.46,
    229.79,
    303.16,
    223.19]],
  'area': 5945.7057,
  'iscrowd': 0,
  'image_id': 475546,
  'bbox': [292.0, 130.0, 92.31, 107.05],
  'category_id': 1,
  'id': 439530},
 {'segmentation': [[413.5,
    228.18,
   

In [32]:
dir(coco)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'annToMask',
 'annToRLE',
 'anns',
 'catToImgs',
 'cats',
 'createIndex',
 'dataset',
 'download',
 'getAnnIds',
 'getCatIds',
 'getImgIds',
 'imgToAnns',
 'imgs',
 'info',
 'loadAnns',
 'loadCats',
 'loadImgs',
 'loadNumpyAnnotations',
 'loadRes',
 'showAnns']