# Newspaper Segmentation

Network backbone is a Resnet101

In [None]:
import os
import re
import sys
import cv2
import math
import json
import time
import random
import skimage

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

# Root directory of the project
ROOT_DIR = os.path.abspath("")
sys.path.append(ROOT_DIR)

from mrcnn.config import Config
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log

MODEL_DIR = os.path.join(ROOT_DIR, "logs") # Directory to save logs and trained model
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5") # Local path to trained weights file

if not os.path.exists(COCO_MODEL_PATH): # Download COCO trained weights
    utils.download_trained_weights(COCO_MODEL_PATH)

## Configurations

In [None]:
class ShapesConfig(Config):
    NAME = "newspapers"
    GPU_COUNT = 1 # Train on 1 GPU and 8 images per GPU.
    IMAGES_PER_GPU = 8 # Batch size is 8 (GPUs * images/GPU).
    NUM_CLASSES = 1 + 3  # background + 3 shapes

    IMAGE_MIN_DIM = 128
    IMAGE_MAX_DIM = 128
    
    IMAGE_RESIZE_MODE = 'square' #square or pad64 or crop
    
    RPN_ANCHOR_SCALES = (8, 16, 32, 64, 128)  # anchor side in pixels to extract low resolution 
    TRAIN_ROIS_PER_IMAGE = 32
    
    STEPS_PER_EPOCH = 100  # Use a small epoch since the data is simple
    VALIDATION_STEPS = 5 # use small validation steps since the epoch is small
    
config = ShapesConfig()
config.display()

## Notebook Preferences

In [None]:
def get_ax(rows=1, cols=1, size=12):
    """Return a Matplotlib Axes"""
    _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
    return ax

## Dataset

* load_image()
* load_mask()
* image_reference()

In [None]:
class newspaperDataset(utils.Dataset):
    
    def load_data(self, dataset_dir, subset):
        """
        Load a subset of the newspaper dataset.
            
        dataset_dir: Root directory of the dataset.
        subset: Subset to load: train or val
        """

        self.add_class("newspaper", 0, "article")
        self.add_class("newspaper", 1, "non-article")
        self.add_class("newspaper", 2, "title")

        assert subset in ["train", "val"]
        dataset_dir = os.path.join(dataset_dir, subset)

        annotations = json.load(open(os.path.join(dataset_dir, "annotations.json")))
        filenames = list(annotations.keys())
        annotations = list(annotations.values())

        annotations = [a for a in annotations if a['regions']]

        # Add images
        for a in annotations:
            shapes = [r['shape_attributes'] for r in a['regions'].values()]
            classes = [r['region_attributes']['label'] for r in a['regions'].values()]

            image_path = os.path.join(dataset_dir, a['filename'])
            image = skimage.io.imread(image_path)
            height, width = image.shape[:2]

            self.add_image(
                "newspaper",
                image_id=a['filename'],  # use file name as a unique image id
                path=image_path,
                width=width, height=height,
                shapes=shapes,
                classes=classes)
            
        return filenames

    def load_mask(self, image_id):

        image_info = self.image_info[image_id]
        if image_info["source"] != "newspaper":
            return super(self.__class__, self).load_mask(image_id)

        info = self.image_info[image_id]
        classes = info['classes']
        
        mask = np.zeros([info["height"], info["width"], len(info["shapes"])],
                        dtype=np.uint8)
        
        for i, p in enumerate(info["shapes"]):
            if p['name'] == 'polygon':
                rr, cc = skimage.draw.polygon(p['all_points_y'], p['all_points_x'])
                mask[rr, cc, i] = 1
                
            if p['name'] == 'rectangle':
                start = (int(p['ymin']),int(p['xmin']))
                extent = (int(p['ymax']),int(p['xmax']))
                
                rr, cc = skimage.draw.rectangle(start=start, extent=extent, shape=mask.shape)
                mask[rr, cc, i] = 1
                
        class_ids = np.array([self.class_names.index(s) for s in classes])
        
        return mask, class_ids
    
    def image_reference(self, image_id):

        info = self.image_info[image_id]
        if info["source"] == "newspaper":
            return info["newspaper"]
        else:
            super(self.__class__).image_reference(self, image_id)

In [None]:
data_DIR = os.path.join(ROOT_DIR, "datasets/newspaper/")

dataset = newspaperDataset()
data_names = dataset.load_data(data_DIR, "train")
dataset.prepare()

In [None]:
dataset.image_ids

### Useful functions

In [None]:
" ============= Display some sample dataset informations ============="

for ids in dataset.image_ids[:1]:
    mask, class_ids = dataset.load_mask(ids)
    bbox = utils.extract_bboxes(mask)
    
    image_path = data_DIR+'/train/'+data_names[ids]
    image = skimage.io.imread(image_path)

    visualize.display_top_masks(image, mask, class_ids, dataset.class_names)
    visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names,figsize=(5, 5))
    
    image, window, scale, padding, _ = utils.resize_image(
        image, 
        min_dim=config.IMAGE_MIN_DIM, 
        max_dim=config.IMAGE_MAX_DIM,
        mode='square')
    mask = utils.resize_mask(mask, scale, padding)
    
    log("image", image)
    log("mask", mask)
    
    bbox = utils.extract_bboxes(mask)
    visualize.display_instances(image, bbox, mask, class_ids, dataset.class_names,figsize=(5, 5))