In [1]:
import sys
import os
path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if path not in sys.path:
    sys.path.insert(0, path)

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import math
import time
import pickle
from PIL import Image 

## Create dataset

In [3]:
import numbers
from collections.abc import Sequence
from torchvision.transforms.functional import gaussian_blur
from torchvision.transforms.functional import adjust_saturation, adjust_hue, adjust_contrast, adjust_brightness

In [4]:
torch.manual_seed(0)

def update_bbox(masks):
    boxes = []
    for mask in masks:
        if np.count_nonzero(mask) > 0:
            pos = torch.where(mask)
            xmin = torch.min(pos[1])
            xmax = torch.max(pos[1])
            ymin = torch.min(pos[0])
            ymax = torch.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])
    return torch.Tensor(boxes)

class RandomHorizontalFlip(object):
    def __init__(self, p=0.5):
        self.p = p
    def __call__(self, image, target):
        if np.random.rand() > self.p:
            image = torch.flip(image, [2])
            target['masks'] = torch.flip(target['masks'], [2])
            target['boxes'] = update_bbox(target['masks'])
        return image, target
    
class RandomVerticalFlip(object):
    def __init__(self, p=0.5):
        self.p = p
    def __call__(self, image, target):
        if np.random.rand() > self.p:
            image = torch.flip(image, [1])
            target['masks'] = torch.flip(target['masks'], [1])
            target['boxes'] = update_bbox(target['masks'])
        return image, target

def _setup_size(size, error_msg):
    if isinstance(size, numbers.Number):
        return int(size), int(size)
    if isinstance(size, Sequence) and len(size) == 1:
        return size[0], size[0]
    if len(size) != 2:
        raise ValueError(error_msg)
    return size    
    
class GaussianBlur(torch.nn.Module):
    def __init__(self, kernel_size=5, sigma=(0.1, 2.0)):
        super().__init__()
        self.kernel_size = _setup_size(kernel_size, "Kernel size should be a tuple/list of two integers")
        for ks in self.kernel_size:
            if ks <= 0 or ks % 2 == 0:
                raise ValueError("Kernel size value should be an odd and positive number.")
        if isinstance(sigma, numbers.Number):
            if sigma <= 0:
                raise ValueError("If sigma is a single number, it must be positive.")
            sigma = (sigma, sigma)
        elif isinstance(sigma, Sequence) and len(sigma) == 2:
            if not 0. < sigma[0] <= sigma[1]:
                raise ValueError("sigma values should be positive and of the form (min, max).")
        else:
            raise ValueError("sigma should be a single number or a list/tuple with length 2.")
        self.sigma = sigma

    def get_params(self, sigma_min: float, sigma_max: float) -> float:
        return torch.empty(1).uniform_(sigma_min, sigma_max).item()

    def forward(self, img, target):
        sigma = self.get_params(self.sigma[0], self.sigma[1])
        img = gaussian_blur(img, self.kernel_size, [sigma, sigma])
        return img, target

    def __repr__(self):
        s = '(kernel_size={}, '.format(self.kernel_size)
        s += 'sigma={})'.format(self.sigma)
        return self.__class__.__name__ + s    

def _log_api_usage_once(obj):
    if not obj.__module__.startswith("torchvision"):
        return
    name = obj.__class__.__name__
    if isinstance(obj, FunctionType):
        name = obj.__name__
    torch._C._log_api_usage_once(f"{obj.__module__}.{name}")
    
class ColorJitter(torch.nn.Module):
    def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
        super().__init__()
        _log_api_usage_once(self)
        self.brightness = self._check_input(brightness, "brightness")
        self.contrast = self._check_input(contrast, "contrast")
        self.saturation = self._check_input(saturation, "saturation")
        self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False)

    @torch.jit.unused
    def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError(f"If {name} is a single number, it must be non negative.")
            value = [center - float(value), center + float(value)]
            if clip_first_on_zero:
                value[0] = max(value[0], 0.0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError(f"{name} values should be between {bound}")
        else:
            raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")

        if value[0] == value[1] == center:
            value = None
        return value

    @staticmethod
    def get_params(brightness, contrast, saturation, hue):
        fn_idx = torch.randperm(4)
        b = None if brightness is None else float(torch.empty(1).uniform_(brightness[0], brightness[1]))
        c = None if contrast is None else float(torch.empty(1).uniform_(contrast[0], contrast[1]))
        s = None if saturation is None else float(torch.empty(1).uniform_(saturation[0], saturation[1]))
        h = None if hue is None else float(torch.empty(1).uniform_(hue[0], hue[1]))
        return fn_idx, b, c, s, h

    def forward(self, img, target):
        fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
            self.brightness, self.contrast, self.saturation, self.hue)

        for fn_id in fn_idx:
            if fn_id == 0 and brightness_factor is not None:
                img = adjust_brightness(img, brightness_factor)
            elif fn_id == 1 and contrast_factor is not None:
                img = adjust_contrast(img, contrast_factor)
            elif fn_id == 2 and saturation_factor is not None:
                img = adjust_saturation(img, saturation_factor)
            elif fn_id == 3 and hue_factor is not None:
                img = adjust_hue(img, hue_factor)

        return img, target

    def __repr__(self):
        format_string = self.__class__.__name__ + "("
        format_string += f"brightness={self.brightness}"
        format_string += f", contrast={self.contrast}"
        format_string += f", saturation={self.saturation}"
        format_string += f", hue={self.hue})"
        return format_string

    
