# MODEL SETUP

In [None]:
import os
import sys
import json
import random
import math
import re
import time
import numpy as np
import cv2
import matplotlib
import matplotlib.pyplot as plt
import skimage.draw
import tensorflow as tf

# Root directory of the project
ROOT_DIR = os.path.abspath("../")

# Import Mask RCNN
sys.path.append(ROOT_DIR)  # To find local version of the library

from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.visualize import display_images
from mrcnn.model import log

#Import adaptions
from fashion_config import FashionConfig
from fashion_dataset import FashionDataset

%matplotlib inline

# Hide some tensorflaw warning messages
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

In [None]:
""" Load config """ 
config = FashionConfig()
#config.display()

In [None]:
""" Prepare dataset """ 
dataset_train = FashionDataset()
dataset_train.load_fashion(ROOT_DIR + '/datasets/big_deepfashion2', "train")
dataset_train.prepare()

dataset_val = FashionDataset()
dataset_val.load_fashion(ROOT_DIR + '/datasets/big_deepfashion2', "val")
dataset_val.prepare()


In [None]:
class InferenceConfig(FashionConfig):
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

inference_config = InferenceConfig()

# Recreate the model in inference mode
model = modellib.MaskRCNN(mode="inference", 
                          config=inference_config,
                          model_dir=MODEL_DIR)


In [None]:
# Load trained weights
model_path = "../logs/final_logs/logs_mask_rcnn_fashion_0045.h5"
print("Loading weights from ", model_path)
model.load_weights(model_path, by_name=True)

In [None]:
def get_ax(rows=1, cols=1, size=8):
    """Return a Matplotlib Axes array to be used in
    all visualizations in the notebook. Provide a
    central point to control graph sizes.
    
    Change the default size attribute to control the size
    of rendered images
    """
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

# EVALUATIONS


## Visual comparison between ground truth and model

In [None]:
# Test on a random image
image_id = random.choice(dataset_val.image_ids)

#Test on specific image
#image_id = 17803
#image_id = 7818
print("Image #{}".format(image_id))

image, image_meta, gt_class_id, gt_bbox, gt_mask, gt_landmark =\
    modellib.load_image_gt(dataset_val, inference_config, 
                           image_id, use_mini_mask=False)

log("image", image)
log("image_meta", image_meta)
log("gt_class_id", gt_class_id)
log("gt_bbox", gt_bbox)
log("gt_mask", gt_mask)
log("gt_landmark", gt_landmark)

visualize.display_instances(image, gt_bbox, gt_mask, gt_landmark, gt_class_id, 
                            dataset_val.class_names, figsize=(8, 8))

results = model.detect([image], verbose=2)

r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['landmarks'], r['class_ids'], 
                            dataset_val.class_names, r['scores'], ax=get_ax())

## Precision-recall curve
* The X-axis is recall: Recall is High if the amount of False Negatives is Low
* The Y-axis is precision: Precision is High if the amount of False Positives is low.


In [None]:
# Draw precision-recall curve based on what is returned in the image
AP, precisions, recalls, overlaps = utils.compute_ap(gt_bbox, gt_class_id, gt_mask,
                                          r['rois'], r['class_ids'], r['scores'], r['masks'])
visualize.plot_precision_recall(AP, precisions, recalls)

## Heatmap of predictions and ground truth (Still in beta)

In [None]:
# Grid of ground truth objects and their predictions
visualize.plot_overlaps(gt_class_id, r['class_ids'], r['scores'],
                        overlaps, dataset_val.class_names)

## Compare the generated masks
Useful to see what the network sees when there is overlap

In [None]:
visualize.display_top_masks(image, gt_mask, gt_class_id, 
                            dataset_val.class_names)
visualize.display_top_masks(image, r['masks'], r['class_ids'], 
                            dataset_val.class_names)

## Visualize Activations

In some cases it helps to look at the output from different layers and visualize them to catch issues and odd patterns.

In [None]:
#Check the names of all layers
for layer in model.keras_model.layers:
    print(layer.name)

