# Analysis

In [1]:
import json
import statistics
import random
import string

import pandas as pd
import numpy as np

In [2]:
data_filename = "full.json"

data_dict = json.load(open(data_filename))

data_dict.keys()

dict_keys(['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'])

In [3]:
stats_dict = {}

#columns:
# 1. num_images
# 2. num_males
# 3. num_females
# 4. cooccuring

for target, imgs in data_dict.items():
    stats_dict[target] = {}
    stats_dict[target]["num_images"] = len(imgs)
    stats_dict[target]["num_males"] = 0
    stats_dict[target]["num_females"] = 0
    stats_dict[target]["cooccuring"] = {}
    for img, img_data in imgs.items():
        if img_data["gender"] == "male":
            stats_dict[target]["num_males"] += 1
        else:
            stats_dict[target]["num_females"] += 1
        for other_target in img_data["other_targets"]:
            # don't count person as a target because not useful
            if other_target == "person":
                continue
            
            if other_target in stats_dict[target]["cooccuring"]:
                stats_dict[target]["cooccuring"][other_target] += 1
            else:
                stats_dict[target]["cooccuring"][other_target] = 1
    stats_dict[target]["cooccuring"] = list(stats_dict[target]["cooccuring"].items())
    stats_dict[target]["cooccuring"].sort(key=lambda x: x[1], reverse=True)

list(stats_dict.items())[2]
            