class ToTensor(object):
    def __call__(self, image, target):
        image = torch.from_numpy(image).float()
        target['masks'] = torch.from_numpy(target['masks'])
        target['boxes'] = torch.from_numpy(target['boxes'])
        target['labels'] = torch.from_numpy(target['labels'])
        target['image_id'] = torch.from_numpy(target['image_id'])
        target['area'] = torch.from_numpy(target['area'])
        target['iscrowd'] = torch.from_numpy(target['iscrowd'])
        return image, target

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

def get_transform(training=True):
    transforms = []
    transforms.append(ToTensor())
    if training:
        None
        transforms.append(RandomHorizontalFlip(p=0.5))
        transforms.append(RandomVerticalFlip(p=0.5))
        transforms.append(GaussianBlur(kernel_size=9))
        transforms.append(ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1))
    return Compose(transforms)

In [5]:
from rock_detection_2d.datasets.instance_segmentation.dataset import Dataset, create_datasets
from rock_detection_2d.utils import utils
from rock_detection_2d.utils.coco_utils import get_coco_api_from_dataset
from rock_detection_2d.utils.coco_eval import CocoEvaluator

In [6]:
#create_datasets('data/rocklas/tiles')

In [7]:
dataset = Dataset(['data/rocklas/tiles/train_split.json'], 1000, transforms=get_transform(training=True))
dataset_valid = Dataset(['data/rocklas/tiles/valid_split.json', 'data/rocklas/tiles/test_split.json'], 1000, transforms=get_transform(training=False))

In [8]:
print(len(dataset))
img, target = dataset[1]
print(torch.count_nonzero(target['masks'][1]))
print(target)
print(target['masks'].shape)
print(img.shape)

54
tensor(40402)
{'boxes': tensor([[959., 831., 998., 999.],
        [472., 521., 703., 777.],
        [562., 336., 782., 597.],
        [414., 791., 580., 999.],
        [  0., 438., 415., 832.]]), 'labels': tensor([1, 1, 1, 1, 1]), 'masks': tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 1, 1, 0],
         [0, 0, 0,  ..., 1, 1, 0],
         [0, 0, 0,  ..., 1, 1, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         

In [9]:
dataset.show(10)

In [10]:
pre_computed_stats = True
if not pre_computed_stats:
    data_mean, data_std, data_max, data_min = dataset.imageStat(100)
else:
    data_mean = [0.4508936958691256, 0.43597375552024875, 0.3957421697130724]
    data_std = [0.2869340501454762, 0.27112731116220695, 0.24573479879090088]
    data_max = [1.0, 1.0, 1.0]
    data_min = [0.0, 0.0, 0.0]
print(data_mean)
print(data_std)
print(data_max)
print(data_min)

[0.4508936958691256, 0.43597375552024875, 0.3957421697130724]
[0.2869340501454762, 0.27112731116220695, 0.24573479879090088]
[1.0, 1.0, 1.0]
[0.0, 0.0, 0.0]


In [11]:
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=4, shuffle=False, num_workers=8,
    collate_fn=utils.collate_fn)

data_loader_valid = torch.utils.data.DataLoader(
    dataset_valid, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

## Train mask rcnn

In [12]:
import math
import sys
import time
import torch

import torchvision.models.detection.mask_rcnn


def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items() if not isinstance(v, str)} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print(loss_dict_reduced)
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])


def _get_iou_types(model):
    model_without_ddp = model
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model_without_ddp = model.module
    iou_types = ["bbox"]
    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
        iou_types.append("segm")
    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
        iou_types.append("keypoints")
    return iou_types


