In [None]:
import numpy as np
from pycocotools.coco import COCO
import random
import torchvision
from functools import reduce
from tqdm import tqdm
import os
import json
import itertools

In [None]:
coco = COCO(annotation_file='instances_train2017.json')

In [None]:
# Only consider concept combinations seen in at least 8/2/2 images in the train/val/test splits.
number_train = 8
number_val = 2
number_test = 2
k=2 # Number of concepts being combined
n_combinations = 1000 #seen combinations
n_combinations_zero_shot = 250 #unseen category combinations

#The train set contains 4/6ths of the images, and test and validation contain 1/6th each.
def get_train_subset(img_list):
    return [x for x in img_list if x%6==0 or x%6==1 or x%6==2 or x%6==3]

def get_val_subset(img_list):
    return [x for x in img_list if x%6==4]

def get_test_subset(img_list):
    return [x for x in img_list if x%6==5]

In [None]:
cats = coco.getCatIds()
hash_table_category_tuples = {}
combinations = []
while len(combinations) < n_combinations:
    print(len(combinations))
    cat_ids = sorted(random.sample(cats, k=k))
    img_ids = coco.getImgIds(catIds=cat_ids)
    img_ids_train = get_train_subset(img_ids)
    img_ids_val = get_val_subset(img_ids)
    img_ids_test = get_test_subset(img_ids)
    if len(img_ids_train) >= number_train and len(img_ids_val) >= number_val and len(img_ids_test) >= number_test:
        hash_code_cats = '_'.join([str(c) for c in cat_ids])
        if hash_code_cats not in hash_table_category_tuples:
            hash_table_category_tuples[hash_code_cats] = True
            combinations.append(cat_ids)
        else:
            print('already seen')
    else:
        print('not enough images')

# generate a new set of unseen category combinations for zero shot experiments 
seen = set([i for c in combinations for i in c])
combinations_zero_shot = []
while len(combinations_zero_shot) < n_combinations_zero_shot:
    print(len(combinations_zero_shot))
    cat_ids = sorted(random.sample(seen, k=k))
    img_ids = coco.getImgIds(catIds=cat_ids)
    img_ids_train = get_train_subset(img_ids)
    img_ids_val = get_val_subset(img_ids)
    img_ids_test = get_test_subset(img_ids)
    if len(img_ids_train) >= number_train and len(img_ids_val) >= number_val and len(img_ids_test) >= number_test:
        hash_code_cats = '_'.join([str(c) for c in cat_ids])
        if hash_code_cats not in hash_table_category_tuples:
            hash_table_category_tuples[hash_code_cats] = True
            combinations_zero_shot.append(cat_ids)
        else:
            print('already seen')
    else:
        print('not enough images')




In [None]:
# generate a set of impossible combinations for feasibility experiments
n_combinations_impossible = 250 #impossible combinations
combinations_impossible = []
hash_table_category_tuples_impossible = {}

while len(combinations_impossible) < n_combinations_impossible:
    print(len(combinations_impossible))
    cat_ids = sorted(random.sample(cats, k=k))
    img_ids = coco.getImgIds(catIds=cat_ids)
    img_ids_train = get_train_subset(img_ids)
    img_ids_val = get_val_subset(img_ids)
    img_ids_test = get_test_subset(img_ids)
    
    if len(img_ids) == 0:
        print([x['name'] for x in coco.loadCats(cat_ids)])
        hash_code_cats = '_'.join([str(c) for c in cat_ids])
        if hash_code_cats not in hash_table_category_tuples_impossible:
            hash_table_category_tuples_impossible[hash_code_cats] = True
            combinations_impossible.append(cat_ids)


In [None]:
with open('2_cat_combinations.json', 'w', encoding='utf-8') as f:
    json.dump(combinations, f, ensure_ascii=False, indent=4)
with open('2_cat_combinations_zero_shot.json', 'w', encoding='utf-8') as f:
    json.dump(combinations_zero_shot, f, ensure_ascii=False, indent=4)
