In [None]:
import json
import random
from shutil import copyfile

In [None]:
with open("./coco_data/annotations/instances_train2017.json") as json_file:
  object_detections = json.load(json_file)

COUNT_PER_CATEGORY = 1000
# Category IDs for motorcycle, airplane, selephant and tennis racket respectively
category_list = [4, 5, 22, 43]
category_dict = dict()
for category_id in category_list:
  category_dict[category_id] = dict()

all_images = dict()
filtered_images = set()

for annotation in object_detections['annotations']:
  category_id = annotation['category_id']
  image_id = annotation['image_id']
  area = annotation['area']
  if category_id in category_list:
    if image_id not in category_dict[category_id]:
      category_dict[category_id][image_id] = []
  if image_id not in all_images:
    all_images[image_id] = dict()
  if category_id not in all_images[image_id]:
    all_images[image_id][category_id] = area
  else:
    current_area = all_images[image_id][category_id]
    if area > current_area:
      all_images[image_id][category_id] = area

for image_id in all_images:
  areas = list(all_images[image_id].values())
  categories = list(all_images[image_id].keys())
  sorted_areas = sorted(areas, reverse=True)
  sorted_categories = []
  for area in sorted_areas:
    sorted_categories.append(categories[areas.index(area)])
  all_images[image_id] = sorted_categories

for category_id in category_dict:
  for image_id in category_dict[category_id]:
    category_dict[category_id][image_id] = all_images[image_id]
  count = 0
  prominance_index = 0
  while count < COUNT_PER_CATEGORY:
    prominent_images = []
    prominance_index = prominance_index + 1
    count = 0
    for image_id in category_dict[category_id]:
      if category_dict[category_id][image_id].index(category_id) < prominance_index:
        prominent_images.append(image_id)
        count = count + 1
  random.shuffle(prominent_images)
  filtered_images.update(prominent_images[0:COUNT_PER_CATEGORY])

print("Expected number of filtered images is {}".format(len(category_list) * COUNT_PER_CATEGORY))
print("Actual number of filtered images is {}".format(len(filtered_images)))

In [None]:
with open("./coco_data/annotations/captions_train2017.json") as json_file:
  captions = json.load(json_file)

filtered_annotations = []
for annotation in captions['annotations']:
  if annotation['image_id'] in filtered_images:
    filtered_annotations.append(annotation)
captions['annotations'] = filtered_annotations
print("Expected number of filtered annotations should be roughly {}".format(len(filtered_images) * 5))
print("Actual number of filtered annotations is {}".format(len(captions['annotations'])))

images = []
filtered_image_file_names = set()
for image in captions['images']:
  if image['id'] in filtered_images:
    images.append(image)
    filtered_image_file_names.add(image['file_name'])
captions['images'] = images
print("Expected number of filtered images is {}".format(len(filtered_images)))
print("Actual number of filtered images is {}".format(len(captions['images'])))

with open("./coco_data/captions.json", 'w+') as output_file:
  json.dump(captions, output_file)
!rm -rf ./coco_data/annotations

!mkdir coco_data/images
for file_name in filtered_image_file_names:
  copyfile("./coco_data/train2017/{}".format(file_name), "./coco_data/images/{}".format(file_name))
!rm -rf ./coco_data/train2017