Source for training pipiline:
https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html

In [27]:
import torch
import torchvision
import os
import cv2
import numpy as np
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import matplotlib.pyplot as plt
import albumentations as A
from tqdm import tqdm
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from PIL import Image
import random
from torchvision import transforms as T
import cv2

In [14]:
# download training and evaluating files
# os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/engine.py")
# os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/utils.py")
# os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_utils.py")
# os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/coco_eval.py")
# os.system("wget https://raw.githubusercontent.com/pytorch/vision/main/references/detection/transforms.py")

In [20]:
# dataset directories
train_dir = './dataset2/train'
valid_dir = './dataset2/valid'
test_dir = './dataset2/test'

In [23]:
# create custom coral dataset class for training model
coral_classes = {
    0: "Coral-growth-forms",
    1: "Encrusting",
    2: "Parascolymia",
    3: "branching",
    4: "foliose",
    5: "massive",
    6: "mushroom",
    7: "phaceolid",
    8: "submassive"
}

class CoralDataset(Dataset):
    
    def __init__(self, mode = 'train', augmentation=None):
        if mode == 'train':
            self.dataset_path = train_dir
            ann_path = os.path.join(train_dir, '_annotations.coco.json')
        if mode == 'valid':
            self.dataset_path = valid_dir
            ann_path = os.path.join(valid_dir, '_annotations.coco.json')
        if mode == 'test':
            self.dataset_path = test_dir
            ann_path = os.path.join(test_dir, '_annotations.coco.json')

        self.coco = COCO(ann_path)
        self.cat_ids = self.coco.getCatIds()
        self.augmentation=augmentation

        
    def __len__(self):
        return len(self.coco.imgs)
  

    def get_masks(self, index):
        ann_ids = self.coco.getAnnIds([index])
        anns = self.coco.loadAnns(ann_ids)
        masks=[]

        for ann in anns:
            mask = self.coco.annToMask(ann)
            masks.append(mask) 
        return masks

    
    def get_boxes(self, masks):
        num_objs = len(masks)
        boxes = []

        for i in range(num_objs):
            x,y,w,h = cv2.boundingRect(masks[i])
            boxes.append([x, y, x+w, y+h])

        return np.array(boxes)

    
    def get_labels(self, index, num_objs):
        ann_ids = self.coco.getAnnIds([index])
        cat_id = self.coco.loadAnns(ann_ids)[0]['category_id']
        return [cat_id] * num_objs
      

    def __getitem__(self, index):
    
        img_info = self.coco.loadImgs([index])[0]
        image = cv2.imread(os.path.join(self.dataset_path,
                                    img_info['file_name']))
        masks = self.get_masks(index)

        if self.augmentation:
            augmented = self.augmentation(image=image, masks=masks)
            image, masks = augmented['image'], augmented['masks']

        image = image.transpose(2,0,1) / 255.

        num_objs = len(masks)
        masks = np.array(masks)
        boxes = self.get_boxes(masks)
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        labels = np.array(self.get_labels(index, num_objs))

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        image = torch.as_tensor(image, dtype=torch.float32)

        data = {}
        data["boxes"] =  boxes
        data["labels"] = labels
        data["masks"] = masks
        data["image_id"] = index
        data["iscrowd"] = torch.zeros((num_objs,), dtype=torch.int64)
        data["area"] = area

        return image, data
    

def collate_fn(batch):
    images = list()
    targets = list()
    
    for b in batch:
        images.append(b[0])
        targets.append(b[1])
        
    images = torch.stack(images, dim=0)
    return images, targets


transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomBrightnessContrast(
        contrast_limit=0.2, brightness_limit=0.3, p=0.5),
    A.OneOf([
        A.ImageCompression(p=0.8),
        A.RandomGamma(p=0.8),
        A.Blur(p=0.8),
        A.Equalize(mode='cv',p=0.8)
    ], p=1.0),
    A.OneOf([
        A.ImageCompression(p=0.8),
        A.RandomGamma(p=0.8),
        A.Blur(p=0.8),
        A.Equalize(mode='cv',p=0.8),
    ], p=1.0)
])

In [25]:
# create a pretrained MRCNN model
def get_model_instance_segmentation(num_classes):
    
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights="DEFAULT")

    in_features = model.roi_heads.box_predictor.cls_score.in_features

    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256

    model.roi_heads.mask_predictor = MaskRCNNPredictor(
        in_features_mask,
        hidden_layer,
        num_classes
    )

    return model

In [7]:
from engine import train_one_epoch, evaluate

# train on the GPU or on the CPU, if a GPU is not available
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has 1 background + 8 coral species
num_classes = 9

# use our dataset and defined transformations
dataset = CoralDataset(mode='train', augmentation=transform)
dataset_test = CoralDataset(mode='valid')

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=8,
    shuffle=True,
    pin_memory=True,
    collate_fn=collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    dataset=dataset_test,
    batch_size=8,
    shuffle=False,
    pin_memory=True,
    collate_fn=collate_fn
)

