# Instance Segmentation with Mask Region Based Convolutional Neural Network.

The Mask R-CNN has been trained on the coco datset and and there are a number of classes that occur in the Davis 16 training sequences.

Bear,
Bus,
Car,
Dog,
Elephant,
Horse,
Motor Cycle,
Kite and
Train.



The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each image, and should be in 0-1 range. Different images can have different sizes.


During inference, the model requires only the input tensors, and returns the post-processed predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as follows:

- boxes (Tensor[N, 4]): the predicted boxes in [x0, y0, x1, y1] format, with values between 0 and H and 0 and W

- labels (Tensor[N]): the predicted labels for each image

-  scores (Tensor[N]): the scores or each prediction

-  masks (Tensor[N, H, W]): the predicted masks for each instance, in 0-1 range. In order to obtain the final segmentation masks, the soft masks can be thresholded, generally with a value of 0.5 (mask >= 0.5)

In [1]:
# import necessary libraries
from PIL import Image
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
import torchvision
import torch.utils.data as data
import random
import time
import os
from typing import Any, Callable, cast, Dict, List, Optional, Tuple
import numpy as np
import cv2 as cv


from skimage.morphology import disk



 # These are the classes that are available in the COCO-Dataset
COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]





## Download the model

In [None]:
# get the pretrained model from torchvision.models
# Note: pretrained=True will get the pretrained weights for the model.
# model.eval() to use the model for inference
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()

Downloading: "https://download.pytorch.org/models/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth" to /Users/Papa/.cache/torch/hub/checkpoints/maskrcnn_resnet50_fpn_coco-bf2d0c1e.pth
1.3%

In [None]:
# We will use the following colors to fill the pixels
colours = [[0, 255, 0],
           [0, 0, 255],
           [255, 0, 0],
           [0, 255, 255],
           [255, 255, 0],
           [255, 0, 255],
           [80, 70, 180],
           [250, 80, 190],
           [245, 145, 50],
           [70, 150, 250],
           [50, 190, 190]]

In [None]:
""" Utilities for computing, reading and saving DAVIS benchmark evaluation."""

def get_contours(mask, bound_th=0.5):
    """ Returns one pixel wide contours of mask as a bit ndarray
    :param mask: A ground truth or predicted 2D (grayscale) image mask (H x W )
    :param bound_th: The pixel threshold value for a one or zero assignment
    """

    # Tests the mask is only 2 dimensional H X W. Raise AssertionError otherwise
    assert len(mask.shape) == 2, f"Mask should be 2D (HxW) but got {mask.shape}"
    # Convert image to true binary based on the threshold and find contours - all points
    ret, thresh = cv.threshold(mask, int(bound_th * 1), 1, 0)
    # Set approximation as CHAIN_APPROX_NONE to keep all points (at slower speed)
    contours = cv.findContours(thresh, cv.RETR_TREE, cv.CHAIN_APPROX_SIMPLE)[0]

    # Create a blank canvas the same shape as the input mask and overlay the contours
    cont_matte = np.zeros_like(mask)
    for x in contours:
        for arr in x:
            cont_matte[arr[0][1], arr[0][0]] = 1

    # plt.imshow(cont_matte, cmap='gray')
    # plt.show()
    return cont_matte