@torch.no_grad()
def evaluate(model, data_loader, device):
    n_threads = torch.get_num_threads()
    # FIXME remove this and make paste_masks_in_image run on the GPU
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types)

    for image, targets in metric_logger.log_every(data_loader, 100, header):
        image = list(img.to(device) for img in image)
        targets = [{k: v.to(device) for k, v in t.items() if not isinstance(v, str)} for t in targets]

        torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(image)

        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
        model_time = time.time() - model_time

        res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
        evaluator_time = time.time()
        coco_evaluator.update(res)
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    torch.set_num_threads(n_threads)
    return coco_evaluator


In [13]:
from rock_detection_2d.models.mask_rcnn import get_model_instance_segmentation

In [14]:
#device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cuda:0')

In [15]:
num_classes = 2
input_c = 3
detections_per_img = 256
anchor_size = ((16,), (32,), (64,), (128,), (256,))

In [16]:
mask_rcnn = get_model_instance_segmentation(num_classes, input_channel=input_c, image_mean=data_mean, image_std=data_std, stats=True, detections_per_img=detections_per_img, anchor_size=anchor_size)

In [17]:
mask_rcnn.to(device)
params = [p for p in mask_rcnn.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=5,
                                               gamma=0.8)
init_epoch = 0
num_epochs = 30

In [18]:
save_param = "model/mask_rcnn/epoch_{:04d}.param".format(init_epoch)
torch.save(mask_rcnn.state_dict(), save_param)


for epoch in range(init_epoch, init_epoch + num_epochs):
    save_param = "model/mask_rcnn/epoch_{:04d}.param".format(epoch)
    print(save_param)
    train_one_epoch(mask_rcnn, optimizer, data_loader, device, epoch, print_freq=100)
    # update the learning rate
    lr_scheduler.step()
    evaluate(mask_rcnn, data_loader_valid, device=device)
    torch.save(mask_rcnn.state_dict(), save_param)

