In [1]:

from nltk.corpus import stopwords
from nltk.tokenize import RegexpTokenizer
import numpy as np
import os
import json
from pathlib import Path
from pycocotools.coco import COCO

from Config import get_data_dir


In [2]:

def get_maps(mode, verbose = False):
    
    if mode == 'train':
        threshold = 1000
    else:
        threshold = 20
    
    coco = COCO('/home/gregory/Datasets/COCO/annotations/instances_{}2017.json'.format(mode))
    captions = COCO('/home/gregory/Datasets/COCO/annotations/captions_{}2017.json'.format(mode))
    
    # Get the objects associated with each image
    id2objs = {}

    cats = {x['id']: x['name'] for x in coco.loadCats(coco.getCatIds())}

    for img_id, img_obj in coco.anns.items():
        i = img_obj['image_id']
        o = cats[img_obj['category_id']]

        if i not in id2objs:
            id2objs[i] = [o]
        elif o not in id2objs[i]: # We don't care about how many of each object there are
            id2objs[i].append(o)
            
    # Get the words associated with each image via its caption
    id2words = {}

    stop_words = set(stopwords.words('english'))  
    tokenizer = RegexpTokenizer(r'\w+')

    for img_id, img_obj in captions.anns.items():
        i = img_obj['image_id']
        c = img_obj['caption']

        c = c.lower()
        words = tokenizer.tokenize(c)  
        words = [w for w in words if not w in stop_words]  

        if i not in id2words:
            id2words[i] = []

        for w in words:
            if w not in id2words[i]:
                id2words[i].append(w)
            
    # Make sure that both of those mappings have the same keys
    k1 = [key for key in id2objs]
    k2 = [key for key in id2words]

    just_obj = np.setdiff1d(k1, k2)
    just_cap = np.setdiff1d(k2, k1)

    for key in just_obj:
        del id2objs[key]

    for key in just_cap:
        del id2words[key] 
    
    # Get a list of all of the words used in the caption and their counts
    word2count = {}
    for i in id2words.keys():
        words = id2words[i]
        for w in words:
            if w not in word2count:
                word2count[w] = 0
            word2count[w] += 1
            
    # Find the most common words
    common_words = []
    for w in word2count.keys():
        if word2count[w] >= threshold:
            common_words.append(w)  
            
    if verbose:
        print(common_words)
            
    # Map the common words to images
    word2ids = {}
    for word in common_words:
        word2ids[word] = []
        for img_id in id2words.keys():
            if word in id2words[img_id]:
                word2ids[word].append(img_id)
                
    return id2objs, id2words, word2ids
        

In [3]:
id2objs, id2words, word2ids = get_maps('train', verbose = True)

id2objs_val, id2words_val, word2ids_val = get_maps('val')
            

loading annotations into memory...
Done (t=11.85s)
creating index...
index created!
loading annotations into memory...
Done (t=0.70s)
creating index...
index created!
['bicycle', 'clock', 'front', 'bike', 'black', 'metal', 'inside', 'room', 'blue', 'walls', 'white', 'sink', 'door', 'small', 'bathroom', 'wall', 'boat', 'painted', 'baby', 'car', 'parked', 'behind', 'two', 'cars', 'sidewalk', 'street', 'city', 'bench', 'parking', 'along', 'couple', 'busy', 'large', 'passenger', 'airplane', 'flying', 'air', 'plane', 'taking', 'cloudy', 'sky', 'red', 'toilet', 'full', 'little', 'decorated', 'many', 'colorful', 'long', 'empty', 'home', 'kitchen', 'picture', 'looking', 'area', 'refrigerator', 'gray', 'stove', 'counter', 'various', 'items', 'cabinets', 'several', 'across', 'open', 'box', 'four', 'food', 'filled', 'green', 'vegetables', 'bananas', 'purple', 'old', 'station', 'colored', 'sitting', 'beside', 'road', 'surfboard', 'top', 'side', 'next', 'silver', 'riding', 'group', 'motorcycle', 'd

In [4]:

def get_obj_counts(img_ids):
    out = {}
    num_imgs = len(img_ids)
    for img_id in img_ids:
        objs = id2objs[img_id]
        for obj in objs:
            if obj not in out:
                out[obj] = 0
            out[obj] += 1 / num_imgs
    return out

def compare_words(word1, word2, word2ids):
    
    counts1 = get_obj_counts(word2ids[word1])
    counts2 = get_obj_counts(word2ids[word2])
    
    k1 = [key for key in counts1]
    k2 = [key for key in counts2]
    
    keys = list(set(k1).union(set(k2)))
    
    diff = {}
    for key in keys:
        if key in counts1:
            v1 = counts1[key]
        else:
            v1 = 0.0
        
        if key in counts2:
            v2 = counts2[key]
        else:
            v2 = 0.0
        
        diff[key] = v1 - v2
        
    diff_sorted = sorted(diff.items(), key = lambda x: np.abs(x[1]), reverse = True)
        
    return diff_sorted

def get_splits(label1, label2, spurious, word2ids, id2objs):
    ids1 = word2ids[label1]
    ids2 = word2ids[label2]

    just1 = np.setdiff1d(ids1, ids2)
    just2 = np.setdiff1d(ids2, ids1)
   
    splits = {}
    splits['1s'] = [] # Answer is 1 (eg, label 1) and Spurious is present
    splits['1ns'] = [] # Answer is 1 and no Spurious
    splits['0s'] = [] # Answer is 0 (eg, label 2) and Spurious
    splits['0ns'] = [] # Answer is 0 and no Spurious
    
    for img_id in just1:
        if spurious in id2objs[img_id]:
            splits['1s'].append(str(img_id))
        else:
            splits['1ns'].append(str(img_id))
    
    
    for img_id in just2:
        if spurious in id2objs[img_id]:
            splits['0s'].append(str(img_id))
        else:
            splits['0ns'].append(str(img_id))
            
    return splits

In [5]:

# cloudy, sunny, snowy
# mountain, field, beach, river
# runway/airport, street
# kitchen, bathroom

label1 = 'runway'
label2 = 'street'

print(compare_words(label1, label2, word2ids)[:10])


[('airplane', 0.984017654083118), ('car', -0.3438270646869643), ('person', -0.285584405587939), ('traffic light', -0.24937485988238162), ('bus', -0.17644913683344116), ('truck', 0.13954683703286622), ('handbag', -0.13596361169261262), ('motorcycle', -0.12464757660148773), ('bicycle', -0.11915601772811545), ('stop sign', -0.08854591882873832)]


In [6]:


spurious = 'airplane'

mkdir = True

if mkdir:
    out_dir = '{}/{}-{}/{}'.format(get_data_dir(), label1, label2, spurious)
    os.system('rm -rf {}'.format(out_dir))
    Path(out_dir).mkdir(parents = True, exist_ok = True)
    os.system('mkdir {}/train'.format(out_dir))
    os.system('mkdir {}/val'.format(out_dir))

print('Train Sizes')
splits = get_splits(label1, label2, spurious, word2ids, id2objs)
for key in splits:
    print(key, len(splits[key]))
print()

if mkdir:
    with open('{}/train/splits.json'.format(out_dir), 'w') as f:
        json.dump(splits, f)
    
 
print('Val Sizes')
splits = get_splits(label1, label2, spurious, word2ids_val, id2objs_val)
for key in splits:
    print(key, len(splits[key]))

if mkdir:
    with open('{}/val/splits.json'.format(out_dir), 'w') as f:
        json.dump(splits, f)

Train Sizes
1s 1120
1ns 14
0s 33
0ns 12510

Val Sizes
1s 34
1ns 2
0s 2
0ns 537