def eval_boundary(foreground_mask, gt_mask, bound_th=0.008):
    """
    Compute mean,recall and decay from per-frame evaluation.
    Calculates precision/recall for boundaries between foreground_mask and
    gt_mask.
    :param foreground_mask: (ndarray) binary segmentation image.
    :param gt_mask:         (ndarray): binary annotated image.
    Returns:
        F (float): boundaries F-measure
        P (float): boundaries precision
        R (float): boundaries recall

    Based on github fperazzi/davis but using openCV functions for finding contours
    and dilation to significantly improve speed
    """
    # tests the mask is only 2 dimensional H X W. Raise AssertionError otherwise
    assert len(foreground_mask.shape) == 2, f"Foreground mask should be 2D (HxW) but got {foreground_mask.shape}"

    # Get the pixel boundaries of both masks
    fg_boundary = get_contours(foreground_mask, bound_th=0.1)
    gt_boundary = get_contours(gt_mask)

    # Get a disk radius proportional to the size of image and
    # dilate contours proportionally
    bound_pix = bound_th if bound_th >= 1 else \
        np.ceil(bound_th * np.linalg.norm(foreground_mask.shape))
    fg_dil = cv.dilate(fg_boundary, disk(bound_pix))
    gt_dil = cv.dilate(gt_boundary, disk(bound_pix))

    # Get the intersection
    gt_match = gt_boundary * fg_dil
    fg_match = fg_boundary * gt_dil

    # Area of the intersection
    n_fg = np.sum(fg_boundary)
    n_gt = np.sum(gt_boundary)

    # % Compute precision and recall
    if n_fg == 0 and n_gt > 0:
        precision = 1
        recall = 0
    elif n_fg > 0 and n_gt == 0:
        precision = 0
        recall = 1
    elif n_fg == 0 and n_gt == 0:
        precision = 1
        recall = 1
    else:
        precision = np.sum(fg_match) / float(n_fg)
        recall = np.sum(gt_match) / float(n_gt)

    # Compute F measure
    if precision + recall == 0:
        F = 0
    else:
        F = 2 * precision * recall / (precision + recall)
    return F, precision, recall


def eval_iou(foreground_mask, gt_mask):

    """ Compute region similarity (intersection over union IoU as the Jaccard Index.
    As github fperazzi/davis but with np.bool operations removed as they have been dropped.
    Mask values should be 0 or 1
    :params foreground_mask (ndarray): Predicted 2D binary annotation mask - HxW of form cv.CV_8UC1.
            gt_mask (ndarray): Provided 2D binary annotation mask - HxW of form cv.CV_8UC1.
    :returns jaccard (float): region similarity

    """

    # test for a blank image  and avoid div by 0 err
    if np.isclose(np.sum(gt_mask), 0) and np.isclose(np.sum(foreground_mask), 0):
        return 1
    else:
        # &, |  elementwise operators
        return np.sum((gt_mask & foreground_mask)) / np.sum((gt_mask | foreground_mask))

## Running inference on an sequences

# Set up test dataset and datloaders

In [None]:
# class ImglistToTensor(torch.nn.Module):
#     """
#     Converts a list of PIL images in the range [0,255] to a torch.FloatTensor
#     of shape (NUM_IMAGES x CHANNELS x HEIGHT x WIDTH) in the range [0,1].
#     Can be used as first transform for ``VideoFrameDataset``.
#     """
#     @staticmethod
#     def forward(img_list: List[Image.Image]) -> 'torch.Tensor[NUM_IMAGES, CHANNELS, HEIGHT, WIDTH]':
#         """
#         Converts each PIL image in a list to
#         a torch Tensor and stacks them into
#         a single tensor.
#         Args:
#             img_list: list of PIL images.
#         Returns:
#             tensor of size ``NUM_IMAGES x CHANNELS x HEIGHT x WIDTH``
#         """
#         return torch.stack([transforms.functional.to_tensor(pic) for pic in img_list])

In [None]:
from torchvision.datasets.folder import IMG_EXTENSIONS
from torchvision.datasets.folder import default_loader, has_file_allowed_extension
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from torchvision.datasets.vision import VisionDataset
import os
import torchvision


