In [1]:
import sys
import os
path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if path not in sys.path:
    sys.path.insert(0, path)

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import math
import time
import pickle
from PIL import Image 
import zipfile

## Create dataset

In [3]:
import numbers
from collections.abc import Sequence
from torchvision.transforms.functional import gaussian_blur
from torchvision.transforms.functional import adjust_saturation, adjust_hue, adjust_contrast, adjust_brightness

In [4]:
torch.manual_seed(0)

def update_bbox(masks):
    boxes = []
    for mask in masks:
        if np.count_nonzero(mask) > 0:
            pos = torch.where(mask)
            xmin = torch.min(pos[1])
            xmax = torch.max(pos[1])
            ymin = torch.min(pos[0])
            ymax = torch.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
    return torch.Tensor(boxes)

class RandomHorizontalFlip(object):
    def __init__(self, p=0.5):
        self.p = p
    def __call__(self, image, target):
        if np.random.rand() > self.p:
            image = torch.flip(image, [2])
            target['masks'] = torch.flip(target['masks'], [2])
            target['boxes'] = update_bbox(target['masks'])
        return image, target
    
class RandomVerticalFlip(object):
    def __init__(self, p=0.5):
        self.p = p
    def __call__(self, image, target):
        if np.random.rand() > self.p:
            image = torch.flip(image, [1])
            target['masks'] = torch.flip(target['masks'], [1])
            target['boxes'] = update_bbox(target['masks'])
        return image, target

def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)
    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]
    if len(size) != 2:
        raise ValueError(error_msg)
    return size    
    
class GaussianBlur(torch.nn.Module):
    def __init__(self, kernel_size=5, sigma=(0.1, 2.0)):
        super().__init__()
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")
        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
            if not 0. < sigma[0] <= sigma[1]:
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")
        self.sigma = sigma

    def get_params(self, sigma_min: float, sigma_max: float) -> float:
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img, target):
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        img = gaussian_blur(img, self.kernel_size, [sigma, sigma])
        return img, target

    def __repr__(self):
        s = '(kernel_size={}, '.format(self.kernel_size)
        s += 'sigma={})'.format(self.sigma)
        return self.__class__.__name__ + s    

def _log_api_usage_once(obj):
    if not obj.__module__.startswith("torchvision"):
        return
    name = obj.__class__.__name__
    if isinstance(obj, FunctionType):
        name = obj.__name__
    torch._C._log_api_usage_once(f"{obj.__module__}.{name}")
    
class ColorJitter(torch.nn.Module):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        super().__init__()
        _log_api_usage_once(self)
        self.brightness = self._check_input(brightness, "brightness")
        self.contrast = self._check_input(contrast, "contrast")
        self.saturation = self._check_input(saturation, "saturation")
        self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)

    @torch.jit.unused
    def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError(f"If {name} is a single number, it must be non negative.")
            value = [center - float(value), center + float(value)]
            if clip_first_on_zero:
                value[0] = max(value[0], 0.0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError(f"{name} values should be between {bound}")
        else:
            raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")

        if value[0] == value[1] == center:
            value = None
        return value

    @staticmethod
    def get_params(brightness, contrast, saturation, hue):
        fn_idx = torch.randperm(4)
        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
        return fn_idx, b, c, s, h

    def forward(self, img, target):
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
            self.brightness, self.contrast, self.saturation, self.hue)

        for fn_id in fn_idx:
            if fn_id == 0 and brightness_factor is not None:
                img = adjust_brightness(img, brightness_factor)
            elif fn_id == 1 and contrast_factor is not None:
                img = adjust_contrast(img, contrast_factor)
            elif fn_id == 2 and saturation_factor is not None:
                img = adjust_saturation(img, saturation_factor)
            elif fn_id == 3 and hue_factor is not None:
                img = adjust_hue(img, hue_factor)

        return img, target

    def __repr__(self):
        format_string = self.__class__.__name__ + "("
        format_string += f"brightness={self.brightness}"
        format_string += f", contrast={self.contrast}"
        format_string += f", saturation={self.saturation}"
        format_string += f", hue={self.hue})"
        return format_string

    
