In [None]:
# Copyright 2022 Sony Semiconductor Solutions Corp. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Download Images from COCO

This notebook explains the workflow for getting images from [COCO](https://cocodataset.org/#home). <br>
<br>
Instructions are described in [README.md](./README.md).

## Imports

In [None]:
import errno
import json
import jsonschema
import os
import time
from contextlib import redirect_stdout
from pathlib import Path

from pycocotools.coco import COCO

## Load Configurations

Load the configuration file and set the variables.

In [None]:
def validate_symlink(path: Path):
    if path.is_symlink():
        raise OSError(
            errno.ELOOP,
            "Symbolic link is not supported. Please use real folder or file",
            f"{path}",
        )


configuration_path = Path("./configuration.json")
validate_symlink(configuration_path)

with open(configuration_path, "r") as f:
    app_configuration = json.load(f)

configuration_schema_path = Path("./configuration_schema.json")
validate_symlink(configuration_schema_path)

with open(configuration_schema_path, "r") as f:
    json_schema = json.load(f)

# Validate configuration.
jsonschema.validate(app_configuration, json_schema)

# Set annotation file path:
annotation_file = app_configuration["annotation_file"].replace(os.path.sep, "/")
validate_symlink(Path(annotation_file))

# Set category names:
category_names = app_configuration.get("category_names", [])

# Set max number of downloads from coco dataset:
max_download_count = app_configuration.get("max_download_count", 0)

# Set target licenses:
licenses = app_configuration.get("licenses", [])

# Set category names to remove
remove_categories = app_configuration.get("remove_categories", [])

# Set output directory name
output_dir = app_configuration["output_dir"].replace(os.path.sep, "/")
validate_symlink(Path(output_dir))

## Validate the Configuration Value

In [None]:
# Initialize COCO API for instance annotations:
coco = COCO(annotation_file)

categories = coco.dataset["categories"]
licenses_lst = coco.dataset["licenses"]


def get_category(categories, category_name):
    for category in categories:
        if category_name == category["name"]:
            return True
    return False


def get_license(licenses_lst, license):
    for item in licenses_lst:
        if license == item["id"]:
            return True
    return False


# Validate category names:
if len(category_names) > 0:
    category_names_valid = []
    for category_name in category_names:
        if get_category(categories, category_name) is False:
            print(
                '[Warning]:The value "%s" in "category_names" is invalid.'
                % category_name
            )
        else:
            category_names_valid.append(category_name)
    if len(category_names_valid) == 0:
        raise ValueError(
            '"category_names" is not set to a valid value in configuration.json.'
        )
    category_names = category_names_valid


# Validate target licenses:
if len(licenses) > 0:
    licenses_valid = []
    for license in licenses:
        if get_license(licenses_lst, license) is False:
            print('[Warning]:The value "%s" in "licenses" is invalid.' % license)
        else:
            licenses_valid.append(license)
    if len(licenses_valid) == 0:
        raise ValueError(
            '"licenses" is not set to a valid value in configuration.json.'
        )
    licenses = licenses_valid

# Validate category names to remove
if len(remove_categories) > 0:
    remove_categories_valid = []
    for category_name in remove_categories:
        if get_category(categories, category_name) is False:
            print(
                '[Warning]:The value "%s" in "remove_categories" is invalid.'
                % category_name
            )
        else:
            remove_categories_valid.append(category_name)
    if len(remove_categories_valid) == 0:
        raise ValueError(
            '"remove_categories" is not set to a valid value in configuration.json.'
        )
    remove_categories = remove_categories_valid

## Create Image List to Download

In [None]:
def get_image_ids_by_category(coco, category_names):
    image_ids_by_category = []
    if len(category_names) == 0:
        image_ids_by_category.append(coco.getImgIds())
    else:
        category_ids = coco.getCatIds(catNms=category_names)
        for category_id in category_ids:
            image_ids_by_category.append(coco.getImgIds(catIds=category_id))
    return image_ids_by_category


def filter_license(image_ids, coco, licenses):
    if len(licenses) == 0:
        return image_ids
    new_image_ids = []
    image_infos = coco.loadImgs(image_ids)

    for image_info in image_infos:
        if image_info["license"] in licenses:
            new_image_ids.append(image_info["id"])

    return new_image_ids


def remove_category(image_ids, coco, remove_categories):
    if len(remove_categories) == 0:
        return image_ids

    remove_category_ids = coco.getCatIds(catNms=remove_categories)

    remove_image_ids = []
    for remove_category_id in remove_category_ids:
        remove_image_ids.extend(
            coco.getImgIds(imgIds=image_ids, catIds=remove_category_id)
        )

    remove_image_ids = list(set(remove_image_ids))

    for id in remove_image_ids:
        image_ids.remove(id)

    return image_ids


def limit_number(image_ids, coco, max_download_count):
    if max_download_count == 0:
        return image_ids
    image_ids = image_ids[0:max_download_count]
    return image_ids


image_ids_by_category = []
download_image_ids = []

# Get Image Ids of specified categories
image_ids_by_category = get_image_ids_by_category(coco, category_names)

# For each category, extract Image IDs according to the settings
for image_ids in image_ids_by_category:
    image_ids = filter_license(image_ids, coco, licenses)
    if len(image_ids) == 0:
        continue
    image_ids = remove_category(image_ids, coco, remove_categories)
    if len(image_ids) == 0:
        continue
    image_ids = limit_number(image_ids, coco, max_download_count)
    download_image_ids += image_ids

download_image_ids = list(set(download_image_ids))

## Download Images from COCO

> **NOTE**
>
> If the download stops with "[Errno 104] Connection reset by peer", run this cell again. <br>
> The download will resume from where it stopped.

> **NOTE**
>
> If you stop cell execution during download, the image may be corrupted. <br>
> The image whose download was interrupted can be checked from the output log. <br>
> If the image is corrupted, delete it and run this cell again.

In [None]:
# Download Images
if len(download_image_ids) != 0:
    for i, img_id in enumerate(download_image_ids):
        print("downloading id:", img_id)
        with redirect_stdout(open(os.devnull, "w")):
            tic = time.time()
            coco.download(output_dir, [img_id])
        print(
            "downloaded {}/{} images (t={:0.1f}s)\n".format(
                i + 1, len(download_image_ids), time.time() - tic
            )
        )

    print("\ndownloaded image Ids:", download_image_ids)
    print("Total count:", len(download_image_ids))
else:
    print(
        "[Warning]:There are no images available for download. "
        + "Change the parameters in configuration.json."
    )