class SelectImageFolder(torchvision.datasets.DatasetFolder):
    """ """

    def __init__(
            self,
            root: str,
            chosen_classes: list,
            extensions: Optional[Tuple[str, ...]] = None,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
            loader: Callable[[str], Any] = default_loader,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.extensions = IMG_EXTENSIONS if is_valid_file is None else None

        classes, class_to_idx = self.find_classes(self.root, chosen_classes)
        samples = self.make_dataset(self.root, class_to_idx, self.extensions, is_valid_file)

        self.loader = loader
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]


    @staticmethod
    def make_dataset(
            directory: str,
            class_to_idx: Dict[str, int],
            extensions: Optional[Tuple[str, ...]] = None,
            is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
        """Generates a list of samples of a form (path_to_sample, class).
            Overwrites parent function - See :class:`DatasetFolder` for details"""

        if class_to_idx is None:
            raise ValueError("The class_to_idx parameter cannot be None.")
        return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)

    def find_classes(self, directory: str, chosen_classes: list) -> Tuple[List[str], Dict[str, int]]:
        """Find the class folders in a dataset folder structure
        Overwrites parent function - See :class:`DatasetFolder` for details"""
        return find_classes(directory, chosen_classes)


def find_classes(directory: str, chosen_classes: list):
    """Finds the class folders in a dataset. This is an over load function to
    to allow for customisation
    Tuple[List[str], Dict[str, int]
  See :class:`DatasetFolder` for details. 
  """
    classes = sorted(
        entry.name for entry in os.scandir(directory) if entry.is_dir() and entry.name in chosen_classes)
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx


def make_dataset(directory: str,
                 class_to_idx: Optional[Dict[str, int]] = None,
                 extensions: Optional[Union[str, Tuple[str, ...]]] = None,
                 is_valid_file: Optional[Callable[[str], bool]] = None,
                 ) -> List[Tuple[str, int]]:
    """Generates a list of samples of a form (path_to_sample, class).
    See :class:`DatasetFolder` for details.
    Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
    by default.
    """
    directory = os.path.expanduser(directory)

    if class_to_idx is None:
        _, class_to_idx = find_classes(directory)
    elif not class_to_idx:
        raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

    both_none = extensions is None and is_valid_file is None
    both_something = extensions is not None and is_valid_file is not None
    if both_none or both_something:
        raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

    if extensions is not None:
        def is_valid_file(x: str) -> bool:
            return has_file_allowed_extension(x, extensions)  # type: ignore[arg-type]

    is_valid_file = cast(Callable[[str], bool], is_valid_file)

    instances = []
    available_classes = set()
    for target_class in sorted(class_to_idx.keys()):
        class_index = class_to_idx[target_class]
        target_dir = os.path.join(directory, target_class)
        if not os.path.isdir(target_dir):
            continue
        for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
            for fname in sorted(fnames):
                path = os.path.join(root, fname)
                if is_valid_file(path):
                    item = path, class_index
                    instances.append(item)

                    if target_class not in available_classes:
                        available_classes.add(target_class)

    empty_classes = set(class_to_idx.keys()) - available_classes
    if empty_classes:
        msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
        if extensions is not None:
            msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
        raise FileNotFoundError(msg)

    return instances

# find_classes(PROJECT_ROOT_DIR + "JPEGImages/480p/", 'bear')

In [None]:
import sys
import os
# If Notebook is running in Colab change location of images to google drive

MULTI_PROCESSOR = False
WORKERS = 1
ON_COLAB = "google.colab" in sys.modules

if ON_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    PROJECT_ROOT_DIR = "drive/MyDrive/DAVIS/DAVIS/"
    MULTI_PROCESSOR = True
    WORKERS = 8
else:
    PROJECT_ROOT_DIR = "."




In [None]:
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

def label_cnt(targets, classes):
  """ Given the set of labels for a Torch data loader,
  returns the number of each label as a print string"""
  label_cnt = {}
  for item in targets:
    label_cnt[classes[item]] = label_cnt.get(classes[item],0) + 1
  cnt = ''
  for key,value in label_cnt.items():
    cnt += f"{str(key)}: {str(value)}, "
  return cnt

# Transform data to fixed size, tensor and normalised to the ImageNet means
seq_folder =  ['bear','bus','car-roundabout','car-turn','dog','dog-agility',
                  'drift-turn','horsejump-low', 'kite-surf', 'motor-bike', 'train']