with open('2_cat_combinations_impossible.json', 'w', encoding='utf-8') as f:
    json.dump(combinations_impossible, f, ensure_ascii=False, indent=4)

In [None]:
# save the train/test/val image splits
img_ids = coco.getImgIds()
img_ids_train = get_train_subset(img_ids)
img_ids_val = get_val_subset(img_ids)
img_ids_test = get_test_subset(img_ids)

with open('train_imgs.json', 'w', encoding='utf-8') as f:
    json.dump(img_ids_train, f, ensure_ascii=False, indent=4)

with open('test_imgs.json', 'w', encoding='utf-8') as f:
    json.dump(img_ids_test, f, ensure_ascii=False, indent=4)
    
with open('val_imgs.json', 'w', encoding='utf-8') as f:
    json.dump(img_ids_val, f, ensure_ascii=False, indent=4)

In [None]:
# Pre-compute the lists of images for all combinations of concepts. 
# We do this to speed up processing during training/evaluation.

img_ids_per_cat_train = {}
img_ids_per_cat_test = {}
img_ids_per_cat_val = {}
print(cats)
for cat in tqdm(cats):
    img_ids = coco.getImgIds(catIds=cat)
    img_ids_train = get_train_subset(img_ids)
    img_ids_val = get_val_subset(img_ids)
    img_ids_test = get_test_subset(img_ids)
    img_ids_per_cat_train[cat] = img_ids_train
    img_ids_per_cat_val[cat] = img_ids_val    
    img_ids_per_cat_test[cat] = img_ids_test

for cats in tqdm(combinations):
    img_ids = coco.getImgIds(catIds=cats)
    img_ids_train = get_train_subset(img_ids)
    img_ids_val = get_val_subset(img_ids)
    img_ids_test = get_test_subset(img_ids)
    hash_code = '_'.join([str(x) for x in cats])
    img_ids_per_cat_train[hash_code] = img_ids_train
    img_ids_per_cat_val[hash_code] = img_ids_val    
    img_ids_per_cat_test[hash_code] = img_ids_test

for cats in tqdm(combinations_zero_shot):
    img_ids = coco.getImgIds(catIds=cats)
    img_ids_train = get_train_subset(img_ids)
    img_ids_val = get_val_subset(img_ids)
    img_ids_test = get_test_subset(img_ids)
    hash_code = '_'.join([str(x) for x in cats])
    img_ids_per_cat_train[hash_code] = img_ids_train
    img_ids_per_cat_val[hash_code] = img_ids_val    
    img_ids_per_cat_test[hash_code] = img_ids_test


In [None]:
with open('2_img_ids_per_cats_train.json', 'w', encoding='utf-8') as f:
    json.dump(img_ids_per_cat_train, f, ensure_ascii=False, indent=4)

with open('2_img_ids_per_cats_val.json', 'w', encoding='utf-8') as f:
    json.dump(img_ids_per_cat_val, f, ensure_ascii=False, indent=4)
    
with open('2_img_ids_per_cats_test.json', 'w', encoding='utf-8') as f:
    json.dump(img_ids_per_cat_test, f, ensure_ascii=False, indent=4)

In [None]:
# Generate test cases for both seen and unseen combinations scenarios

test_cases_per_combination = 100
testset = []
valset = []
testset_zero_shot = []
valset_zero_shot = []

#testset
for cats in tqdm(combinations):
    image_ids_per_cat = []
    for cat in cats:
        img_ids = random.choices(get_test_subset(coco.getImgIds(catIds=cat)), k = test_cases_per_combination)
        image_ids_per_cat.append(img_ids)
    test_img_sets = [list(x) for x in zip(*image_ids_per_cat)]
    for test_img_set in test_img_sets:
        test_case = {}
        test_case['images'] = test_img_set
        test_case['categories'] = cats
        test_case['modalities'] = random.choices(['image', 'text'], k=len(test_img_set))
        testset.append(test_case)

