In [None]:
import json
import random
from shutil import copyfile

In [None]:
obj_fl = "./coco_data/annotations/instances_train2017.json"
with open(obj_fl) as json_file:
  object_detections = json.load(json_file)

CATEGORY_LIST = [4, 5, 22, 43]

COUNT_PER_CATEGORY = 1000

category_dict = dict()
for category_id in CATEGORY_LIST:
  category_dict[category_id] = dict()

all_images = dict()
filtered_images = set()

for annotation in object_detections['annotations']:
  category_id = annotation['category_id']
  image_id = annotation['image_id']
  area = annotation['area']
  if category_id in CATEGORY_LIST:
    if image_id not in category_dict[category_id]:
      category_dict[category_id][image_id] = []
  if image_id not in all_images:
    all_images[image_id] = dict()
  if category_id not in all_images[image_id]:
    all_images[image_id][category_id] = area
  else:
    current_area = all_images[image_id][category_id]
    if area > current_area:
      all_images[image_id][category_id] = area

if COUNT_PER_CATEGORY == -1:
  for category_id in category_dict:
    print("Processing category {}".format(category_id))
    filtered_images.update(category_dict[category_id].keys())
    print("  Filtered total {} images of category {}".format(len(category_dict[category_id].keys()), category_id))
else:
  for image_id in all_images:
    areas = list(all_images[image_id].values())
    categories = list(all_images[image_id].keys())
    sorted_areas = sorted(areas, reverse=True)
    sorted_categories = []
    for area in sorted_areas:
      sorted_categories.append(categories[areas.index(area)])
    all_images[image_id] = sorted_categories

  for category_id in category_dict:
    print("Processing category {}".format(category_id))
    for image_id in category_dict[category_id]:
      category_dict[category_id][image_id] = all_images[image_id]
    prominance_index = 0
    prominent_image_ids = []
    while len(category_dict[category_id]) > 0 and len(prominent_image_ids) < COUNT_PER_CATEGORY:
      remaining_count = COUNT_PER_CATEGORY - len(prominent_image_ids)
      image_ids = []
      for image_id in category_dict[category_id]:
        if category_dict[category_id][image_id].index(category_id) == prominance_index:
          image_ids.append(image_id)
      for image_id in image_ids:
        del category_dict[category_id][image_id]
      if len(image_ids) <= remaining_count:
        prominent_image_ids = prominent_image_ids + image_ids
        if prominance_index > 4:
          print(image_ids)
        print("  Added all {} images at prominance_index {}".format(len(image_ids), prominance_index))
      else:
        random.shuffle(image_ids)
        prominent_image_ids = prominent_image_ids + image_ids[0:remaining_count]
        print("  Added {} images at prominance_index {} out of {} images".format(remaining_count, prominance_index, len(image_ids)))
      prominance_index = prominance_index + 1
    filtered_images.update(prominent_image_ids)
    print("  Completed filtering of total {} images of category {}".format(len(prominent_image_ids), category_id))

print("Processed all categories. Number of filtered images is {}".format(len(filtered_images)))

In [None]:
caps_fl = "./coco_data/annotations/captions_train2017.json"
with open(caps_fl) as json_file:
  captions = json.load(json_file)

filtered_annotations = []
for annotation in captions['annotations']:
  if annotation['image_id'] in filtered_images:
    filtered_annotations.append(annotation)
captions['annotations'] = filtered_annotations
print("Number of filtered annotations is {}".format(len(captions['annotations'])))

images = []
filtered_image_file_names = set()
for image in captions['images']:
  if image['id'] in filtered_images:
    images.append(image)
    filtered_image_file_names.add(image['file_name'])
captions['images'] = images
print("Expected number of filtered images is {}, actual number is {}".format(len(filtered_images), len(captions['images'])))

with open("./coco_data/captions.json", 'w+') as output_file:
  json.dump(captions, output_file)
!rm -rf ./coco_data/annotations

!mkdir coco_data/images
for file_name in filtered_image_file_names:
  copyfile("./coco_data/train2017/{}".format(file_name),
           "./coco_data/images/{}".format(file_name))
!rm -rf ./coco_data/train2017