model/mask_rcnn/epoch_0000.param


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Epoch: [0]  [ 0/14]  eta: 0:01:06  lr: 0.000389  loss: 5.3278 (5.3278)  loss_classifier: 0.7879 (0.7879)  loss_box_reg: 0.1027 (0.1027)  loss_mask: 4.2882 (4.2882)  loss_objectness: 0.0887 (0.0887)  loss_rpn_box_reg: 0.0603 (0.0603)  time: 4.7394  data: 3.4784  max mem: 3076
Epoch: [0]  [13/14]  eta: 0:00:01  lr: 0.005000  loss: 1.0648 (1.7696)  loss_classifier: 0.0986 (0.2472)  loss_box_reg: 0.0480 (0.0623)  loss_mask: 0.8265 (1.3359)  loss_objectness: 0.0861 (0.1014)  loss_rpn_box_reg: 0.0201 (0.0229)  time: 1.3311  data: 0.2555  max mem: 3364
Epoch: [0] Total time: 0:00:18 (1.3362 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:13  model_time: 0.1307 (0.1307)  evaluator_time: 0.0169 (0.0169)  time: 0.3676  data: 0.2170  max mem: 3364
Test:  [35/36]  eta: 0:00:00  model_time: 0.1428 (0.1430)  evaluator_time: 0.0252 (0.0278)  time: 0.1813  data: 0.0029  max mem: 3364
Test: Total time: 0:00:06 (0.1842 s / it)
Averaged stats: model_time: 0.1428 (0.1430)  evaluator_tim

model/mask_rcnn/epoch_0003.param
Epoch: [3]  [ 0/14]  eta: 0:01:04  lr: 0.005000  loss: 0.5619 (0.5619)  loss_classifier: 0.0777 (0.0777)  loss_box_reg: 0.1008 (0.1008)  loss_mask: 0.2795 (0.2795)  loss_objectness: 0.0334 (0.0334)  loss_rpn_box_reg: 0.0705 (0.0705)  time: 4.6027  data: 3.4076  max mem: 3364
Epoch: [3]  [13/14]  eta: 0:00:01  lr: 0.005000  loss: 0.4886 (0.5223)  loss_classifier: 0.0744 (0.0834)  loss_box_reg: 0.1086 (0.1242)  loss_mask: 0.2652 (0.2675)  loss_objectness: 0.0202 (0.0262)  loss_rpn_box_reg: 0.0151 (0.0209)  time: 1.4086  data: 0.2508  max mem: 3364
Epoch: [3] Total time: 0:00:19 (1.4138 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:25  model_time: 0.2873 (0.2873)  evaluator_time: 0.1873 (0.1873)  time: 0.7030  data: 0.2251  max mem: 3364
Test:  [35/36]  eta: 0:00:00  model_time: 0.2859 (0.2768)  evaluator_time: 0.1850 (0.1762)  time: 0.4783  data: 0.0029  max mem: 3364
Test: Total time: 0:00:17 (0.4731 s / it)
Averaged stats: model_tim

model/mask_rcnn/epoch_0006.param
Epoch: [6]  [ 0/14]  eta: 0:01:00  lr: 0.004000  loss: 0.4714 (0.4714)  loss_classifier: 0.1341 (0.1341)  loss_box_reg: 0.0805 (0.0805)  loss_mask: 0.1780 (0.1780)  loss_objectness: 0.0186 (0.0186)  loss_rpn_box_reg: 0.0602 (0.0602)  time: 4.3454  data: 3.1191  max mem: 3364
Epoch: [6]  [13/14]  eta: 0:00:01  lr: 0.004000  loss: 0.4249 (0.4626)  loss_classifier: 0.0922 (0.1035)  loss_box_reg: 0.0850 (0.1144)  loss_mask: 0.1987 (0.2103)  loss_objectness: 0.0180 (0.0195)  loss_rpn_box_reg: 0.0098 (0.0150)  time: 1.4698  data: 0.2316  max mem: 3364
Epoch: [6] Total time: 0:00:20 (1.4745 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:21  model_time: 0.2457 (0.2457)  evaluator_time: 0.1235 (0.1235)  time: 0.5865  data: 0.2144  max mem: 3364
Test:  [35/36]  eta: 0:00:00  model_time: 0.2263 (0.2325)  evaluator_time: 0.1218 (0.1161)  time: 0.3578  data: 0.0029  max mem: 3364
Test: Total time: 0:00:13 (0.3659 s / it)
Averaged stats: model_tim

model/mask_rcnn/epoch_0009.param
Epoch: [9]  [ 0/14]  eta: 0:01:02  lr: 0.004000  loss: 0.3914 (0.3914)  loss_classifier: 0.0968 (0.0968)  loss_box_reg: 0.0842 (0.0842)  loss_mask: 0.1609 (0.1609)  loss_objectness: 0.0138 (0.0138)  loss_rpn_box_reg: 0.0356 (0.0356)  time: 4.4896  data: 3.1813  max mem: 3405
Epoch: [9]  [13/14]  eta: 0:00:01  lr: 0.004000  loss: 0.4108 (0.4171)  loss_classifier: 0.0893 (0.0926)  loss_box_reg: 0.0856 (0.1060)  loss_mask: 0.1894 (0.1962)  loss_objectness: 0.0118 (0.0116)  loss_rpn_box_reg: 0.0079 (0.0107)  time: 1.4991  data: 0.2368  max mem: 3450
Epoch: [9] Total time: 0:00:21 (1.5037 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:19  model_time: 0.2138 (0.2138)  evaluator_time: 0.0849 (0.0849)  time: 0.5320  data: 0.2301  max mem: 3450
Test:  [35/36]  eta: 0:00:00  model_time: 0.2157 (0.2171)  evaluator_time: 0.0808 (0.0828)  time: 0.3124  data: 0.0029  max mem: 3450
Test: Total time: 0:00:11 (0.3160 s / it)
Averaged stats: model_tim

model/mask_rcnn/epoch_0012.param
Epoch: [12]  [ 0/14]  eta: 0:00:54  lr: 0.003200  loss: 0.3368 (0.3368)  loss_classifier: 0.0816 (0.0816)  loss_box_reg: 0.0748 (0.0748)  loss_mask: 0.1376 (0.1376)  loss_objectness: 0.0187 (0.0187)  loss_rpn_box_reg: 0.0242 (0.0242)  time: 3.9098  data: 2.4517  max mem: 3450
Epoch: [12]  [13/14]  eta: 0:00:01  lr: 0.003200  loss: 0.3449 (0.3691)  loss_classifier: 0.0774 (0.0804)  loss_box_reg: 0.0767 (0.0964)  loss_mask: 0.1670 (0.1712)  loss_objectness: 0.0110 (0.0109)  loss_rpn_box_reg: 0.0077 (0.0103)  time: 1.4778  data: 0.1864  max mem: 3496
Epoch: [12] Total time: 0:00:20 (1.4828 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:16  model_time: 0.1716 (0.1716)  evaluator_time: 0.0378 (0.0378)  time: 0.4629  data: 0.2501  max mem: 3496
Test:  [35/36]  eta: 0:00:00  model_time: 0.1835 (0.1885)  evaluator_time: 0.0530 (0.0554)  time: 0.2517  data: 0.0031  max mem: 3496
Test: Total time: 0:00:09 (0.2600 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0015.param
Epoch: [15]  [ 0/14]  eta: 0:00:55  lr: 0.002560  loss: 0.2851 (0.2851)  loss_classifier: 0.0636 (0.0636)  loss_box_reg: 0.0552 (0.0552)  loss_mask: 0.1328 (0.1328)  loss_objectness: 0.0113 (0.0113)  loss_rpn_box_reg: 0.0223 (0.0223)  time: 3.9653  data: 2.4306  max mem: 3496
Epoch: [15]  [13/14]  eta: 0:00:01  lr: 0.002560  loss: 0.3210 (0.3174)  loss_classifier: 0.0636 (0.0694)  loss_box_reg: 0.0647 (0.0780)  loss_mask: 0.1514 (0.1526)  loss_objectness: 0.0066 (0.0085)  loss_rpn_box_reg: 0.0064 (0.0089)  time: 1.5185  data: 0.1837  max mem: 3496
Epoch: [15] Total time: 0:00:21 (1.5227 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:15  model_time: 0.1672 (0.1672)  evaluator_time: 0.0373 (0.0373)  time: 0.4212  data: 0.2139  max mem: 3496
Test:  [35/36]  eta: 0:00:00  model_time: 0.1884 (0.1863)  evaluator_time: 0.0558 (0.0564)  time: 0.2496  data: 0.0029  max mem: 3496
Test: Total time: 0:00:09 (0.2571 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0018.param
Epoch: [18]  [ 0/14]  eta: 0:01:05  lr: 0.002560  loss: 0.2619 (0.2619)  loss_classifier: 0.0595 (0.0595)  loss_box_reg: 0.0481 (0.0481)  loss_mask: 0.1217 (0.1217)  loss_objectness: 0.0101 (0.0101)  loss_rpn_box_reg: 0.0225 (0.0225)  time: 4.6977  data: 3.3392  max mem: 3507
Epoch: [18]  [13/14]  eta: 0:00:01  lr: 0.002560  loss: 0.2687 (0.2830)  loss_classifier: 0.0477 (0.0561)  loss_box_reg: 0.0611 (0.0730)  loss_mask: 0.1333 (0.1393)  loss_objectness: 0.0046 (0.0066)  loss_rpn_box_reg: 0.0059 (0.0081)  time: 1.5021  data: 0.2450  max mem: 3507
Epoch: [18] Total time: 0:00:21 (1.5062 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:13  model_time: 0.1479 (0.1479)  evaluator_time: 0.0230 (0.0230)  time: 0.3880  data: 0.2133  max mem: 3507
Test:  [35/36]  eta: 0:00:00  model_time: 0.1651 (0.1696)  evaluator_time: 0.0358 (0.0402)  time: 0.2166  data: 0.0031  max mem: 3507
Test: Total time: 0:00:08 (0.2237 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0021.param
Epoch: [21]  [ 0/14]  eta: 0:01:03  lr: 0.002048  loss: 0.2307 (0.2307)  loss_classifier: 0.0362 (0.0362)  loss_box_reg: 0.0395 (0.0395)  loss_mask: 0.1249 (0.1249)  loss_objectness: 0.0105 (0.0105)  loss_rpn_box_reg: 0.0195 (0.0195)  time: 4.5429  data: 3.2169  max mem: 3507
Epoch: [21]  [13/14]  eta: 0:00:01  lr: 0.002048  loss: 0.2612 (0.2728)  loss_classifier: 0.0454 (0.0526)  loss_box_reg: 0.0574 (0.0739)  loss_mask: 0.1321 (0.1330)  loss_objectness: 0.0053 (0.0055)  loss_rpn_box_reg: 0.0061 (0.0078)  time: 1.4928  data: 0.2366  max mem: 3507
Epoch: [21] Total time: 0:00:20 (1.4972 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:13  model_time: 0.1405 (0.1405)  evaluator_time: 0.0155 (0.0155)  time: 0.3715  data: 0.2122  max mem: 3507
Test:  [35/36]  eta: 0:00:00  model_time: 0.1540 (0.1593)  evaluator_time: 0.0259 (0.0304)  time: 0.1931  data: 0.0028  max mem: 3507
Test: Total time: 0:00:07 (0.2027 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0024.param
Epoch: [24]  [ 0/14]  eta: 0:00:56  lr: 0.002048  loss: 0.2026 (0.2026)  loss_classifier: 0.0292 (0.0292)  loss_box_reg: 0.0359 (0.0359)  loss_mask: 0.1124 (0.1124)  loss_objectness: 0.0084 (0.0084)  loss_rpn_box_reg: 0.0167 (0.0167)  time: 4.0518  data: 2.6318  max mem: 3519
Epoch: [24]  [13/14]  eta: 0:00:01  lr: 0.002048  loss: 0.2334 (0.2491)  loss_classifier: 0.0392 (0.0476)  loss_box_reg: 0.0441 (0.0622)  loss_mask: 0.1221 (0.1275)  loss_objectness: 0.0039 (0.0056)  loss_rpn_box_reg: 0.0059 (0.0063)  time: 1.4617  data: 0.1969  max mem: 3555
Epoch: [24] Total time: 0:00:20 (1.4662 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:13  model_time: 0.1444 (0.1444)  evaluator_time: 0.0158 (0.0158)  time: 0.3820  data: 0.2184  max mem: 3555
Test:  [35/36]  eta: 0:00:00  model_time: 0.1577 (0.1628)  evaluator_time: 0.0326 (0.0330)  time: 0.2012  data: 0.0029  max mem: 3555
Test: Total time: 0:00:07 (0.2092 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0027.param
Epoch: [27]  [ 0/14]  eta: 0:01:04  lr: 0.001638  loss: 0.2114 (0.2114)  loss_classifier: 0.0374 (0.0374)  loss_box_reg: 0.0424 (0.0424)  loss_mask: 0.1107 (0.1107)  loss_objectness: 0.0056 (0.0056)  loss_rpn_box_reg: 0.0153 (0.0153)  time: 4.6060  data: 3.2738  max mem: 3555
Epoch: [27]  [13/14]  eta: 0:00:01  lr: 0.001638  loss: 0.2321 (0.2511)  loss_classifier: 0.0351 (0.0456)  loss_box_reg: 0.0504 (0.0663)  loss_mask: 0.1192 (0.1289)  loss_objectness: 0.0033 (0.0044)  loss_rpn_box_reg: 0.0051 (0.0059)  time: 1.4939  data: 0.2405  max mem: 3555
Epoch: [27] Total time: 0:00:20 (1.4981 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:13  model_time: 0.1434 (0.1434)  evaluator_time: 0.0157 (0.0157)  time: 0.3793  data: 0.2168  max mem: 3555
Test:  [35/36]  eta: 0:00:00  model_time: 0.1530 (0.1513)  evaluator_time: 0.0219 (0.0233)  time: 0.1802  data: 0.0028  max mem: 3555
Test: Total time: 0:00:06 (0.1873 s / it)
Averaged stats: model_

## Inference

In [19]:
dataset_test = Dataset(['data/rocklas/tiles/test_split.json'], 1000, transforms=get_transform(training=False))
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

In [20]:
mask_rcnn.load_state_dict(torch.load("model/mask_rcnn/epoch_0020.param"))
evaluate(mask_rcnn, data_loader_valid, device=device)
evaluate(mask_rcnn, data_loader_test, device=device)

creating index...
index created!
Test:  [ 0/36]  eta: 0:00:14  model_time: 0.1532 (0.1532)  evaluator_time: 0.0232 (0.0232)  time: 0.3982  data: 0.2185  max mem: 3555
Test:  [35/36]  eta: 0:00:00  model_time: 0.1634 (0.1662)  evaluator_time: 0.0322 (0.0370)  time: 0.2052  data: 0.0028  max mem: 3555
Test: Total time: 0:00:07 (0.2167 s / it)
Averaged stats: model_time: 0.1634 (0.1662)  evaluator_time: 0.0322 (0.0370)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.332
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.501
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.368
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.085
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | max

<rock_detection_2d.utils.coco_eval.CocoEvaluator at 0x7f0e14e157d0>

In [21]:
def save_inference(dataset, model, device, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    for i, data in enumerate(dataset):
        image, target = data
        pred = model(image.unsqueeze(0).to(device))[0]
        boxes = pred['boxes'].to("cpu").detach().numpy()
        labels = pred['labels'].to("cpu").detach().numpy()
        scores = pred['scores'].to("cpu").detach().numpy()
        masks = pred['masks'].to("cpu").detach().numpy()
        image_name = target['image_name']
        result = {}
        result['image'] = image.to("cpu").detach().numpy()
        result['bb'] = boxes
        result['labels'] = labels
        result['scores'] = scores
        result['masks'] = masks
        result['image_name'] = image_name
        result['true_masks'] = target['masks'].to("cpu").detach().numpy()
        file_name = image_name.split('/')[-1].split('.')[0] +'.pickle'
        f = os.path.join(save_dir, file_name)
        with open(f, 'wb') as handle:
            pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print(f)
        
def display_inference(save_dir, groundtruth=False):
    assert os.path.exists(save_dir)
    files = [os.path.join(save_dir, f) for f in os.listdir(save_dir) if f.endswith('.pickle')]
    for f in files:
        with open(f, 'rb') as handle:
            result = pickle.load(handle)
        image = result['image']
        image_name = result['image_name']
        masks = result['masks']
        if masks.shape[0] == 0:
            continue
        masks = np.squeeze(masks, axis=1).max(axis=0)
        x,y = masks.shape
        color_mask = np.zeros((3, x, y))
        color_mask[0,:,:] = masks
        color_mask[2,:,:] = masks
        overlay = image*0.7 + color_mask*0.3
        overlay = np.moveaxis(overlay, 0, -1)
        im = Image.fromarray(np.uint8(overlay*255))
        img_name = f.split('.pickle')[0]+'.png'
        im.save(img_name)
        print(img_name)
        if groundtruth:
            continue
        masks = result['true_masks']
        if masks.shape[0] == 0:
            continue
        masks = masks.max(axis=0)
        x,y = masks.shape
        color_mask = np.zeros((3, x, y))
        color_mask[0,:,:] = masks
        color_mask[2,:,:] = masks
        overlay = image*0.7 + color_mask*0.3
        overlay = np.moveaxis(overlay, 0, -1)
        im = Image.fromarray(np.uint8(overlay*255))
        img_name = f.split('.pickle')[0]+'_true.png'
        im.save(img_name)
        print(img_name)

In [22]:
dataset_infer = Dataset(['data/rocklas/inference_tiles/infer_split.json'], 1000, transforms=get_transform(training=False))
data_loader_infer = torch.utils.data.DataLoader(
    dataset_infer, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

In [23]:
save_inference(dataset_infer, mask_rcnn, device, 'data/rocklas/prediction_2d')

data/rocklas/prediction_2d/20_8.pickle
data/rocklas/prediction_2d/11_16.pickle
data/rocklas/prediction_2d/16_15.pickle
data/rocklas/prediction_2d/10_6.pickle
data/rocklas/prediction_2d/14_1.pickle
data/rocklas/prediction_2d/6_9.pickle
data/rocklas/prediction_2d/2_4.pickle
data/rocklas/prediction_2d/13_6.pickle
data/rocklas/prediction_2d/12_17.pickle
data/rocklas/prediction_2d/5_7.pickle
data/rocklas/prediction_2d/13_0.pickle
data/rocklas/prediction_2d/0_14.pickle
data/rocklas/prediction_2d/17_20.pickle
data/rocklas/prediction_2d/12_12.pickle
data/rocklas/prediction_2d/15_6.pickle
data/rocklas/prediction_2d/1_5.pickle
data/rocklas/prediction_2d/5_4.pickle
data/rocklas/prediction_2d/15_15.pickle
data/rocklas/prediction_2d/6_2.pickle
data/rocklas/prediction_2d/0_0.pickle
data/rocklas/prediction_2d/18_3.pickle
data/rocklas/prediction_2d/16_9.pickle
data/rocklas/prediction_2d/3_9.pickle
data/rocklas/prediction_2d/0_13.pickle
data/rocklas/prediction_2d/6_14.pickle
data/rocklas/prediction_2d/

data/rocklas/prediction_2d/4_12.pickle
data/rocklas/prediction_2d/0_16.pickle
data/rocklas/prediction_2d/15_11.pickle
data/rocklas/prediction_2d/16_0.pickle
data/rocklas/prediction_2d/14_11.pickle
data/rocklas/prediction_2d/11_2.pickle
data/rocklas/prediction_2d/20_10.pickle
data/rocklas/prediction_2d/12_5.pickle
data/rocklas/prediction_2d/7_3.pickle
data/rocklas/prediction_2d/10_18.pickle
data/rocklas/prediction_2d/2_10.pickle
data/rocklas/prediction_2d/19_16.pickle
data/rocklas/prediction_2d/11_13.pickle
data/rocklas/prediction_2d/18_15.pickle
data/rocklas/prediction_2d/10_16.pickle
data/rocklas/prediction_2d/11_7.pickle
data/rocklas/prediction_2d/13_13.pickle
data/rocklas/prediction_2d/18_12.pickle
data/rocklas/prediction_2d/9_7.pickle
data/rocklas/prediction_2d/14_20.pickle
data/rocklas/prediction_2d/5_5.pickle
data/rocklas/prediction_2d/18_2.pickle
data/rocklas/prediction_2d/11_8.pickle
data/rocklas/prediction_2d/3_16.pickle
data/rocklas/prediction_2d/8_2.pickle
data/rocklas/predi

data/rocklas/prediction_2d/15_1.pickle
data/rocklas/prediction_2d/2_2.pickle
data/rocklas/prediction_2d/17_17.pickle
data/rocklas/prediction_2d/4_11.pickle
data/rocklas/prediction_2d/2_11.pickle
data/rocklas/prediction_2d/18_11.pickle
data/rocklas/prediction_2d/9_6.pickle
data/rocklas/prediction_2d/7_0.pickle
data/rocklas/prediction_2d/10_13.pickle
data/rocklas/prediction_2d/20_15.pickle
data/rocklas/prediction_2d/0_10.pickle
data/rocklas/prediction_2d/3_7.pickle
data/rocklas/prediction_2d/6_10.pickle
data/rocklas/prediction_2d/20_2.pickle
data/rocklas/prediction_2d/13_11.pickle
data/rocklas/prediction_2d/6_7.pickle
data/rocklas/prediction_2d/12_9.pickle
data/rocklas/prediction_2d/4_16.pickle
data/rocklas/prediction_2d/3_19.pickle
data/rocklas/prediction_2d/5_20.pickle


In [24]:
display_inference('data/rocklas/prediction_2d', groundtruth=True)

data/rocklas/prediction_2d/8_2.png
data/rocklas/prediction_2d/9_4.png
data/rocklas/prediction_2d/1_11.png
data/rocklas/prediction_2d/9_13.png
data/rocklas/prediction_2d/7_8.png
data/rocklas/prediction_2d/15_14.png
data/rocklas/prediction_2d/18_8.png
data/rocklas/prediction_2d/14_6.png
data/rocklas/prediction_2d/6_6.png
data/rocklas/prediction_2d/5_6.png
data/rocklas/prediction_2d/14_3.png
data/rocklas/prediction_2d/9_11.png
data/rocklas/prediction_2d/18_6.png
data/rocklas/prediction_2d/8_14.png
data/rocklas/prediction_2d/13_6.png
data/rocklas/prediction_2d/8_4.png
data/rocklas/prediction_2d/11_9.png
data/rocklas/prediction_2d/3_10.png
data/rocklas/prediction_2d/5_2.png
data/rocklas/prediction_2d/13_10.png
data/rocklas/prediction_2d/12_5.png
data/rocklas/prediction_2d/10_3.png
data/rocklas/prediction_2d/8_9.png
data/rocklas/prediction_2d/12_7.png
data/rocklas/prediction_2d/9_12.png
data/rocklas/prediction_2d/2_12.png
data/rocklas/prediction_2d/7_4.png
data/rocklas/prediction_2d/11_10.pn

data/rocklas/prediction_2d/17_11.png
data/rocklas/prediction_2d/8_15.png
data/rocklas/prediction_2d/3_7.png
data/rocklas/prediction_2d/12_10.png
data/rocklas/prediction_2d/8_17.png
data/rocklas/prediction_2d/8_12.png
data/rocklas/prediction_2d/4_10.png
data/rocklas/prediction_2d/5_4.png
data/rocklas/prediction_2d/8_1.png
data/rocklas/prediction_2d/12_2.png
data/rocklas/prediction_2d/11_5.png
data/rocklas/prediction_2d/6_10.png
data/rocklas/prediction_2d/10_19.png
data/rocklas/prediction_2d/10_5.png
data/rocklas/prediction_2d/11_19.png
data/rocklas/prediction_2d/8_18.png
data/rocklas/prediction_2d/3_15.png
data/rocklas/prediction_2d/11_13.png
data/rocklas/prediction_2d/5_7.png
data/rocklas/prediction_2d/17_10.png
data/rocklas/prediction_2d/13_15.png
data/rocklas/prediction_2d/11_11.png
data/rocklas/prediction_2d/5_17.png
data/rocklas/prediction_2d/10_9.png
data/rocklas/prediction_2d/5_5.png
data/rocklas/prediction_2d/14_7.png
data/rocklas/prediction_2d/9_9.png
data/rocklas/prediction_2d