In [None]:
# default_exp annotation.multi_category_adapter

In [None]:
# hide
from nbdev.showdoc import *

In [None]:
# export

import csv
import sys
import argparse
import logging
import shutil
from os.path import join, normpath, sep, getsize, basename
from mlcore.io.core import create_folder, scan_files
from mlcore.annotation.core import Annotation, Region, create_annotation_id

In [None]:
# hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# export

DEFAULT_ANNOTATIONS_FILE = 'annotations.csv'

In [None]:
# export

logger = logging.getLogger(__name__)

# Multi Category Annotation Adapter
> Adapter to read and write annotations for multi label classification.

In [None]:
# export


def read_annotations(annotations_file, files_source):
    """
    Reads a multi classification CSV annotations file.
    `annotations_file`: the path to the CSV annotation file to read
    `files_source`: the path to the folder containing the source files
    return: the annotations as dictionary
    """
    annotations = {}

    with open(annotations_file, newline='') as csvfile:
        reader = csv.DictReader(csvfile)

        for row in reader:
            file_path = join(files_source, row['image_name'])
            annotation_id = create_annotation_id(file_path)
            if annotation_id is None:
                logger.warning('File not found, skip annotations at path: {}'.format(file_path))
                continue

            if annotation_id not in annotations:
                file_size = getsize(file_path)
                file_name = basename(file_path)
                annotations[annotation_id] = Annotation(annotation_id=annotation_id, file_name=file_name,
                                                        file_size=file_size, file_path=file_path)

            annotation = annotations[annotation_id]

            tags = row['tags'] if 'tags' in row else []
            for category in tags.split(' '):
                region = Region(labels=[category])
                annotation.regions.append(region)

    return annotations

In [None]:
# export


def write_annotations(annotations_file, annotations):
    """
    Reads a multi classification CSV annotations file.
    `annotations_file`: the path to the CSV annotation file to write
    `annotations`: the annotations to write
    """
    with open(annotations_file, 'w', newline='') as csvfile:
        fieldnames = ['image_name', 'tags']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
        writer.writeheader()

        for annotation in annotations.values():
            writer.writerow({'image_name': annotation.file_name,
                             'tags': ' '.join(annotation.labels())})

## Helper Methods

In [None]:
# export


def configure_logging(logging_level=logging.INFO):
    """
    Configures logging for the system.

    :param logging_level: The logging level to use.
    """
    logger.setLevel(logging_level)

    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging_level)

    logger.addHandler(handler)

## Run from command line

To run the data-set builder from command line, use the following command:
`python -m mlcore.annotation.multi_category_adapter [parameters]`

The following parameters are supported:
- `[annotation]`: The path to the multi classification CSV annotation file (e.g.: *imagesets/segmentation/car_damage/annotations.csv*)
- `--files_source`: The path to the folder containing the source files (e.g.: *imagesets/segmentation/car_damage/trainval*)

In [None]:
# export


if __name__ == '__main__' and '__file__' in globals():
    # for direct shell execution
    configure_logging()

    parser = argparse.ArgumentParser()
    parser.add_argument("annotation",
                        help="The path to the multi classification CSV annotation file.")
    parser.add_argument("--files_source",
                        help="The path to the folder containing the source files.")
    args = parser.parse_args()

    read_annotations(args.annotation, args.files_source)
