In [8]:
import os
import torch
from PIL import Image
import torchvision.transforms as T
import matplotlib.pyplot as plt
from torchvision.models import detection
import numpy as np
import random
import cv2
from skimage.io import imread
from collections import Counter
import pixellib
from pixellib.torchbackend.instance import instanceSegmentation
# from pixellib.instance import instance_segmentation

In [9]:
model_name = 'pointrend'  #'pointrend' 'mask-rcnn'
images_folder = './images/signal_images/'
detected_thresh = 0.4
dim = (1000,300)
CUDA_LAUNCH_BLOCKING=1
COCO_INSTANCE_CATEGORY_NAMES = open("coco_names.txt", "r").read().split(",")
# print(COCO_INSTANCE_CATEGORY_NAMES)

In [11]:
if model_name == 'mask-rcnn':
    model = detection.maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()
elif model_name == 'pointrend':
    model = instanceSegmentation()
    model.load_model("pointrend_resnet50.pkl", confidence = detected_thresh)


The checkpoint state_dict contains keys that are not used by the model:
  [35mproposal_generator.anchor_generator.cell_anchors.{0, 1, 2, 3, 4}[0m


In [12]:
def get_prediction(img_path, threshold):
    img = imread(img_path)
    # img = Image.open(img_path)
    image = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
    image = Image.fromarray(image)
    transform = T.Compose([T.ToTensor()])
    img = transform(image)
    pred = model([img])
    pred_score = list(pred[0]['scores'].detach().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x>threshold][-1]
    masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()
    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[i] for i in list(pred[0]['labels'].numpy())]
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().numpy())]
    masks = masks[:pred_t+1]
    pred_boxes = pred_boxes[:pred_t+1]
    pred_class = pred_class[:pred_t+1]
    return masks, pred_boxes, pred_class

In [13]:
def random_colour_masks(image):
    colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)
    r[image == 1], g[image == 1], b[image == 1] = colours[random.randrange(0,10)]
    coloured_mask = np.stack([r, g, b], axis=2)
    return coloured_mask

In [14]:
def instance_segmentation_api(img_path, img_num, threshold=0.5, rect_th=3, text_size=3, text_th=3):
    masks, boxes, pred_cls = get_prediction(img_path, threshold)
    # print(pred_cls)
    img = cv2.imread(img_path, 0)
    img = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    for i in range(len(masks)):
        rgb_mask = random_colour_masks(masks[i])
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)

#         cv2.rectangle(img, boxes[i][0], boxes[i][1],color=(0, 255, 0), thickness=rect_th)
#         cv2.putText(img,pred_cls[i], boxes[i][0], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
    # plt.figure(figsize=(10,10))
    # plt.imshow(img)
    # plt.xticks([])
    # plt.yticks([])
    # plt.show()
    cv2.imwrite('./results/images_segmented_{}_{}.png'.format(model_name, img_num), img)

In [26]:
data_list = os.listdir(images_folder)
pred = []

for i in range(len(data_list)):
    img_path = os.path.join(images_folder, data_list[i])
    if model_name == 'mask-rcnn':
        instance_segmentation_api(img_path, i)
    elif model_name == 'pointrend':
        img = cv2.imread(img_path, 0)
        img = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        model.segmentFrame(img, show_bboxes=True, output_image_name='./results/images_segmented_{}_{}.png'.format(model_name, i))

(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
(128, 2048)