# get the model using our helper function
# model = get_model_instance_segmentation(num_classes)

# load a pretrained model
model = get_model_instance_segmentation(9)
weights = torch.load('./mrcnn_model.pth')
model.load_state_dict(weights)
model.eval()

# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

# let's train it just for # epochs
num_epochs = 20

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

print("That's it!")

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
Epoch: [0]  [ 0/23]  eta: 0:06:17  lr: 0.000232  loss: 4.9395 (4.9395)  loss_classifier: 1.9666 (1.9666)  loss_box_reg: 0.1829 (0.1829)  loss_mask: 2.7758 (2.7758)  loss_objectness: 0.0078 (0.0078)  loss_rpn_box_reg: 0.0064 (0.0064)  time: 16.4244  data: 0.0927
Epoch: [0]  [10/23]  eta: 0:03:20  lr: 0.002503  loss: 1.6494 (2.4575)  loss_classifier: 0.4131 (0.8726)  loss_box_reg: 0.1663 (0.1691)  loss_mask: 1.0655 (1.4034)  loss_objectness: 0.0073 (0.0083)  loss_rpn_box_reg: 0.0033 (0.0042)  time: 15.4610  data: 0.0689
Epoch: [0]  [20/23]  eta: 0:00:45  lr: 0.004773  loss: 1.0045 (1.7303)  loss_classifier: 0.2019 (0.5497)  loss_box_reg: 0.1286 (0.1549)  loss_mask: 0.6129 (1.0039)  loss_objectness: 0.0145 (0.0166)  loss_rpn_box_reg: 0.0041 (0.0052)  time: 15.2399  data: 0.0695
Epoch: [0]  [22/23]  eta: 0:00:15  lr: 0.005000 

  torch.set_num_threads(1)


Test:  [0/3]  eta: 0:00:14  model_time: 4.8091 (4.8091)  evaluator_time: 0.0788 (0.0788)  time: 4.9162  data: 0.0283
Test:  [2/3]  eta: 0:00:04  model_time: 4.7648 (4.7734)  evaluator_time: 0.0720 (0.0724)  time: 4.8739  data: 0.0253
Test: Total time: 0:00:14 (4.8740 s / it)
Averaged stats: model_time: 4.7648 (4.7734)  evaluator_time: 0.0720 (0.0724)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.110
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.262
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.051
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = -1.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.110
 Average Recall     (AR) @[ IoU=0.50:0.95 | ar

In [30]:
# save model
# PATH = './mrcnn_model.pth'
# torch.save(model.state_dict(), PATH)

In [28]:
# visualize predictions
model.eval()

def get_prediction(model, img_path, threshold=0.5):

    img = Image.open(img_path) 


    transform = T.Compose([T.ToTensor()]) 
    img = transform(img)

    pred = model([img]) 

    pred_score = list(pred[0]['scores'].detach().cpu().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
    masks = (pred[0]['masks'] > 0.5).squeeze().detach().cpu().numpy()
    pred_class = [coral_classes[i] for i in list(pred[0]['labels'].cpu().numpy())]
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().cpu().numpy())]

    masks = masks[:pred_t+1]
    pred_boxes = pred_boxes[:pred_t+1]
    pred_class = pred_class[:pred_t+1]
    pred_score = pred_score[:pred_t+1]
    
    masks = torch.as_tensor(masks, dtype=torch.uint8)
    
    bboxes = torchvision.ops.masks_to_boxes(masks)
    keep = torchvision.ops.nms(bboxes, torch.as_tensor(pred_score, dtype=torch.float32), 0.2)
    keep = keep.numpy()
    
    temp_mask = []
    temp_boxes = []
    temp_class = []
    
    for idx in keep:
        temp_mask.append(masks[idx])
        temp_boxes.append(pred_boxes[idx])
        temp_class.append(pred_class[idx])
        
    return temp_mask, temp_boxes, temp_class

def random_color_masks(image):
    
    colors = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180], [250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]

    r = np.zeros_like(image).astype(np.uint8)
    g = np.zeros_like(image).astype(np.uint8)
    b = np.zeros_like(image).astype(np.uint8)

    r[image==1], g[image==1], b[image==1] = colors[random.randrange(0, 10)]
    colored_mask = np.stack([r,g,b], axis=2)


    return colored_mask

def instance_segmentation(img_path, threshold=0.18, rect_th=3, text_size=1, text_th=2):
    masks, boxes, pred_class = get_prediction(model, img_path, threshold=threshold)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    for i in range(len(masks)):
        rgb_mask = random_color_masks(masks[i])
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
        pt1 = tuple(int(x) for x in boxes[i][0])
        pt2 = tuple(int(x) for x in boxes[i][1])
        cv2.rectangle(img, pt1, pt2, color=(0, 0, 0), thickness=rect_th)
        cv2.putText(img, pred_class[i], pt1, cv2.FONT_HERSHEY_SIMPLEX, text_size, (0, 255, 0), thickness=text_th)
        
    return img, pred_class, masks[i]