# 1. Install dependences

In [1]:
%%shell
pip install cython
pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
pip install albumentations
git clone https://github.com/pytorch/vision.git
cd vision
cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

# 2. Imports

In [None]:
import os
import numpy as np
import torch
import torch.utils.data
from PIL import Image
import matplotlib.pyplot as plt
import glob
import cv2
from torchvision.transforms import functional as F
from torch.utils.data import Dataset
import albumentations as TAA
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from engine import train_one_epoch, evaluate
import utils
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

## 3. Load dataset images from your Google Drive
Or you can download it here: https://zenodo.org/record/4041488#.YEzSwy2708Z

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')
# !cp drive/MyDrive/dataset.zip ./
# !unzip dataset.zip

# 4. Dataset and Augmentations

In [11]:
def get_transform(train):
    if train:
        return TAA.Compose([
            TAA.Blur(blur_limit=10),
            TAA.ShiftScaleRotate(shift_limit=0.1, rotate_limit=30, interpolation=1, border_mode=0),
            TAA.HorizontalFlip(),
            TAA.VerticalFlip()
        ])
    else:
        return None


class FoodDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = sorted(glob.glob(f'{root}/img/*.png'))
        self.masks = sorted(glob.glob(f'{root}/mask/*.png'))

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        mask_path = self.masks[idx]
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(mask_path), cv2.COLOR_BGR2GRAY)
        if self.transforms:
            aug = self.transforms(image=img, mask=mask)
            img = aug['image']
            mask = aug['mask']

        # countour detection
        obj_ids = np.unique(mask)
        obj_ids = obj_ids[1:]
        num_objs = len(obj_ids)
        boxes = []
        masks = []

        for u in obj_ids:
            mask_unique = np.copy(mask)
            mask_unique[mask_unique != u] = 0
            contours, b = cv2.findContours(mask_unique, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
            if len(contours) > 1:
                while True:
                    mask_unique = np.copy(mask)
                    mask_unique[mask_unique != u] = 0
                    id_zerar = np.random.randint(len(contours))
                    xmin = np.min(contours[id_zerar][:, 0, 0])
                    xmax = np.max(contours[id_zerar][:, 0, 0])
                    ymin = np.min(contours[id_zerar][:, 0, 1])
                    ymax = np.max(contours[id_zerar][:, 0, 1])

                    # exclui os outros objetos
                    mm = np.zeros(mask_unique.shape, np.bool)
                    mm[ymin:ymax, xmin:xmax] = 1
                    mask_unique *= mm
                    if len(np.unique(mask_unique)) > 1:
                        break

            try:
              mask_unique = np.clip(mask_unique, 0, 1)
              masks.append(mask_unique)
              pos = np.where(mask_unique)
              xmin = np.min(pos[1])
              xmax = np.max(pos[1])
              ymin = np.min(pos[0])
              ymax = np.max(pos[0])
              boxes.append([xmin, ymin, xmax, ymax])
            except:
              print(mask_path)
              plt.imshow(mask_unique)
              plt.show()

              plt.imshow(mask) 
              plt.show()
              print(img_path)
              print(mask_unique)
              print(obj_ids)
              print(mask)
              print(len(contours))

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(obj_ids, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd
        return F.to_tensor(img), target

    def __len__(self):
        return len(self.imgs)

# 5. Dataloader

In [None]:
dataset = FoodDataset('dataset', get_transform(train=True))
dataset_test = FoodDataset('dataset', get_transform(train=False))

indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=6, shuffle=True, num_workers=2,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=2,
    collate_fn=utils.collate_fn)

# 6. Load pretrained model

In [5]:
def get_instance_segmentation_model(num_classes):
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)
    return model
  
num_classes = 10
model = get_instance_segmentation_model(num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.9)

## 7. Training

In [None]:
num_epochs = 10
from engine import train_one_epoch, evaluate
for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    
    # save model at end of train step
    torch.save(model, 'model_split_1.pt')

    # update the learning rate
    lr_scheduler.step()
    
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

Now you can move the "model_split_1.pt" to root folder and use the webapp perform predictions