class ToTensor(object):
    def __call__(self, image, target):
        image = torch.from_numpy(image).float()
        target['masks'] = torch.from_numpy(target['masks'])
        target['boxes'] = torch.from_numpy(target['boxes'])
        target['labels'] = torch.from_numpy(target['labels'])
        target['image_id'] = torch.from_numpy(target['image_id'])
        target['area'] = torch.from_numpy(target['area'])
        target['iscrowd'] = torch.from_numpy(target['iscrowd'])
        return image, target

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

def get_transform(training=True):
    transforms = []
    transforms.append(ToTensor())
    if training:
        None
        transforms.append(RandomHorizontalFlip(p=0.5))
        transforms.append(RandomVerticalFlip(p=0.5))
        transforms.append(GaussianBlur(kernel_size=9))
        transforms.append(ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1))
    return Compose(transforms)

In [5]:
from dataset import Dataset, create_datasets
import mask_rcnn_utils as utils
from coco_utils import get_coco_api_from_dataset
from coco_eval import CocoEvaluator

In [6]:
#create_datasets('data/rocklas/tiles')

In [7]:
dataset = Dataset(['data/rock/train_split.json'], 1000, transforms=get_transform(training=True))
dataset_valid = Dataset(['data/rock/valid_split.json'], 1000, transforms=get_transform(training=False))

In [8]:
print(len(dataset))
img, target = dataset[1]
print(torch.count_nonzero(target['masks'][1]))
print(target)
print(target['masks'].shape)
print(img.shape)

251
tensor(3741)
{'boxes': tensor([[900.,   0., 999.,  75.],
        [  0., 204.,  35., 345.],
        [ 84., 736., 154., 801.],
        [ 58., 801., 117., 840.],
        [  0., 747.,  91., 819.],
        [433., 464., 504., 532.],
        [338., 443., 459., 533.],
        [276., 469., 346., 562.],
        [ 77., 211., 162., 275.],
        [ 28., 163., 111., 273.],
        [150.,  44., 189.,  80.],
        [  5.,  22., 156., 182.],
        [ 55., 341.,  78., 356.],
        [107., 290., 124., 316.],
        [  0., 367.,  10., 395.],
        [688., 729., 823., 805.],
        [808., 718., 858., 774.],
        [715., 804., 845., 873.],
        [612., 832., 792., 898.],
        [568., 926., 710., 998.],
        [823., 159., 959., 265.],
        [636., 207., 750., 294.],
        [661., 178., 744., 230.],
        [360., 425., 409., 452.],
        [109., 390., 272., 599.],
        [355., 517., 448., 591.],
        [  0., 544.,  12., 616.],
        [ 66., 365., 115., 425.],
        [427., 406., 

In [9]:
dataset.show(10)

In [11]:
pre_computed_stats = True
if not pre_computed_stats:
    data_mean, data_std, data_max, data_min = dataset.imageStat(100)
else:
    data_mean = [0.5321785499092146, 0.5256574577876263, 0.5179107016311176]
    data_std = [0.24644456418805552, 0.2399634340866946, 0.23669749307655985]
    data_max = [1.0, 1.0, 1.0]
    data_min = [0.0, 0.0, 0.0]
print(data_mean)
print(data_std)
print(data_max)
print(data_min)

[0.5321785499092146, 0.5256574577876263, 0.5179107016311176]
[0.24644456418805552, 0.2399634340866946, 0.23669749307655985]
[1.0, 1.0, 1.0]
[0.0, 0.0, 0.0]


In [12]:
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=4, shuffle=False, num_workers=8,
    collate_fn=utils.collate_fn)

data_loader_valid = torch.utils.data.DataLoader(
    dataset_valid, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

## Train mask rcnn

In [13]:
import math
import sys
import time
import torch

import torchvision.models.detection.mask_rcnn


def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items() if not isinstance(v, str)} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print(loss_dict_reduced)
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])


def _get_iou_types(model):
    model_without_ddp = model
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model_without_ddp = model.module
    iou_types = ["bbox"]
    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
        iou_types.append("segm")
    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
        iou_types.append("keypoints")
    return iou_types


@torch.no_grad()
def evaluate(model, data_loader, device):
    n_threads = torch.get_num_threads()
    # FIXME remove this and make paste_masks_in_image run on the GPU
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types)

    for image, targets in metric_logger.log_every(data_loader, 100, header):
        image = list(img.to(device) for img in image)
        targets = [{k: v.to(device) for k, v in t.items() if not isinstance(v, str)} for t in targets]

        torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(image)

        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
        model_time = time.time() - model_time

        res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
        evaluator_time = time.time()
        coco_evaluator.update(res)
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    torch.set_num_threads(n_threads)
    return coco_evaluator


