In [None]:
# uncomment in case of google colab
# from google.colab import drive
# drive.mount('/content/drive')

import tensorflow as tf2
print(tf2.__version__)

In [None]:
# uncomment in case of google colab
# import os
# PATH_TO_FILE = '/content/drive/My Drive/Development/fashion_instance_segmentation/src'
# os.chdir(PATH_TO_FILE)

In [None]:
import os
import sys
import itertools
import math
import logging
import json
import re
import random
from collections import OrderedDict
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.lines as lines
from matplotlib.patches import Polygon

# Root directory of the project
ROOT_DIR = os.path.abspath("../")

# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library

MRNN_DIR = os.path.join(ROOT_DIR, 'externals/mask_rcnn')
#print(MRNN_DIR)
sys.path.append(MRNN_DIR)

from mrcnn import utils
from mrcnn import visualize
from mrcnn.visualize import display_images
import mrcnn.model as modellib
from mrcnn.model import log

import rle_helper
import dress
#from src import dress

%matplotlib inline 

In [None]:
def get_ax(rows=1, cols=1, size=8):
    """Return a Matplotlib Axes array to be used in
    all visualizations in the notebook. Provide a
    central point to control graph sizes.
    
    Change the default size attribute to control the size
    of rendered images
    """
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

In [None]:
DRESS_DIR = os.path.join(ROOT_DIR, "data")

configResnet101 = dress.DressConfig()
configResnet50 = dress.DressResnet50Config()

In [None]:
configResnet101.display()

In [None]:
configResnet50.display()

In [None]:
dataset_train = dress.DressDataset()
dataset_train.load_dress(DRESS_DIR, "train")
dataset_train.prepare()

# Validation dataset
dataset_val = dress.DressDataset()
dataset_val.load_dress(DRESS_DIR, "val")
dataset_val.prepare()

print("Image Count: {}".format(len(dataset_train.image_ids)))
print("Image Count: {}".format(len(dataset_val.image_ids)))
print("Class Count: {}".format(dataset_train.num_classes))
for i, info in enumerate(dataset_train.class_info):
    print("{:3}. {:50}".format(i, info['name']))

In [None]:
image_ids = np.random.choice(dataset_train.image_ids, 4)
for image_id in image_ids:
    image = dataset_train.load_image(image_id)
    mask, class_ids = dataset_train.load_mask(image_id)
    visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names)

In [None]:
# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

In [None]:
# Create model in training mode
modelResnet101 = modellib.MaskRCNN(mode="training", config=configResnet101,
                          model_dir=MODEL_DIR)

# Create model in training mode
modelResnet50 = modellib.MaskRCNN(mode="training", config=configResnet50,
                          model_dir=MODEL_DIR)

In [None]:
# Local path to trained weights file
COCO_MODEL_PATH = os.path.join(MRNN_DIR, "mask_rcnn_coco.h5")
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

modelResnet101.load_weights(COCO_MODEL_PATH, by_name=True,
                       exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", 
                                "mrcnn_bbox", "mrcnn_mask"])


modelResnet50.load_weights(COCO_MODEL_PATH, by_name=True,
                       exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", 
                                "mrcnn_bbox", "mrcnn_mask"])

In [None]:
# Train the head branches
# Passing layers="heads" freezes all layers except the head
# layers. You can also pass a regular expression to select
# which layers to train by name pattern.
modelResnet50.train(dataset_train, dataset_val, 
            learning_rate=configResnet50.LEARNING_RATE, 
            epochs=10, 
            layers='heads')

In [None]:
inference_config = dress.InferenceResnet50Config()

# Recreate the model in inference mode
modelResnet50 = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)

# Get path to saved weights
# Either set a specific path or find last trained weights
# model_path = os.path.join(ROOT_DIR, ".h5 file name here")
model_path = modelResnet50.find_last()

# Load trained weights
print("Loading weights from ", model_path)
modelResnet50.load_weights(model_path, by_name=True)

In [None]:
# Test on a random image
image_id = random.choice(dataset_val.image_ids)
original_image, image_meta, gt_class_id, gt_bbox, gt_mask =\
    modellib.load_image_gt(dataset_val, inference_config, 
                           image_id, use_mini_mask=False)

log("original_image", original_image)
log("image_meta", image_meta)
log("gt_class_id", gt_class_id)
log("gt_bbox", gt_bbox)
log("gt_mask", gt_mask)

visualize.display_instances(original_image, gt_bbox, gt_mask, gt_class_id, 
                            dataset_train.class_names, figsize=(8, 8))

In [None]:
results = modelResnet50.detect([original_image], verbose=1)

r = results[0]
visualize.display_instances(original_image, r['rois'], r['masks'], r['class_ids'], 
                            dataset_val.class_names, r['scores'], ax=get_ax())

In [None]:
modelResnet101.train(dataset_train, dataset_val, 
            learning_rate=configResnet101.LEARNING_RATE, 
            epochs=10, 
            layers='heads')

In [None]:

inference_config = dress.InferenceResnet101Config()

# Recreate the model in inference mode
modelResnet101 = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)

# Get path to saved weights
# Either set a specific path or find last trained weights
# model_path = os.path.join(ROOT_DIR, ".h5 file name here")
model_path = modelResnet101.find_last()

# Load trained weights
print("Loading weights from ", model_path)
modelResnet101.load_weights(model_path, by_name=True)

In [None]:
results = modelResnet101.detect([original_image], verbose=1)

r = results[0]
visualize.display_instances(original_image, r['rois'], r['masks'], r['class_ids'], 
                            dataset_val.class_names, r['scores'], ax=get_ax())