('car',
 {'num_images': 3357,
  'num_males': 2461,
  'num_females': 896,
  'cooccuring': [('truck', 815),
   ('handbag', 660),
   ('traffic light', 504),
   ('skateboard', 497),
   ('bicycle', 493),
   ('motorcycle', 447),
   ('backpack', 413),
   ('bus', 337),
   ('umbrella', 332),
   ('cell phone', 328),
   ('bench', 316),
   ('chair', 180),
   ('tie', 167),
   ('dog', 157),
   ('sports ball', 156),
   ('frisbee', 152),
   ('horse', 152),
   ('fire hydrant', 133),
   ('stop sign', 132),
   ('suitcase', 122),
   ('dining table', 111),
   ('cup', 107),
   ('tennis racket', 101),
   ('bottle', 97),
   ('kite', 95),
   ('parking meter', 88),
   ('potted plant', 81),
   ('baseball glove', 67),
   ('baseball bat', 66),
   ('surfboard', 62),
   ('clock', 45),
   ('skis', 41),
   ('train', 41),
   ('cow', 40),
   ('pizza', 39),
   ('banana', 36),
   ('book', 36),
   ('elephant', 36),
   ('knife', 35),
   ('boat', 34),
   ('donut', 34),
   ('bowl', 32),
   ('bird', 31),
   ('fork', 26),
   ('

In [4]:
#convert dictionary to dataframe

df = pd.DataFrame.from_dict(stats_dict, orient="index")
df

Unnamed: 0,num_images,num_males,num_females,cooccuring
person,33912,24086,9826,"[(chair, 4157), (car, 3357), (tennis racket, 2..."
bicycle,1312,964,348,"[(car, 493), (handbag, 319), (backpack, 310), ..."
car,3357,2461,896,"[(truck, 815), (handbag, 660), (traffic light,..."
motorcycle,1283,1087,196,"[(car, 447), (truck, 201), (bicycle, 179), (ha..."
airplane,180,152,28,"[(truck, 41), (handbag, 23), (suitcase, 23), (..."
...,...,...,...,...
vase,423,194,229,"[(chair, 179), (cup, 164), (dining table, 155)..."
scissors,236,129,107,"[(chair, 48), (bottle, 38), (cup, 38), (dining..."
teddy bear,349,136,213,"[(chair, 78), (couch, 51), (bed, 50), (handbag..."
hair drier,72,18,54,"[(sink, 16), (bottle, 12), (chair, 11), (couch..."


In [5]:
#sort dataframe by number of images
df.sort_values(by=["num_images"], ascending=False)

Unnamed: 0,num_images,num_males,num_females,cooccuring
person,33912,24086,9826,"[(chair, 4157), (car, 3357), (tennis racket, 2..."
chair,4157,2504,1653,"[(dining table, 1459), (cup, 947), (bottle, 77..."
car,3357,2461,896,"[(truck, 815), (handbag, 660), (traffic light,..."
tennis racket,2926,1775,1151,"[(sports ball, 1566), (chair, 616), (bench, 22..."
handbag,2896,1266,1630,"[(car, 660), (backpack, 592), (cell phone, 551..."
...,...,...,...,...
broccoli,88,41,47,"[(bowl, 45), (dining table, 41), (cup, 34), (k..."
hair drier,72,18,54,"[(sink, 16), (bottle, 12), (chair, 11), (couch..."
bear,37,23,14,"[(chair, 6), (handbag, 4), (dining table, 4), ..."
toaster,32,18,14,"[(bottle, 19), (oven, 15), (cup, 12), (microwa..."


In [6]:
df['min_num_gender'] = df[['num_males', 'num_females']].min(axis=1)
df.reindex(columns = ["num_images", "num_males", "num_females", "min_num_gender", "cooccuring"])
df.sort_values(by=["num_images"], ascending=False)

Unnamed: 0,num_images,num_males,num_females,cooccuring,min_num_gender
person,33912,24086,9826,"[(chair, 4157), (car, 3357), (tennis racket, 2...",9826
chair,4157,2504,1653,"[(dining table, 1459), (cup, 947), (bottle, 77...",1653
car,3357,2461,896,"[(truck, 815), (handbag, 660), (traffic light,...",896
tennis racket,2926,1775,1151,"[(sports ball, 1566), (chair, 616), (bench, 22...",1151
handbag,2896,1266,1630,"[(car, 660), (backpack, 592), (cell phone, 551...",1266
...,...,...,...,...,...
broccoli,88,41,47,"[(bowl, 45), (dining table, 41), (cup, 34), (k...",41
hair drier,72,18,54,"[(sink, 16), (bottle, 12), (chair, 11), (couch...",18
bear,37,23,14,"[(chair, 6), (handbag, 4), (dining table, 4), ...",14
toaster,32,18,14,"[(bottle, 19), (oven, 15), (cup, 12), (microwa...",14


In [7]:
#export dataframe to csv
df.sort_values(by=["num_images"], ascending=False).to_csv("mscoco_gender_stats.csv")

# Multi-Label to Single-Label

In [8]:
data_list = []
for target, imgs in data_dict.items():
    if target == "person":
        continue
    data_list.append((target, set(imgs.keys())))

data_list

[('bicycle',
  {'466247',
   '396137',
   '190313',
   '296901',
   '67470',
   '540206',
   '215738',
   '358361',
   '239828',
   '399922',
   '309526',
   '225686',
   '16123',
   '575227',
   '21647',
   '490847',
   '24621',
   '440087',
   '99129',
   '235281',
   '366789',
   '428111',
   '36273',
   '67208',
   '566064',
   '497415',
   '218305',
   '279278',
   '514191',
   '82246',
   '72156',
   '414510',
   '20172',
   '124859',
   '113212',
   '578705',
   '413556',
   '400822',
   '76942',
   '558213',
   '117371',
   '320467',
   '232329',
   '555002',
   '25506',
   '386326',
   '380447',
   '580191',
   '521689',
   '530758',
   '82312',
   '7524',
   '400915',
   '467000',
   '12993',
   '437485',
   '562876',
   '225329',
   '396853',
   '340998',
   '313502',
   '43773',
   '471572',
   '258552',
   '73527',
   '455393',
   '239656',
   '505501',
   '437073',
   '505579',
   '160596',
   '468505',
   '575088',
   '504021',
   '373212',
   '260627',
   '58143',
   '1

In [9]:
data_list.sort(key=lambda x: len(x[1]), reverse=True)
data_list

[('chair',
  {'376983',
   '461567',
   '1591',
   '413623',
   '31082',
   '572907',
   '306335',
   '575768',
   '560323',
   '66468',
   '412281',
   '219880',
   '577819',
   '143636',
   '3109',
   '513053',
   '175804',
   '175102',
   '331370',
   '289012',
   '333899',
   '102655',
   '66172',
   '347542',
   '194184',
   '226119',
   '107686',
   '364941',
   '103348',
   '253955',
   '313541',
   '528411',
   '288964',
   '9226',
   '291366',
   '437720',
   '383890',
   '282599',
   '102256',
   '379332',
   '204147',
   '303429',
   '204219',
   '201111',
   '294564',
   '84546',
   '42528',
   '385016',
   '516416',
   '170476',
   '452380',
   '343570',
   '71244',
   '561324',
   '158956',
   '416478',
   '194310',
   '189845',
   '280377',
   '64263',
   '360306',
   '347263',
   '242931',
   '272082',
   '146078',
   '11661',
   '94268',
   '87101',
   '312889',
   '114510',
   '167028',
   '197525',
   '252036',
   '357526',
   '426975',
   '265934',
   '563839',
   '

In [10]:
from collections import Counter
from itertools import chain

# get symmetric difference of a list of sets aka. elements with frequency 1
def get_sym_diff(sets):
    freq = Counter(chain.from_iterable(sets))
    res = [k for k, v in freq.items() if v == 1]
    return res

In [11]:
selected_image_sets = []
saved_imgs_len = 0

for target, img_set in data_list:
    sym_diff = get_sym_diff([*([p[1] for p in selected_image_sets]), img_set])
    if len(sym_diff) > saved_imgs_len:
        saved_imgs_len = len(sym_diff)
        selected_image_sets.append((target, img_set))

selected_images = get_sym_diff(([p[1] for p in selected_image_sets]))
print(len(selected_images), saved_imgs_len)
selected_targets = [p[0] for p in selected_image_sets]
print(len(selected_targets), selected_targets)

    

20028 20028
35 ['chair', 'car', 'tennis racket', 'handbag', 'skateboard', 'cell phone', 'surfboard', 'bottle', 'backpack', 'tie', 'skis', 'motorcycle', 'baseball glove', 'couch', 'frisbee', 'horse', 'snowboard', 'kite', 'pizza', 'bed', 'boat', 'banana', 'donut', 'elephant', 'train', 'hot dog', 'cow', 'toothbrush', 'scissors', 'toilet', 'giraffe', 'airplane', 'sheep', 'hair drier', 'bear']


In [12]:
filtered_data_dict = {}
count = 0
for target in selected_targets:
    filtered_data_dict[target] = {}
    for img_id, img_data in data_dict[target].items():
        if img_id in selected_images:
            assert(not any(other_target in selected_targets for other_target in img_data["other_targets"]))
            filtered_data_dict[target][img_id] = img_data
            count += 1

count

20028

In [13]:
filtered_stats_dict = {}

#columns:
# 1. num_images
# 2. num_males
# 3. num_females
# 4. cooccuring

for target, imgs in filtered_data_dict.items():
    filtered_stats_dict[target] = {}
    filtered_stats_dict[target]["num_images"] = len(imgs)
    filtered_stats_dict[target]["num_males"] = 0
    filtered_stats_dict[target]["num_females"] = 0
    filtered_stats_dict[target]["cooccuring"] = {}
    for img, img_data in imgs.items():
        if img_data["gender"] == "male":
            filtered_stats_dict[target]["num_males"] += 1
        else:
            filtered_stats_dict[target]["num_females"] += 1
        for other_target in img_data["other_targets"]:
            # don't count person as a target because not useful
            if other_target == "person":
                continue
            
            if other_target in filtered_stats_dict[target]["cooccuring"]:
                filtered_stats_dict[target]["cooccuring"][other_target] += 1
            else:
                filtered_stats_dict[target]["cooccuring"][other_target] = 1
    filtered_stats_dict[target]["cooccuring"] = list(filtered_stats_dict[target]["cooccuring"].items())
    filtered_stats_dict[target]["cooccuring"].sort(key=lambda x: x[1], reverse=True)

list(filtered_stats_dict.items())[0]


('chair',
 {'num_images': 875,
  'num_males': 469,
  'num_females': 406,
  'cooccuring': [('dining table', 381),
   ('cup', 223),
   ('cake', 155),
   ('laptop', 140),
   ('tv', 134),
   ('bowl', 132),
   ('knife', 131),
   ('book', 122),
   ('spoon', 107),
   ('remote', 106),
   ('fork', 106),
   ('potted plant', 78),
   ('wine glass', 74),
   ('umbrella', 65),
   ('mouse', 61),
   ('keyboard', 53),
   ('sandwich', 50),
   ('bench', 46),
   ('clock', 45),
   ('vase', 42),
   ('baseball bat', 39),
   ('refrigerator', 38),
   ('dog', 37),
   ('sports ball', 30),
   ('oven', 29),
   ('teddy bear', 29),
   ('cat', 26),
   ('suitcase', 26),
   ('sink', 21),
   ('microwave', 19),
   ('apple', 17),
   ('orange', 14),
   ('truck', 14),
   ('carrot', 12),
   ('bicycle', 10),
   ('bird', 10),
   ('broccoli', 9),
   ('bus', 7),
   ('fire hydrant', 2),
   ('zebra', 2),
   ('toaster', 2),
   ('traffic light', 2),
   ('stop sign', 1),
   ('parking meter', 1)]})

In [14]:
df = pd.DataFrame.from_dict(filtered_stats_dict, orient="index")
df.sort_values("num_images")

Unnamed: 0,num_images,num_males,num_females,cooccuring
bear,23,16,7,"[(bicycle, 1), (dog, 1), (truck, 1)]"
hair drier,32,5,27,"[(sink, 4), (dog, 3), (bowl, 2), (book, 2), (t..."
sheep,98,77,21,"[(dog, 17), (truck, 5), (bench, 3), (bus, 3), ..."
airplane,106,99,7,"[(truck, 21), (suitcase, 8), (bench, 4), (bird..."
scissors,111,62,49,"[(cup, 8), (knife, 8), (book, 7), (bowl, 7), (..."
giraffe,127,65,62,"[(bench, 6), (bird, 5), (cup, 2), (carrot, 2),..."
toilet,134,94,40,"[(sink, 47), (book, 6), (cup, 5), (bowl, 4), (..."
cow,151,109,42,"[(truck, 9), (dog, 8), (bicycle, 6), (bird, 5)..."
train,153,129,24,"[(traffic light, 11), (suitcase, 10), (bench, ..."
hot dog,164,106,58,"[(sandwich, 36), (dining table, 26), (cup, 16)..."


In [15]:
df['min_num_gender'] = df[['num_males', 'num_females']].min(axis=1)
df.reindex(columns = ["num_images", "num_males", "num_females", "min_num_gender", "cooccuring"])
df.sort_values(by=["num_images"], ascending=False)

Unnamed: 0,num_images,num_males,num_females,cooccuring,min_num_gender
surfboard,2127,1824,303,"[(dog, 37), (bird, 26), (umbrella, 13), (bicyc...",303
tennis racket,2074,1244,830,"[(sports ball, 1109), (bench, 101), (clock, 38...",830
skateboard,1909,1868,41,"[(bench, 155), (bicycle, 88), (truck, 29), (um...",41
tie,1163,1048,115,"[(cup, 54), (book, 41), (clock, 33), (dining t...",115
skis,1081,811,270,"[(dog, 21), (bench, 15), (stop sign, 2), (truc...",270
baseball glove,989,961,28,"[(baseball bat, 592), (sports ball, 480), (ben...",28
cell phone,888,540,348,"[(cup, 69), (laptop, 46), (bench, 44), (dining...",348
chair,875,469,406,"[(dining table, 381), (cup, 223), (cake, 155),...",406
frisbee,797,637,160,"[(dog, 80), (bench, 47), (umbrella, 16), (spor...",160
snowboard,745,693,52,"[(bench, 6), (truck, 3), (fire hydrant, 1), (r...",52


In [16]:
#export dataframe to csv
df.sort_values(by=["num_images"], ascending=False).to_csv("single_label/mscoco_single_gender_stats.csv")

In [17]:
df["min_num_gender"].sum()

4829

In [18]:
full_annotations = json.load(open("full_annotations.json"))

selected_annotations = {}

for img_id, img_data in full_annotations.items():
    if img_id in selected_images:
        selected_annotations[img_id] = img_data.copy()

len(selected_annotations)


20028

In [19]:
# export selected annotations to json
with open("single_label/mscoco_single_annotations.json", "w") as outfile:
    json.dump(selected_annotations, outfile, indent=4)

# Get top-10 targets for single-label

In [20]:
top_10_targets_balanced = df.sort_values(by=["min_num_gender"], ascending=False).head(10).index.tolist()
top_10_targets_balanced

['tennis racket',
 'chair',
 'cell phone',
 'surfboard',
 'bottle',
 'skis',
 'horse',
 'bed',
 'car',
 'couch']

In [21]:
top_10_images = get_sym_diff(([p[1] for p in selected_image_sets if p[0] in top_10_targets_balanced]))
len(top_10_images)

14999

In [22]:
def filter_data_dict(selected_targets, selected_images, data_dict):
    filtered_data_dict = {}
    count = 0
    for target in selected_targets:
        filtered_data_dict[target] = {}
        for img_id, img_data in data_dict[target].items():
            if img_id in selected_images:
                assert(not any(other_target in selected_targets for other_target in img_data["other_targets"]))
                filtered_data_dict[target][img_id] = img_data
                count += 1
    print("number of images: ", count)
    return filtered_data_dict

#columns:
# 1. num_images
# 2. num_males
# 3. num_females
# 4. cooccuring
def get_stats_dict(data_dict):
    filtered_stats_dict = {}
    for target, imgs in data_dict.items():
        filtered_stats_dict[target] = {}
        filtered_stats_dict[target]["num_images"] = len(imgs)
        filtered_stats_dict[target]["num_males"] = 0
        filtered_stats_dict[target]["num_females"] = 0
        filtered_stats_dict[target]["cooccuring"] = {}
        for img, img_data in imgs.items():
            if img_data["gender"] == "male":
                filtered_stats_dict[target]["num_males"] += 1
            else:
                filtered_stats_dict[target]["num_females"] += 1
            for other_target in img_data["other_targets"]:
                # don't count person as a target because not useful
                if other_target == "person":
                    continue
                
                if other_target in filtered_stats_dict[target]["cooccuring"]:
                    filtered_stats_dict[target]["cooccuring"][other_target] += 1
                else:
                    filtered_stats_dict[target]["cooccuring"][other_target] = 1
        filtered_stats_dict[target]["cooccuring"] = list(filtered_stats_dict[target]["cooccuring"].items())
        filtered_stats_dict[target]["cooccuring"].sort(key=lambda x: x[1], reverse=True)
    return filtered_stats_dict
    

In [23]:
top_10_filtered_dict = filter_data_dict(top_10_targets_balanced, top_10_images, data_dict)
top_10_stats_dict = get_stats_dict(top_10_filtered_dict)

number of images:  14999


In [24]:
df = pd.DataFrame.from_dict(top_10_stats_dict, orient="index")
df.sort_values("num_images", ascending=False)

Unnamed: 0,num_images,num_males,num_females,cooccuring
car,2477,1895,582,"[(truck, 614), (skateboard, 471), (handbag, 46..."
surfboard,2347,2017,330,"[(boat, 89), (kite, 82), (dog, 46), (backpack,..."
tennis racket,2171,1309,862,"[(sports ball, 1149), (bench, 120), (backpack,..."
chair,1900,1179,721,"[(dining table, 700), (cup, 417), (tie, 218), ..."
skis,1553,1205,348,"[(backpack, 311), (snowboard, 157), (dog, 30),..."
cell phone,1497,863,634,"[(handbag, 316), (backpack, 139), (bench, 113)..."
bottle,1232,727,505,"[(cup, 359), (bowl, 351), (dining table, 229),..."
horse,837,558,279,"[(tie, 41), (handbag, 40), (truck, 40), (cow, ..."
couch,510,286,224,"[(remote, 247), (book, 114), (laptop, 90), (cu..."
bed,475,207,268,"[(book, 103), (laptop, 63), (dog, 44), (cat, 4..."


In [25]:
df["num_images"].sum()

14999

In [26]:
top_10_filtered_dict

{'tennis racket': {'471762': {'file_name': 'COCO_train2014_000000471762.jpg',
   'gender': 'male',
   'target': 'tennis racket',
   'other_targets': ['sports ball', 'person']},
  '353754': {'file_name': 'COCO_train2014_000000353754.jpg',
   'gender': 'male',
   'target': 'tennis racket',
   'other_targets': ['sports ball', 'person']},
  '170601': {'file_name': 'COCO_train2014_000000170601.jpg',
   'gender': 'male',
   'target': 'tennis racket',
   'other_targets': ['sports ball', 'person']},
  '442961': {'file_name': 'COCO_train2014_000000442961.jpg',
   'gender': 'male',
   'target': 'tennis racket',
   'other_targets': ['person']},
  '484313': {'file_name': 'COCO_train2014_000000484313.jpg',
   'gender': 'male',
   'target': 'tennis racket',
   'other_targets': ['sports ball', 'person']},
  '255550': {'file_name': 'COCO_train2014_000000255550.jpg',
   'gender': 'male',
   'target': 'tennis racket',
   'other_targets': ['person', 'sports ball']},
  '158821': {'file_name': 'COCO_train2

# Balance Datasets

In [27]:
random.seed(42)

def get_balanced_target_dict(imgs_dict):
    male_image_ids = []
    female_image_ids = []
    for img_id, img_data in imgs_dict.items():
        if img_data["gender"] == "male":
            male_image_ids.append(img_id)
        else:
            female_image_ids.append(img_id)
    balance_threshold = min(len(male_image_ids), len(female_image_ids))
    print("balance threshold: ", balance_threshold)
    male_image_ids = random.sample(male_image_ids, balance_threshold)
    female_image_ids = random.sample(female_image_ids, balance_threshold)

    balanced_imgs_dict = {}
    for img_id, img_data in imgs_dict.items():
        if img_id in male_image_ids or img_id in female_image_ids:
            balanced_imgs_dict[img_id] = img_data.copy()
    
    return balanced_imgs_dict

In [28]:
top_10_balanced_dict = {}   

for target, imgs_dict in top_10_filtered_dict.items():
    top_10_balanced_dict[target] = get_balanced_target_dict(imgs_dict)

balance threshold:  862
balance threshold:  721
balance threshold:  634
balance threshold:  330
balance threshold:  505
balance threshold:  348
balance threshold:  279
balance threshold:  207
balance threshold:  582
balance threshold:  224


In [29]:
top_10_balanced_stats_dict = get_stats_dict(top_10_balanced_dict)
df = pd.DataFrame.from_dict(top_10_balanced_stats_dict, orient="index")
df.sort_values("num_images", ascending=False)

Unnamed: 0,num_images,num_males,num_females,cooccuring
tennis racket,1724,862,862,"[(sports ball, 903), (bench, 102), (backpack, ..."
chair,1442,721,721,"[(dining table, 561), (cup, 328), (knife, 177)..."
cell phone,1268,634,634,"[(handbag, 289), (backpack, 109), (bench, 96),..."
car,1164,582,582,"[(handbag, 316), (truck, 284), (traffic light,..."
bottle,1010,505,505,"[(cup, 304), (bowl, 300), (dining table, 196),..."
skis,696,348,348,"[(backpack, 136), (snowboard, 71), (bench, 15)..."
surfboard,660,330,330,"[(boat, 27), (dog, 19), (backpack, 16), (kite,..."
horse,558,279,279,"[(tie, 29), (handbag, 29), (truck, 27), (cow, ..."
couch,448,224,224,"[(remote, 219), (book, 102), (laptop, 82), (cu..."
bed,414,207,207,"[(book, 91), (laptop, 52), (cat, 41), (dog, 40..."


In [30]:
df["num_images"].sum()

9384

In [31]:
top_10_balanced_images = set()
for target, imgs_dict in top_10_balanced_dict.items():
    top_10_balanced_images.update(imgs_dict.keys())
  
len(top_10_balanced_images)

9384

In [46]:
#export balanced annotations to json

full_annotations = json.load(open("full_annotations.json"))

selected_annotations = {}

# for img_id, img_data in full_annotations.items():
#     if img_id in top_10_balanced_images and img_data["target"] in top_10_targets_balanced:
#         selected_annotations[img_id] = img_data.copy()

for target, imgs_dict in top_10_balanced_dict.items():
    for img_id, img_data in imgs_dict.items():
        selected_annotations[img_id] = img_data.copy()

len(selected_annotations)

9384

In [47]:
# export selected annotations to json
with open("single_label/balanced_annotations.json", "w") as outfile:
    json.dump(selected_annotations, outfile, indent=4)

In [34]:
# export balanced dict as well
with open("single_label/balanced.json", "w") as outfile:
    json.dump(top_10_balanced_dict, outfile, indent=4)

In [48]:
# create a target -> id mapping
target_to_id = {}
for i, target in enumerate(top_10_targets_balanced):
    target_to_id[target] = i

target_to_id

{'tennis racket': 0,
 'chair': 1,
 'cell phone': 2,
 'surfboard': 3,
 'bottle': 4,
 'skis': 5,
 'horse': 6,
 'bed': 7,
 'car': 8,
 'couch': 9}

In [49]:
# add target id data to annotations
for img_id, img_data in selected_annotations.items():
    img_data["target_id"] = target_to_id[img_data["target"]]

# export selected annotations to json
with open("single_label/balanced_annotations.json", "w") as outfile:
    json.dump(selected_annotations, outfile, indent=4)

# Train-Val-Test Splut

In [50]:
# Load in annotations
target_img_dict = json.load(open("single_label/balanced.json"))
len(target_img_dict)

10

In [54]:
# add target index to each image
for target, imgs_dict in target_img_dict.items():
    for img_id, img_data in imgs_dict.items():
        img_data["target_id"] = target_to_id[target]

In [64]:
# balance modes: balanced, imbalanced_1 (first half male), imbalanced_2 (first half female)
def train_val_test_split(data_dict, test_ratio, val_ratio, random_seed = 42):
    random.seed(random_seed)

    balanced_train_set = []
    imbalanced1_train_set = []
    imbalanced2_train_set = []
    val_set = []
    test_set = []

    for idx, (target, imgs) in enumerate(data_dict.items()):
        male_imgs = []
        female_imgs = []
        for img_id, img_data in imgs.items():
            if img_data["gender"] == "male":
                male_imgs.append((img_id, img_data))
            else:
                female_imgs.append((img_id, img_data))
        
        random.shuffle(male_imgs)
        random.shuffle(female_imgs)

        imgs_len = len(male_imgs)
        num_test = int(imgs_len * test_ratio)
        num_val = int(imgs_len * val_ratio)
        num_train = imgs_len - num_test - num_val
        num_train_male_balanced = num_train // 2
        num_train_female_balanced = num_train // 2

        num_train_male_imbalanced1 = num_train if idx < len(data_dict) / 2 else 0
        num_train_female_imbalanced1 = 0 if idx < len(data_dict) / 2 else num_train

        num_train_male_imbalanced2 = 0 if idx < len(data_dict) / 2 else num_train
        num_train_female_imbalanced2 = num_train if idx < len(data_dict) / 2 else 0

        test_set.extend(male_imgs[:num_test])
        test_set.extend(female_imgs[:num_test])

        val_set.extend(male_imgs[num_test:num_test + num_val])
        val_set.extend(female_imgs[num_test:num_test + num_val])

        balanced_train_set.extend(male_imgs[num_test + num_val:num_test + num_val + num_train_male_balanced])
        balanced_train_set.extend(female_imgs[num_test + num_val:num_test + num_val + num_train_female_balanced])

        imbalanced1_train_set.extend(male_imgs[num_test + num_val:num_test + num_val + num_train_male_imbalanced1])
        imbalanced1_train_set.extend(female_imgs[num_test + num_val:num_test + num_val + num_train_female_imbalanced1])

        imbalanced2_train_set.extend(male_imgs[num_test + num_val:num_test + num_val + num_train_male_imbalanced2])
        imbalanced2_train_set.extend(female_imgs[num_test + num_val:num_test + num_val + num_train_female_imbalanced2])

    return dict(balanced_train_set), dict(imbalanced1_train_set), dict(imbalanced2_train_set), dict(val_set), dict(test_set)
        

In [65]:
balanced_train_set, imbalanced1_train_set, imbalanced2_train_set, val_set, test_set = train_val_test_split(target_img_dict, test_ratio=0.2, val_ratio=0.2)
print(len(balanced_train_set), len(imbalanced1_train_set), len(imbalanced2_train_set), len(val_set), len(test_set))

2820 2824 2824 1868 1868


In [66]:
# export train, val, test sets to json
with open("single_label/balanced_train.json", "w") as outfile:
    json.dump(balanced_train_set, outfile, indent=4)

with open("single_label/imbalanced1_train.json", "w") as outfile:
    json.dump(imbalanced1_train_set, outfile, indent=4)

with open("single_label/imbalanced2_train.json", "w") as outfile:
    json.dump(imbalanced2_train_set, outfile, indent=4)

with open("single_label/val.json", "w") as outfile:
    json.dump(val_set, outfile, indent=4)

with open("single_label/test.json", "w") as outfile:
    json.dump(test_set, outfile, indent=4)

In [58]:
# export target to id mapping to json
with open("single_label/target_to_id.json", "w") as outfile:
    json.dump(target_to_id, outfile, indent=4)