#valset
for cats in tqdm(combinations):
    image_ids_per_cat = []
    for cat in cats:
        img_ids = random.choices(get_val_subset(coco.getImgIds(catIds=cat)), k = test_cases_per_combination)
        image_ids_per_cat.append(img_ids)
    test_img_sets = [list(x) for x in zip(*image_ids_per_cat)]
    for test_img_set in test_img_sets:
        test_case = {}
        test_case['images'] = test_img_set
        test_case['categories'] = cats
        test_case['modalities'] = random.choices(['image', 'text'], k=len(test_img_set))
        valset.append(test_case)

#testset zero shot
for cats in tqdm(combinations_zero_shot):
    image_ids_per_cat = []
    for cat in cats:
        img_ids = random.choices(get_test_subset(coco.getImgIds(catIds=cat)), k = test_cases_per_combination)
        image_ids_per_cat.append(img_ids)
    test_img_sets = [list(x) for x in zip(*image_ids_per_cat)]
    for test_img_set in test_img_sets:
        test_case = {}
        test_case['images'] = test_img_set
        test_case['categories'] = cats
        test_case['modalities'] = random.choices(['image', 'text'], k=len(test_img_set))
        testset_zero_shot.append(test_case)

#valset zero shot
for cats in tqdm(combinations_zero_shot):
    image_ids_per_cat = []
    for cat in cats:
        img_ids = random.choices(get_val_subset(coco.getImgIds(catIds=cat)), k = test_cases_per_combination)
        image_ids_per_cat.append(img_ids)
    test_img_sets = [list(x) for x in zip(*image_ids_per_cat)]
    for test_img_set in test_img_sets:
        test_case = {}
        test_case['images'] = test_img_set
        test_case['categories'] = cats
        test_case['modalities'] = random.choices(['image', 'text'], k=len(test_img_set))
        valset_zero_shot.append(test_case)

In [None]:
# Format the test cases in a way that will be easily parsed during training/evaluation

formatted_testset = {}
for elem in testset:
    hash_code = '_'.join(elem['modalities'])
    if hash_code not in formatted_testset:
        formatted_testset[hash_code] = []
    formatted_testset[hash_code].append({
        'images': elem['images'],
        'categories': elem['categories']
    })
    
formatted_valset = {}
for elem in valset:
    hash_code = '_'.join(elem['modalities'])
    if hash_code not in formatted_valset:
        formatted_valset[hash_code] = []
    formatted_valset[hash_code].append({
        'images': elem['images'],
        'categories': elem['categories']
    })

formatted_testset_zero_shot = {}
for elem in testset_zero_shot:
    hash_code = '_'.join(elem['modalities'])
    if hash_code not in formatted_testset_zero_shot:
        formatted_testset_zero_shot[hash_code] = []
    formatted_testset_zero_shot[hash_code].append({
        'images': elem['images'],
        'categories': elem['categories']
    })
    
formatted_valset_zero_shot = {}
for elem in valset_zero_shot:
    hash_code = '_'.join(elem['modalities'])
    if hash_code not in formatted_valset_zero_shot:
        formatted_valset_zero_shot[hash_code] = []
    formatted_valset_zero_shot[hash_code].append({
        'images': elem['images'],
        'categories': elem['categories']
    })


In [None]:
with open('2_test_cases.json', 'w', encoding='utf-8') as f:
    json.dump(formatted_testset, f, ensure_ascii=False, indent=4)
with open('2_val_cases.json', 'w', encoding='utf-8') as f:
    json.dump(formatted_valset, f, ensure_ascii=False, indent=4)
with open('2_test_cases_zero_shot.json', 'w', encoding='utf-8') as f:
    json.dump(formatted_testset_zero_shot, f, ensure_ascii=False, indent=4)
with open('2_val_cases_zero_shot.json', 'w', encoding='utf-8') as f:
    json.dump(formatted_valset_zero_shot, f, ensure_ascii=False, indent=4)