In [None]:
# default_exp annotation.multi_category_adapter

In [None]:
# hide
from nbdev.showdoc import *
from nbdev.export import notebook2script

In [None]:
# export

import csv
import sys
import argparse
import shutil
import logging
from os.path import join, basename, isfile, dirname
from mlcore.io.core import create_folder
from mlcore.annotation.core import Annotation, AnnotationAdapter, Region

In [None]:
# hide
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# export

DEFAULT_ANNOTATIONS_FILE = 'annotations.csv'

In [None]:
# export

logger = logging.getLogger(__name__)

# Multi Category Annotation Adapter
> Adapter to read and write annotations for multi label classification.

The `args` parameter contains the following options:
- `files_path`: the path to the folder containing the source files (e.g.: *data/segmentation/my_collection/trainval*)
- `annotations_file`: The path to the multi classification CSV annotation file (e.g.: data/segmentation/my_collection/annotations.csv)

In [None]:
# export


class MultiCategoryAdapter(AnnotationAdapter):
    """
    Adapter to read and write annotations for multi label classification.
    `args`: the arguments containing the parameters
    """

    def __init__(self, args):
        super().__init__()
        self.files_path = args.files_path
        self.annotations_file = args.annotations_file

    def read(self):
        """
        Read annotations from a multi classification CSV annotations file.
        return: the annotations as dictionary
        """
        annotations = {}

        logger.info('Read annotations from {}'.format(self.annotations_file))

        with open(self.annotations_file, newline='') as csv_file:
            reader = csv.DictReader(csv_file)

            skipped_annotations = []
            for row in reader:
                file_path = join(self.files_path, row['image_name'])
                if not isfile(file_path):
                    logger.warning("{}: Source file not found, skip annotation.".format(file_path))
                    skipped_annotations.append(file_path)
                    continue

                if file_path not in annotations:
                    annotations[file_path] = Annotation(annotation_id=file_path, file_path=file_path)

                annotation = annotations[file_path]

                tags = row['tags'] if 'tags' in row else []
                for category in tags.split(' '):
                    region = Region(labels=[category])
                    annotation.regions.append(region)

        logger.info('Finished read annotations')
        logger.info('Annotations read: {}'.format(len(annotations)))
        if skipped_annotations:
            logger.info('Annotations skipped: {}'.format(len(skipped_annotations)))
        return annotations

    def write(self, annotations):
        """
        Writes a multi classification CSV annotations file and copy the corresponding source files.
        `annotations`: the annotations to write
        """
        target_folder = create_folder(self.files_path)
        create_folder(dirname(self.annotations_file))
        logger.info('Write annotations to {}'.format(self.annotations_file))
        logger.info('Write file sources to {}'.format(target_folder))
        fieldnames = ['image_name', 'tags']
        with open(self.annotations_file, 'w', newline='') as csv_file:
            writer = csv.DictWriter(csv_file, fieldnames=fieldnames, delimiter=',', quotechar='"',
                                    quoting=csv.QUOTE_MINIMAL)
            writer.writeheader()
            skipped_annotations = []
            for annotation in annotations.values():
                target_file = join(target_folder, basename(annotation.file_path))

                if not isfile(annotation.file_path):
                    logger.warning("{}: Source file not found, skip annotation.".format(annotation.file_path))
                    skipped_annotations.append(annotation.file_path)
                    continue
                if isfile(target_file):
                    logger.warning("{}: Target file already exist, skip annotation.".format(annotation.file_path))
                    skipped_annotations.append(annotation.file_path)
                    continue

                # copy the file
                shutil.copy2(annotation.file_path, target_file)
                writer.writerow({'image_name': basename(annotation.file_path),
                                 'tags': ' '.join(annotation.labels())})

        logger.info('Finished write annotations')
        logger.info('Annotations written: {}'.format(len(annotations) - len(skipped_annotations)))
        if skipped_annotations:
            logger.info('Annotations skipped: {}'.format(len(skipped_annotations)))

    @classmethod
    def argparse(cls, prefix=None):
        """
        Returns the argument parser containing argument definition for command line use.
        `prefix`: a parameter prefix to set, if needed
        return: the argument parser
        """
        parser = argparse.ArgumentParser()
        parser.add_argument(cls.assign_prefix('--files_path', prefix),
                            dest="files_path",
                            help="The path to the folder containing the files.",
                            required=True)
        parser.add_argument(cls.assign_prefix('--annotations_file', prefix),
                            dest="annotations_file",
                            help="The path to the multi classification CSV annotation file.",
                            required=True)
        return parser

In [None]:
show_doc(MultiCategoryAdapter.read)
show_doc(MultiCategoryAdapter.write)
show_doc(MultiCategoryAdapter.argparse)

## Helper Methods

In [None]:
# export


def configure_logging(logging_level=logging.INFO):
    """
    Configures logging for the system.

    :param logging_level: The logging level to use.
    """
    logger.setLevel(logging_level)

    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging_level)

    logger.addHandler(handler)


In [None]:
# hide

# for generating scripts from notebook directly
notebook2script()

Converted annotation-core.ipynb.
Converted annotation-folder_category_adapter.ipynb.
Converted annotation-multi_category_adapter.ipynb.
Converted annotation-via_adapter.ipynb.
Converted annotation-viewer.ipynb.
Converted annotation-yolo_adapter.ipynb.
Converted category_tools.ipynb.
Converted core.ipynb.
Converted dataset-core.ipynb.
Converted dataset-image_classification.ipynb.
Converted dataset-image_object_detection.ipynb.
Converted dataset-image_segmentation.ipynb.
Converted dataset-type.ipynb.
Converted dataset_generator.ipynb.
Converted evaluation-core.ipynb.
Converted geometry.ipynb.
Converted image-color_palette.ipynb.
Converted image-inference.ipynb.
Converted image-opencv_tools.ipynb.
Converted image-pillow_tools.ipynb.
Converted image-tools.ipynb.
Converted index.ipynb.
Converted io-core.ipynb.
Converted tensorflow-tflite_converter.ipynb.
Converted tensorflow-tflite_metadata.ipynb.
Converted tensorflow-tfrecord_builder.ipynb.
Converted tools-check_double_images.ipynb.
Conver