In [None]:
import os
import numpy as np
import random
import warnings

from xml.etree import ElementTree

# Mask RCNN
from mrcnn.config import Config
from mrcnn import utils
from mrcnn.utils import extract_bboxes
from mrcnn.visualize import display_instances
import mrcnn.model as modellib

# Enable mixed precision globally
from tensorflow.keras import mixed_precision

mixed_precision.set_global_policy('mixed_float16')

# Ignore the transparency warning since it really doesn't affect the training
warnings.filterwarnings("ignore", message=".*Palette images with Transparency.*")

In [None]:
#################################################
## DATASET
#################################################

class FoodDatasets(utils.Dataset):
    def load_dataset(self, dataset_dir, is_train=True):
        # Classes
        self.add_class('dataset', 1, 'apple')
        self.add_class('dataset', 2, 'banana')
        self.add_class('dataset', 3, 'orange')

        # Dataset Dir
        images_dir = dataset_dir + '/images/'
        annots_dir = dataset_dir + '/annots/'

        # Find all Images
        for filename in os.listdir(images_dir):
            # Extract image id
            image_id = filename[:-4]

            # Split the train and val
            if is_train and int(image_id) >= 250:
                continue
            if not is_train and int(image_id) < 250:
                continue

            # Add to dataset
            img_path = images_dir + filename
            ann_path = annots_dir + image_id + '.xml'

            self.add_image('dataset', image_id=image_id, path=img_path, annotation=ann_path)

    def extract_boxes(self, filename):
        tree = ElementTree.parse(filename)
        root = tree.getroot()

        # Extract each bounding box
        boxes = list()
        for box in root.findall('.//object'):
            name = box.find('name').text
            xmin = int(box.find('./bndbox/xmin').text)
            ymin = int(box.find('./bndbox/ymin').text)
            xmax = int(box.find('./bndbox/xmax').text)
            ymax = int(box.find('./bndbox/ymax').text)
            coors = [xmin, ymin, xmax, ymax, name]
            boxes.append(coors)

        # Extract image dimensions
        width = int(root.find('.//size/width').text)
        height = int(root.find('.//size/height').text)

        return boxes, width, height
    
    def load_mask(self, image_id):
        # Get image info
        info = self.image_info[image_id]
        path = info['annotation']

        # Load XML
        boxes, w, h = self.extract_boxes(path)
        masks = np.zeros([h, w, len(boxes)], dtype='uint8')

        # Create masks
        class_ids = list()
        for i, box in enumerate(boxes):
            xmin, ymin, xmax, ymax, class_name = box

            # Set mask to 1 (binary) regardless of class
            masks[ymin:ymax, xmin:xmax, i] = 1

            # Append the class ID based on class name
            class_ids.append(self.class_names.index(class_name))

        return masks, np.asarray(class_ids, dtype='int32')
    
    def image_reference(self, image_id):
        info = self.image_info[image_id]

        return info['path']

# DIR
ROOT_DIR = os.path.abspath('./')
LOGS_DIR = os.path.join(ROOT_DIR, 'logs')

# Datasets
train_dataset = FoodDatasets()
train_dataset.load_dataset(dataset_dir='datasets', is_train=True)
train_dataset.prepare()

valid_dataset = FoodDatasets()
valid_dataset.load_dataset(dataset_dir='datasets', is_train=False)
valid_dataset.prepare()

class FoodConfig(Config):
    NAME = 'foods_cfg'
    NUM_CLASSES = 1 + 3
    STEPS_PER_EPOCH = 100
    IMAGES_PER_GPU = 2

# Prepare Config
config = FoodConfig()
config.display()

In [None]:
################# Exploring Dataset #################
# Not only exploring, but with this we can check which dataset is incorrect

num = random.randint(0, len(train_dataset.image_ids))
print(f'Displaying image id of {num}')
image_id = num
image = train_dataset.load_image(image_id)
mask, class_ids = train_dataset.load_mask(image_id)
bbox = extract_bboxes(mask)
display_instances(image, bbox, mask, class_ids, train_dataset.class_names)

In [None]:
#################################################################
# Models
#################################################################
model = modellib.MaskRCNN(
    mode='training',
    model_dir='logs',
    config=config
)
model.load_weights(
    'mask_rcnn_coco.h5',
    by_name=True,
    exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", "mrcnn_bbox", "mrcnn_mask"]
)
model.train(
    train_dataset=train_dataset,
    val_dataset=valid_dataset,
    learning_rate=config.LEARNING_RATE,
    epochs=20,
    layers='heads'
)