# seq_folder =  ['bear']                 
def_transforms = transforms.Compose([
    # transforms.Resize(IMG_SIZE),
    transforms.ToTensor()
    #transforms.Normalize(mean = [0.485,0.456,0.406], std = [0.229,0.224,0.225])
    # scalar
])

gt_transforms = transforms.Compose([
    # transforms.Resize(IMG_SIZE),
    transforms.Grayscale(),
    transforms.ToTensor()
    #transforms.Normalize(mean = [0.485,0.456,0.406], std = [0.229,0.224,0.225])
    # scalar
])



test_data_path = PROJECT_ROOT_DIR + "JPEGImages/480p/"
gt_mask_path = PROJECT_ROOT_DIR + "Annotations/480p/"
results = {}
for seq in seq_folder:
  results[seq]={'f':[], 'j':[]}
  # Adjust to alter mini Batch size
  batch_size = 32

  # Define test data  and gt_mask data sets 
  test_data = SelectImageFolder(
      root = test_data_path,
      chosen_classes = seq,
      transform = def_transforms
  )
  gt_mask_data = SelectImageFolder(
      root=gt_mask_path,
      chosen_classes = seq,
      transform = gt_transforms
  )

  # Create train, and truth data loaders
  test_loader = data.DataLoader(test_data, batch_size = 1,shuffle=False)
  gt_mask_loader = data.DataLoader(gt_mask_data, batch_size = 1,shuffle=False)
  print(f'{len(test_loader.dataset)} test samples.{label_cnt(test_data.targets, test_data.classes)}')
  print(f'{len(gt_mask_loader.dataset)} annotation.{label_cnt(gt_mask_data.targets, gt_mask_data.classes)}')
  # loop through each test sequence and measure metrics
  test_ex = iter(test_loader)
  truth = iter(gt_mask_loader)
  for idx in range(len(test_loader)):
    # test_batch = next(test_ex)
    # gt_mask_batch = next(truth)

    frames, targets= test_ex.next()
    gt_masks, _ = truth.next()
    # frames.to(device)
    # targets.to(device)

    
    # Inference 
    if torch.cuda.is_available():
        preds = model(frames.cuda())
        # Determine the most likely mask index
        max_score_idx = np.argmax(preds[0]['scores'][0].detach().cpu().numpy())

        # calculate the measurements and add to dict
        mask = (preds[0]['masks'][max_score_idx]>0.5).squeeze().detach().cpu().numpy().astype(np.uint8)

        # frame = (frames.squeeze().detach().permute(1,2,0).numpy())
        gt_mask = (gt_masks.squeeze().detach().cpu().numpy()).astype(np.uint8)
    
    
    else:
        preds = model(frames)
      
        # Determine the most likely mask index
        max_score_idx = np.argmax(preds[0]['scores'][0].detach().numpy())

        # calculate the measurements and add to dict
        mask = (preds[0]['masks'][max_score_idx]>0.5).squeeze().detach().numpy().astype(np.uint8)

        # frame = (frames.squeeze().detach().permute(1,2,0).numpy())
        gt_mask = (gt_masks.squeeze().detach().numpy()).astype(np.uint8)
        
        # plt.imshow(frame)
        # plt.show()

        # plt.imshow(gt_mask, cmap='gray')
        # plt.show()

    results[seq]['f'].append(eval_boundary(mask, gt_mask, bound_th=0.008))
    results[seq]['j'].append(eval_iou(mask, gt_mask))







In [None]:
print(device)

In [None]:
# Calculate average IoU & Boundary values 
print(f'Sequence{" ": <15s}J{" ": <10s}F{" ": <10s}J&F{" ": <8s}Best Frame{" ": <8s}Worse frame')
for seq, scores in results.items():
  samples = len(scores['j'])

  print(
    f'{seq: <18s}    {sum(scores["j"])/samples:.2f}',
    f'      {sum(scores["f"][0])/samples:.2f}',
    f'      {(sum(scores["j"])/samples + sum(scores["f"][0])/samples)/2:.2f}',
    f'          {np.argmax(scores["f"])}{"  ": <3s}',
    f'\t\t{np.argmin(scores["f"])}')