In [14]:
from mask_rcnn import get_model_instance_segmentation

In [15]:
#device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cuda:0')

In [16]:
num_classes = 2
input_c = 3
detections_per_img = 256
anchor_size = ((16,), (32,), (64,), (128,), (256,))

In [17]:
mask_rcnn = get_model_instance_segmentation(num_classes, input_channel=input_c, image_mean=data_mean, image_std=data_std, stats=True, detections_per_img=detections_per_img, anchor_size=anchor_size)



In [18]:
mask_rcnn.to(device)
params = [p for p in mask_rcnn.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=5,
                                               gamma=0.8)
init_epoch = 0
num_epochs = 30

In [21]:
save_param = "model/epoch_{:04d}.param".format(init_epoch)
torch.save(mask_rcnn.state_dict(), save_param)


for epoch in range(init_epoch, init_epoch + num_epochs):
    save_param = "model/epoch_{:04d}.param".format(epoch)
    print(save_param)
    train_one_epoch(mask_rcnn, optimizer, data_loader, device, epoch, print_freq=100)
    # update the learning rate
    lr_scheduler.step()
    evaluate(mask_rcnn, data_loader_valid, device=device)
    torch.save(mask_rcnn.state_dict(), save_param)

model/epoch_0000.param
Epoch: [0]  [ 0/63]  eta: 0:21:01  lr: 0.000086  loss: 8.1847 (8.1847)  loss_classifier: 0.7942 (0.7942)  loss_box_reg: 0.5031 (0.5031)  loss_mask: 2.7760 (2.7760)  loss_objectness: 3.8906 (3.8906)  loss_rpn_box_reg: 0.2207 (0.2207)  time: 20.0239  data: 7.4542  max mem: 3535
Epoch: [0]  [62/63]  eta: 0:00:01  lr: 0.005000  loss: 1.1909 (1.8456)  loss_classifier: 0.3055 (0.3574)  loss_box_reg: 0.3847 (0.3960)  loss_mask: 0.2446 (0.4879)  loss_objectness: 0.1324 (0.4564)  loss_rpn_box_reg: 0.0913 (0.1478)  time: 1.0387  data: 0.1904  max mem: 4357
Epoch: [0] Total time: 0:01:17 (1.2234 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:03:02  model_time: 0.6859 (0.6859)  evaluator_time: 0.4017 (0.4017)  time: 1.6726  data: 0.5553  max mem: 4357
Test:  [100/109]  eta: 0:00:06  model_time: 0.3151 (0.3491)  evaluator_time: 0.2618 (0.3269)  time: 0.7150  data: 0.0107  max mem: 4357
Test:  [108/109]  eta: 0:00:00  model_time: 0.3140 (0.3475)  evaluator_t

model/epoch_0003.param
Epoch: [3]  [ 0/63]  eta: 0:09:53  lr: 0.005000  loss: 1.1222 (1.1222)  loss_classifier: 0.2696 (0.2696)  loss_box_reg: 0.4020 (0.4020)  loss_mask: 0.2542 (0.2542)  loss_objectness: 0.1105 (0.1105)  loss_rpn_box_reg: 0.0859 (0.0859)  time: 9.4241  data: 8.3962  max mem: 4476
Epoch: [3]  [62/63]  eta: 0:00:01  lr: 0.005000  loss: 0.9073 (0.9515)  loss_classifier: 0.2608 (0.2557)  loss_box_reg: 0.2956 (0.3057)  loss_mask: 0.2229 (0.2289)  loss_objectness: 0.0669 (0.0784)  loss_rpn_box_reg: 0.0635 (0.0827)  time: 0.9236  data: 0.0483  max mem: 4494
Epoch: [3] Total time: 0:01:08 (1.0878 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:03:15  model_time: 0.6090 (0.6090)  evaluator_time: 0.4749 (0.4749)  time: 1.7961  data: 0.6853  max mem: 4494
Test:  [100/109]  eta: 0:00:08  model_time: 0.3892 (0.4247)  evaluator_time: 0.3877 (0.4120)  time: 0.8781  data: 0.0132  max mem: 4494
Test:  [108/109]  eta: 0:00:00  model_time: 0.3892 (0.4229)  evaluator_ti

model/epoch_0006.param
Epoch: [6]  [ 0/63]  eta: 0:08:18  lr: 0.004000  loss: 1.0667 (1.0667)  loss_classifier: 0.2922 (0.2922)  loss_box_reg: 0.3751 (0.3751)  loss_mask: 0.2509 (0.2509)  loss_objectness: 0.0685 (0.0685)  loss_rpn_box_reg: 0.0799 (0.0799)  time: 7.9102  data: 6.8628  max mem: 4502
Epoch: [6]  [62/63]  eta: 0:00:01  lr: 0.004000  loss: 0.8367 (0.8749)  loss_classifier: 0.2283 (0.2325)  loss_box_reg: 0.2800 (0.2852)  loss_mask: 0.2185 (0.2239)  loss_objectness: 0.0483 (0.0602)  loss_rpn_box_reg: 0.0535 (0.0731)  time: 0.9102  data: 0.0476  max mem: 4502
Epoch: [6] Total time: 0:01:06 (1.0618 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:03:18  model_time: 0.6354 (0.6354)  evaluator_time: 0.5121 (0.5121)  time: 1.8242  data: 0.6503  max mem: 4502
Test:  [100/109]  eta: 0:00:08  model_time: 0.3469 (0.4184)  evaluator_time: 0.2612 (0.4186)  time: 0.8372  data: 0.0115  max mem: 4502
Test:  [108/109]  eta: 0:00:00  model_time: 0.3469 (0.4161)  evaluator_ti

model/epoch_0009.param
Epoch: [9]  [ 0/63]  eta: 0:09:02  lr: 0.004000  loss: 0.9610 (0.9610)  loss_classifier: 0.2593 (0.2593)  loss_box_reg: 0.3352 (0.3352)  loss_mask: 0.2353 (0.2353)  loss_objectness: 0.0614 (0.0614)  loss_rpn_box_reg: 0.0697 (0.0697)  time: 8.6190  data: 7.6010  max mem: 4505
Epoch: [9]  [62/63]  eta: 0:00:01  lr: 0.004000  loss: 0.8153 (0.8448)  loss_classifier: 0.2249 (0.2224)  loss_box_reg: 0.2735 (0.2780)  loss_mask: 0.2148 (0.2213)  loss_objectness: 0.0483 (0.0535)  loss_rpn_box_reg: 0.0451 (0.0696)  time: 0.9110  data: 0.0463  max mem: 4505
Epoch: [9] Total time: 0:01:07 (1.0670 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:03:29  model_time: 0.6422 (0.6422)  evaluator_time: 0.5005 (0.5005)  time: 1.9185  data: 0.7490  max mem: 4505
Test:  [100/109]  eta: 0:00:07  model_time: 0.3087 (0.4017)  evaluator_time: 0.2699 (0.3822)  time: 0.7613  data: 0.0125  max mem: 4505
Test:  [108/109]  eta: 0:00:00  model_time: 0.3095 (0.3992)  evaluator_ti

model/epoch_0012.param
Epoch: [12]  [ 0/63]  eta: 0:08:44  lr: 0.003200  loss: 1.0489 (1.0489)  loss_classifier: 0.2794 (0.2794)  loss_box_reg: 0.3643 (0.3643)  loss_mask: 0.2459 (0.2459)  loss_objectness: 0.0854 (0.0854)  loss_rpn_box_reg: 0.0739 (0.0739)  time: 8.3330  data: 7.2653  max mem: 4505
Epoch: [12]  [62/63]  eta: 0:00:01  lr: 0.003200  loss: 0.7920 (0.8322)  loss_classifier: 0.2109 (0.2199)  loss_box_reg: 0.2694 (0.2741)  loss_mask: 0.2089 (0.2188)  loss_objectness: 0.0370 (0.0502)  loss_rpn_box_reg: 0.0450 (0.0692)  time: 0.9082  data: 0.0463  max mem: 4505
Epoch: [12] Total time: 0:01:07 (1.0652 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:03:32  model_time: 0.6973 (0.6973)  evaluator_time: 0.4964 (0.4964)  time: 1.9515  data: 0.7268  max mem: 4505
Test:  [100/109]  eta: 0:00:08  model_time: 0.3459 (0.4511)  evaluator_time: 0.3188 (0.4467)  time: 0.8725  data: 0.0120  max mem: 4505
Test:  [108/109]  eta: 0:00:00  model_time: 0.3637 (0.4503)  evaluator

model/epoch_0015.param
Epoch: [15]  [ 0/63]  eta: 0:09:32  lr: 0.002560  loss: 0.9972 (0.9972)  loss_classifier: 0.2731 (0.2731)  loss_box_reg: 0.3551 (0.3551)  loss_mask: 0.2443 (0.2443)  loss_objectness: 0.0552 (0.0552)  loss_rpn_box_reg: 0.0694 (0.0694)  time: 9.0923  data: 8.0710  max mem: 4505
Epoch: [15]  [62/63]  eta: 0:00:01  lr: 0.002560  loss: 0.7765 (0.8033)  loss_classifier: 0.1985 (0.2114)  loss_box_reg: 0.2637 (0.2654)  loss_mask: 0.2095 (0.2159)  loss_objectness: 0.0394 (0.0465)  loss_rpn_box_reg: 0.0436 (0.0641)  time: 0.9119  data: 0.0473  max mem: 4505
Epoch: [15] Total time: 0:01:07 (1.0786 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:03:21  model_time: 0.5406 (0.5406)  evaluator_time: 0.4755 (0.4755)  time: 1.8522  data: 0.8077  max mem: 4505
Test:  [100/109]  eta: 0:00:07  model_time: 0.2598 (0.3877)  evaluator_time: 0.2235 (0.3764)  time: 0.7045  data: 0.0112  max mem: 4505
Test:  [108/109]  eta: 0:00:00  model_time: 0.3067 (0.3850)  evaluator

model/epoch_0018.param
Epoch: [18]  [ 0/63]  eta: 0:08:31  lr: 0.002560  loss: 0.9220 (0.9220)  loss_classifier: 0.2526 (0.2526)  loss_box_reg: 0.3300 (0.3300)  loss_mask: 0.2367 (0.2367)  loss_objectness: 0.0392 (0.0392)  loss_rpn_box_reg: 0.0635 (0.0635)  time: 8.1245  data: 7.0779  max mem: 4505
Epoch: [18]  [62/63]  eta: 0:00:01  lr: 0.002560  loss: 0.7437 (0.7783)  loss_classifier: 0.2008 (0.2022)  loss_box_reg: 0.2525 (0.2582)  loss_mask: 0.2107 (0.2143)  loss_objectness: 0.0334 (0.0414)  loss_rpn_box_reg: 0.0412 (0.0622)  time: 0.9132  data: 0.0467  max mem: 4505
Epoch: [18] Total time: 0:01:07 (1.0676 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:03:36  model_time: 0.5736 (0.5736)  evaluator_time: 0.6472 (0.6472)  time: 1.9883  data: 0.7404  max mem: 4505
Test:  [100/109]  eta: 0:00:07  model_time: 0.2745 (0.3917)  evaluator_time: 0.2430 (0.3758)  time: 0.7239  data: 0.0120  max mem: 4505
Test:  [108/109]  eta: 0:00:00  model_time: 0.3085 (0.3900)  evaluator

model/epoch_0021.param
Epoch: [21]  [ 0/63]  eta: 0:09:49  lr: 0.002048  loss: 0.9718 (0.9718)  loss_classifier: 0.2462 (0.2462)  loss_box_reg: 0.3480 (0.3480)  loss_mask: 0.2361 (0.2361)  loss_objectness: 0.0806 (0.0806)  loss_rpn_box_reg: 0.0609 (0.0609)  time: 9.3630  data: 8.3032  max mem: 4505
Epoch: [21]  [62/63]  eta: 0:00:01  lr: 0.002048  loss: 0.7231 (0.7724)  loss_classifier: 0.1891 (0.2001)  loss_box_reg: 0.2530 (0.2574)  loss_mask: 0.2074 (0.2142)  loss_objectness: 0.0279 (0.0395)  loss_rpn_box_reg: 0.0447 (0.0612)  time: 0.9081  data: 0.0441  max mem: 4505
Epoch: [21] Total time: 0:01:08 (1.0815 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:02:59  model_time: 0.4949 (0.4949)  evaluator_time: 0.4972 (0.4972)  time: 1.6449  data: 0.6238  max mem: 4505
Test:  [100/109]  eta: 0:00:07  model_time: 0.2920 (0.3889)  evaluator_time: 0.2510 (0.3880)  time: 0.7780  data: 0.0130  max mem: 4505
Test:  [108/109]  eta: 0:00:00  model_time: 0.3215 (0.3881)  evaluator

model/epoch_0024.param
Epoch: [24]  [ 0/63]  eta: 0:08:16  lr: 0.002048  loss: 0.9037 (0.9037)  loss_classifier: 0.2393 (0.2393)  loss_box_reg: 0.3167 (0.3167)  loss_mask: 0.2317 (0.2317)  loss_objectness: 0.0490 (0.0490)  loss_rpn_box_reg: 0.0670 (0.0670)  time: 7.8829  data: 6.8094  max mem: 4505
Epoch: [24]  [62/63]  eta: 0:00:01  lr: 0.002048  loss: 0.7361 (0.7555)  loss_classifier: 0.1960 (0.1959)  loss_box_reg: 0.2529 (0.2519)  loss_mask: 0.2088 (0.2121)  loss_objectness: 0.0346 (0.0349)  loss_rpn_box_reg: 0.0494 (0.0608)  time: 0.9182  data: 0.0479  max mem: 4505
Epoch: [24] Total time: 0:01:07 (1.0747 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:02:48  model_time: 0.4175 (0.4175)  evaluator_time: 0.4651 (0.4651)  time: 1.5451  data: 0.6339  max mem: 4505
Test:  [100/109]  eta: 0:00:06  model_time: 0.2239 (0.3346)  evaluator_time: 0.1772 (0.3268)  time: 0.6635  data: 0.0123  max mem: 4505
Test:  [108/109]  eta: 0:00:00  model_time: 0.2336 (0.3304)  evaluator

model/epoch_0027.param
Epoch: [27]  [ 0/63]  eta: 0:08:44  lr: 0.001638  loss: 0.9112 (0.9112)  loss_classifier: 0.2450 (0.2450)  loss_box_reg: 0.3189 (0.3189)  loss_mask: 0.2329 (0.2329)  loss_objectness: 0.0517 (0.0517)  loss_rpn_box_reg: 0.0626 (0.0626)  time: 8.3249  data: 7.2791  max mem: 4505
Epoch: [27]  [62/63]  eta: 0:00:01  lr: 0.001638  loss: 0.7009 (0.7397)  loss_classifier: 0.1874 (0.1878)  loss_box_reg: 0.2399 (0.2485)  loss_mask: 0.2073 (0.2099)  loss_objectness: 0.0264 (0.0344)  loss_rpn_box_reg: 0.0436 (0.0590)  time: 0.9184  data: 0.0500  max mem: 4505
Epoch: [27] Total time: 0:01:07 (1.0724 s / it)
creating index...
index created!
Test:  [  0/109]  eta: 0:03:13  model_time: 0.4885 (0.4885)  evaluator_time: 0.5166 (0.5166)  time: 1.7749  data: 0.7427  max mem: 4505
Test:  [100/109]  eta: 0:00:06  model_time: 0.2188 (0.3352)  evaluator_time: 0.1824 (0.3290)  time: 0.6396  data: 0.0116  max mem: 4505
Test:  [108/109]  eta: 0:00:00  model_time: 0.2382 (0.3320)  evaluator

## Inference

In [19]:
dataset_test = Dataset(['data/rock/test_split.json'], 1000, transforms=get_transform(training=False))
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

In [20]:
mask_rcnn.load_state_dict(torch.load("model/epoch_0026.param"))
evaluate(mask_rcnn, data_loader_valid, device=device)
#evaluate(mask_rcnn, data_loader_test, device=device)

creating index...
index created!
Test:  [  0/109]  eta: 0:07:13  model_time: 3.0963 (3.0963)  evaluator_time: 0.2622 (0.2622)  time: 3.9798  data: 0.5904  max mem: 1047
Test:  [100/109]  eta: 0:00:06  model_time: 0.2175 (0.3432)  evaluator_time: 0.1793 (0.2899)  time: 0.6226  data: 0.0118  max mem: 2422
Test:  [108/109]  eta: 0:00:00  model_time: 0.2175 (0.3383)  evaluator_time: 0.2058 (0.2862)  time: 0.5468  data: 0.0086  max mem: 2422
Test: Total time: 0:01:12 (0.6682 s / it)
Averaged stats: model_time: 0.2175 (0.3383)  evaluator_time: 0.2058 (0.2862)
Accumulating evaluation results...
DONE (t=0.05s).
Accumulating evaluation results...
DONE (t=0.05s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.446
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.678
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.510
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.271
 

<coco_eval.CocoEvaluator at 0x7fc7594451f0>

In [21]:
zip_pickle = True

def save_inference(dataset, model, device, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    for i, data in enumerate(dataset):
        image, target = data
        pred = model(image.unsqueeze(0).to(device))[0]
        boxes = pred['boxes'].to("cpu").detach().numpy()
        labels = pred['labels'].to("cpu").detach().numpy()
        scores = pred['scores'].to("cpu").detach().numpy()
        masks = pred['masks'].to("cpu").detach().numpy()
        image_name = target['image_name']
        
        num_objs = len(masks)
        
        image_name_list = ['/root']+image_name.split('/')[3:]
        
        result = {}
        result['image'] = image.to("cpu").detach().numpy()
        result['bb'] = boxes
        result['labels'] = labels
        result['scores'] = scores
        result['masks'] = masks
        result['image_name'] = os.path.join(*image_name_list)
        result['ids'] = np.asarray(['-1']*num_objs)

        
        file_name = image_name.split('/')[-1].split('.')[0] +'.pickle'
        pickle_name = os.path.join(save_dir, file_name)
        with open(pickle_name, 'wb') as handle:
            pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
        if zip_pickle:
            zip_name = image_name.split('/')[-1].split('.')[0] +'.zip'
            zip_file = os.path.join(save_dir, zip_name)
            with zipfile.ZipFile(zip_file, 'w', compression=zipfile.ZIP_DEFLATED) as zip_obj:
                zip_obj.write(pickle_name, arcname=file_name)
            
            os.remove(pickle_name)
        
        
def display_inference(save_dir, groundtruth=False):
    assert os.path.exists(save_dir)
    files = [os.path.join(save_dir, f) for f in os.listdir(save_dir) if f.endswith('.zip')]
    for f in files:
        with zipfile.ZipFile(f, 'r') as zip_ref:
            zip_ref.extractall(os.path.dirname(f))
        
        pickle_f = f.split('.zip')[0]+'.pickle'
        with open(pickle_f, 'rb') as handle:
            result = pickle.load(handle)
            
        image = result['image']
        image_name = result['image_name']
        masks = result['masks']
        if masks.shape[0] == 0:
            os.remove(pickle_f)
            continue
        masks = np.squeeze(masks, axis=1).max(axis=0)
        x,y = masks.shape
        color_mask = np.zeros((3, x, y))
        color_mask[0,:,:] = masks
        color_mask[2,:,:] = masks
        overlay = image*0.7 + color_mask*0.3
        overlay = np.moveaxis(overlay, 0, -1)
        im = Image.fromarray(np.uint8(overlay*255))
        img_name = f.split('.zip')[0]+'.png'
        im.save(img_name)
        print(img_name)
        if not groundtruth:
            os.remove(pickle_f)
            continue
        masks = result['true_masks']
        if masks.shape[0] == 0:
            continue
        masks = masks.max(axis=0)
        x,y = masks.shape
        color_mask = np.zeros((3, x, y))
        color_mask[0,:,:] = masks
        color_mask[2,:,:] = masks
        overlay = image*0.7 + color_mask*0.3
        overlay = np.moveaxis(overlay, 0, -1)
        im = Image.fromarray(np.uint8(overlay*255))
        img_name = f.split('.zip')[0]+'_true.png'
        im.save(img_name)
        print(img_name)
        
        os.remove(pickle_f)

In [22]:
save_inference(dataset_test, mask_rcnn, device, 'data/rock')

In [30]:
display_inference('data/rock')

data/rock/26_11.png
data/rock/8_15.png
data/rock/14_12.png
data/rock/11_12.png
data/rock/13_12.png
data/rock/5_14.png
data/rock/24_9.png
data/rock/12_9.png
data/rock/22_6.png
data/rock/12_8.png
data/rock/14_11.png
data/rock/9_9.png
data/rock/15_19.png
data/rock/5_16.png
data/rock/9_19.png
data/rock/23_4.png
data/rock/16_15.png
data/rock/16_5.png
data/rock/8_22.png
data/rock/8_12.png
data/rock/8_7.png
data/rock/12_19.png
data/rock/21_13.png
data/rock/20_10.png
data/rock/5_17.png
data/rock/6_21.png
data/rock/18_8.png
data/rock/10_11.png
data/rock/6_19.png
data/rock/5_18.png
data/rock/19_9.png
data/rock/10_15.png
data/rock/28_8.png
data/rock/25_13.png
data/rock/6_20.png
data/rock/12_18.png
data/rock/17_13.png
data/rock/19_4.png
data/rock/10_19.png
data/rock/17_11.png
data/rock/8_14.png
data/rock/22_5.png
data/rock/19_3.png
data/rock/7_21.png
data/rock/16_16.png
data/rock/12_7.png
data/rock/18_10.png
data/rock/22_8.png
data/rock/24_13.png
data/rock/11_7.png
data/rock/27_12.png
data/rock/2_