In [None]:
# Get activations of a few sample layers

# Select layer to visualize. Most play nicely with the backbone feature mapping, but not all.
layer_name = 'res3a_out'

activations = model.run_graph([image], [
    (layer_name,           model.keras_model.get_layer(layer_name).output),
    ("input_image",        tf.identity(model.keras_model.get_layer("input_image").output)),
    ("rpn_bbox",           model.keras_model.get_layer("rpn_bbox").output),
    ("roi",                model.keras_model.get_layer("ROI").output),
])

# Show input image (normalized)
_ = plt.imshow(modellib.unmold_image(activations["input_image"][0], config))

# Show Backbone feature map
display_images(np.transpose(activations[layer_name][0,:,:,:16], [2, 0, 1]))

## Calculate Mask Average Precison (mAP)

In [None]:
# Compute VOC-Style mAP @ IoU=0.5
# Running on 10 images. Increase for better accuracy.
import time
sample_size = 1

image_ids = np.random.choice(dataset_val.image_ids, sample_size)
APs = []

item = 0
start = time.time()
print("Calculating mAP on {} items".format(sample_size))
for image_id in image_ids:
    if(item % 10 == 0):
        end = time.time()
        print("{} percent done".format(100*item/sample_size))
        print("{} seconds per item".format((end-start)/10))
        start = time.time()
    item = item+1
    #Load image and ground truth data
    image_org, image_meta, gt_class_id, gt_bbox, gt_mask, gt_landmark =\
    modellib.load_image_gt(dataset_val, inference_config, 
                           image_id, use_mini_mask=False)

    molded_images = np.expand_dims(modellib.mold_image(image_org, inference_config), 0)
    # Run object detection
    results = model.detect([image_org], verbose=0)
    r = results[0]
    
    # Compute AP
    AP, precisions, recalls, overlaps =\
        utils.compute_ap(gt_bbox, gt_class_id, gt_mask,
                         r["rois"], r["class_ids"], r["scores"], r['masks'])
    APs.append(AP)
    
print("mAP: ", np.mean(APs))

## Compare Class Accuracy
### Setup data tables

In [None]:
import pandas as pd
import seaborn as sns
import ast

data = pd.read_csv('val_class_predictions.csv')
preds = data.drop(columns={'Unnamed: 0'})

data = pd.read_csv('val_gt.csv')
truth = data.drop(columns={'Unnamed: 0'})

#Convert the string representation of lists into actual lists
truth['as_list'] = truth.apply(lambda x: ast.literal_eval(x['gt']),axis=1)

#Add the number of clothes in each image for convenience 
truth['gt_item_count'] = truth.apply(lambda x: len(x['as_list']),axis=1)

#Join the lists
table = preds.join(truth)

#Check it
table.head()

### Single item detection
Some statistics when there is only one article of clothing in the image to detect, for simplicitys sake

In [None]:
# Get single-item images
single = table[table['gt_item_count'] == 1]

#Pick the items from the list, otherwise we can't use groupby
single['gt'] = single.apply(lambda x: x['as_list'][0],axis=1)

In [None]:
# For each type of cloth, check what the network believes it to be on average
average_confidence = single.groupby('gt').mean().drop(columns = {'gt_item_count'})

#Use names instead of numbers for predictions
average_confidence.columns = dataset_val.class_names

#The BG column is useless
average_confidence = average_confidence.drop(columns="BG")

#Use names instead of numbers for gt_classes
average_confidence.index = average_confidence.columns.values

In [None]:
#Visualize it
sns.set()
fig, ax = plt.subplots(figsize = (15,10))
sns.heatmap(average_confidence, annot=True, ax = ax)
ax.set_ylim(len(average_confidence),0)
plt.show()

In [None]:
print("Number of items in each class")
fix, ax = plt.subplots(figsize = (15,10))
counts = single.groupby('gt').count()[['gt_item_count']]
counts.index = average_confidence.columns.values

sns.barplot(x=counts.index, y = counts.gt_item_count)
plt.xticks(rotation=45)
plt.show()