In [None]:
# # We will keep only the pixels with values  greater than 0.5 as 1, and set the rest to 0.
# print(pred[0]['masks'].shape)
# print(pred[0]['masks'].squeeze().shape)
# print(pred[0]['masks'].squeeze().detach().shape)
# print(pred[0]['masks'].squeeze().detach().numpy().shape)
# print(pred[0]['masks'].squeeze().detach().numpy())
# masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()
# print(masks.shape)
# print(masks)
# print(pred[0]['labels'].item())

In [None]:
# # Let's plot the mask for the `person` class since the 0th mask belongs to `person`
# plt.imshow(masks, cmap='gray')
# plt.show()

In [None]:
# # Let's color the `person` mask using the `random_colour_masks` function
# mask1 = random_colour_masks(masks)
# plt.imshow(mask1)
# plt.show()

In [None]:
# # Let's blend the original and the masked image and plot it.
# blend_img = cv2.addWeighted(np.asarray(img), 0.5, mask1, 0.5, 0)

# plt.imshow(blend_img)
# plt.show()

Let's create some helper functions.   
We will create the `random_colour_masks()` function to fill the predicted-mask with colors, `get_predictions()` to return the final predictions from the model and finally the `instance_segmentation_api()` to overlay the colored mask over the original image and plot it.

In [None]:

def random_colour_masks(image):
    """
    random_colour_masks
    parameters:
      - image - predicted masks
    method:
      - the masks of each predicted object is given random colour for visualization
    """
    colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0,10)]
    coloured_mask = np.stack([r, g, b], axis=2)
    return coloured_mask

def get_prediction(img_path, threshold):
    """
    get_prediction
    parameters:
      - img_path - path of the input image
    method:
      - Image is obtained from the image path
      - the image is converted to image tensor using PyTorch's Transforms
      - image is passed through the model to get the predictions
      - masks, classes and bounding boxes are obtained from the model and soft masks are made binary(0 or 1) on masks
        ie: eg. segment of cat is made 1 and rest of the image is made 0

    """
    img = Image.open(img_path)
    transform = T.Compose([T.ToTensor()])
    img = transform(img)
    pred = model([img])
    pred_score = list(pred[0]['scores'].detach().numpy())
    print("pred_score: ", pred_score)
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
    print("pred_t: ", pred_t)
    masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
    masks = masks[:pred_t+1]
    pred_boxes = pred_boxes[:pred_t+1]
    print(pred_boxes)
    pred_class = pred_class[:pred_t+1]
    print(pred_class)
    return masks, pred_boxes, pred_class


def instance_segmentation_api(img_path, threshold=0.5, rect_th=3, text_size=3, text_th=3):
    """
    instance_segmentation_api
    parameters:
      - img_path - path to input image
    method:
      - prediction is obtained by get_prediction
      - each mask is given random color
      - each mask is added to the image in the ration 1:0.8 with opencv
      - final output is displayed
    """
    masks, boxes, pred_cls = get_prediction(img_path, threshold)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    for i in range(len(masks)):
        rgb_mask = random_colour_masks(masks[i])
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
        # cv2.rectangle(img, int( boxes[i][0]), int(boxes[i][1]),color=(0, 255, 0), thickness=rect_th)
        cv2.rectangle(img, (int( boxes[i][0][0]),int( boxes[i][0][1]) ), 
                      (int( boxes[i][1][0]),int( boxes[i][1][1]) ),
                      color=(0, 255, 0), thickness=rect_th)
        cv2.putText(img,pred_cls[i], (int( boxes[i][0][0]),int( boxes[i][0][1]) ),
                    cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
    plt.figure(figsize=(20,30))
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.show()