In [1]:
import sys
import os
path = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
if path not in sys.path:
    sys.path.insert(0, path)

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import math
import time
import pickle
from PIL import Image 

## Create dataset

In [3]:
torch.manual_seed(0)

class ToTensor(object):
    def __call__(self, image, target):
        image = torch.from_numpy(image).float()
        target['masks'] = torch.from_numpy(target['masks'])
        target['boxes'] = torch.from_numpy(target['boxes'])
        target['labels'] = torch.from_numpy(target['labels'])
        target['image_id'] = torch.from_numpy(target['image_id'])
        target['area'] = torch.from_numpy(target['area'])
        target['iscrowd'] = torch.from_numpy(target['iscrowd'])
        
        return image, target

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

def get_transform(training=True):
    transforms = []
    transforms.append(ToTensor()) 
    if training:
        None
    return Compose(transforms)

In [4]:
from rock_detection_2d.datasets.instance_segmentation.dataset import Dataset, create_datasets
from rock_detection_2d.utils import utils
from rock_detection_2d.utils.coco_utils import get_coco_api_from_dataset
from rock_detection_2d.utils.coco_eval import CocoEvaluator

In [5]:
#create_datasets('data/rocklas/tiles')

In [6]:
dataset = Dataset(['data/rocklas/tiles/train_split.json'], 1000, transforms=get_transform(training=True), include_name=True)
dataset_valid = Dataset(['data/rocklas/tiles/valid_split.json', 'data/rocklas/tiles/test_split.json'], 1000, transforms=get_transform(training=False), include_name=True)

In [7]:
print(len(dataset))
img, target = dataset[1]
print(target)
print(target['masks'].shape)
print(img.shape)

54
{'boxes': tensor([[  1,   0,  40, 168],
        [296, 222, 527, 478],
        [217, 402, 437, 663],
        [419,   0, 585, 208],
        [584, 167, 999, 561]]), 'labels': tensor([1, 1, 1, 1, 1]), 'masks': tensor([[[0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 0, 0, 0],
         [0, 1, 1,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
        

In [8]:
dataset.show(1)

In [9]:
pre_computed_stats = True
if not pre_computed_stats:
    data_mean, data_std, data_max, data_min = dataset.imageStat(100)
else:
    data_mean = [0.4508936958691256, 0.43597375552024875, 0.3957421697130724]
    data_std = [0.2869340501454762, 0.27112731116220695, 0.24573479879090088]
    data_max = [1.0, 1.0, 1.0]
    data_min = [0.0, 0.0, 0.0]
print(data_mean)
print(data_std)
print(data_max)
print(data_min)

[0.4508936958691256, 0.43597375552024875, 0.3957421697130724]
[0.2869340501454762, 0.27112731116220695, 0.24573479879090088]
[1.0, 1.0, 1.0]
[0.0, 0.0, 0.0]


In [10]:
# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=4, shuffle=False, num_workers=8,
    collate_fn=utils.collate_fn)

data_loader_valid = torch.utils.data.DataLoader(
    dataset_valid, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

## Train mask rcnn

In [11]:
import math
import sys
import time
import torch

import torchvision.models.detection.mask_rcnn

from coco_utils import get_coco_api_from_dataset
from coco_eval import CocoEvaluator
import utils


def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1. / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

    for images, targets in metric_logger.log_every(data_loader, print_freq, header):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items() if not isinstance(v, str)} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print(loss_dict_reduced)
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])


def _get_iou_types(model):
    model_without_ddp = model
    if isinstance(model, torch.nn.parallel.DistributedDataParallel):
        model_without_ddp = model.module
    iou_types = ["bbox"]
    if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
        iou_types.append("segm")
    if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
        iou_types.append("keypoints")
    return iou_types


@torch.no_grad()
def evaluate(model, data_loader, device):
    n_threads = torch.get_num_threads()
    # FIXME remove this and make paste_masks_in_image run on the GPU
    torch.set_num_threads(1)
    cpu_device = torch.device("cpu")
    model.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'

    coco = get_coco_api_from_dataset(data_loader.dataset)
    iou_types = _get_iou_types(model)
    coco_evaluator = CocoEvaluator(coco, iou_types)

    for image, targets in metric_logger.log_every(data_loader, 100, header):
        image = list(img.to(device) for img in image)
        targets = [{k: v.to(device) for k, v in t.items() if not isinstance(v, str)} for t in targets]

        torch.cuda.synchronize()
        model_time = time.time()
        outputs = model(image)

        outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
        model_time = time.time() - model_time

        res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
        evaluator_time = time.time()
        coco_evaluator.update(res)
        evaluator_time = time.time() - evaluator_time
        metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    coco_evaluator.accumulate()
    coco_evaluator.summarize()
    torch.set_num_threads(n_threads)
    return coco_evaluator


In [12]:
from rock_detection_2d.models.mask_rcnn import get_model_instance_segmentation

In [13]:
#device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device = torch.device('cuda:0')

In [14]:
num_classes = 2
input_c = 3
detections_per_img = 256
anchor_size = ((16,), (32,), (64,), (128,), (256,))

In [15]:
mask_rcnn = get_model_instance_segmentation(num_classes, input_channel=input_c, image_mean=data_mean, image_std=data_std, stats=True, detections_per_img=detections_per_img, anchor_size=anchor_size)

In [16]:
mask_rcnn.to(device)
params = [p for p in mask_rcnn.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=5,
                                               gamma=0.8)
init_epoch = 0
num_epochs = 30

In [17]:
save_param = "model/mask_rcnn/epoch_{:04d}.param".format(init_epoch)
torch.save(mask_rcnn.state_dict(), save_param)


for epoch in range(init_epoch, init_epoch + num_epochs):
    save_param = "model/mask_rcnn/epoch_{:04d}.param".format(epoch)
    print(save_param)
    train_one_epoch(mask_rcnn, optimizer, data_loader, device, epoch, print_freq=100)
    # update the learning rate
    lr_scheduler.step()
    evaluate(mask_rcnn, data_loader_valid, device=device)
    torch.save(mask_rcnn.state_dict(), save_param)

model/mask_rcnn/epoch_0000.param


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Epoch: [0]  [ 0/14]  eta: 0:00:25  lr: 0.000389  loss: 4.8771 (4.8771)  loss_classifier: 0.9400 (0.9400)  loss_box_reg: 0.1016 (0.1016)  loss_mask: 3.6486 (3.6486)  loss_objectness: 0.1229 (0.1229)  loss_rpn_box_reg: 0.0639 (0.0639)  time: 1.8471  data: 0.6829  max mem: 3075
Epoch: [0]  [13/14]  eta: 0:00:01  lr: 0.005000  loss: 1.1439 (1.6837)  loss_classifier: 0.1100 (0.2801)  loss_box_reg: 0.0573 (0.0740)  loss_mask: 0.9650 (1.2095)  loss_objectness: 0.0783 (0.0949)  loss_rpn_box_reg: 0.0206 (0.0252)  time: 1.1331  data: 0.0595  max mem: 3368
Epoch: [0] Total time: 0:00:15 (1.1358 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:26  model_time: 0.3344 (0.3344)  evaluator_time: 0.2325 (0.2325)  time: 0.7460  data: 0.1758  max mem: 3368
Test:  [35/36]  eta: 0:00:00  model_time: 0.3056 (0.3084)  evaluator_time: 0.2077 (0.2123)  time: 0.5210  data: 0.0028  max mem: 3368
Test: Total time: 0:00:19 (0.5381 s / it)
Averaged stats: model_time: 0.3056 (0.3084)  evaluator_tim

model/mask_rcnn/epoch_0003.param
Epoch: [3]  [ 0/14]  eta: 0:00:23  lr: 0.005000  loss: 0.5552 (0.5552)  loss_classifier: 0.0829 (0.0829)  loss_box_reg: 0.0941 (0.0941)  loss_mask: 0.2608 (0.2608)  loss_objectness: 0.0484 (0.0484)  loss_rpn_box_reg: 0.0690 (0.0690)  time: 1.6597  data: 0.5550  max mem: 3368
Epoch: [3]  [13/14]  eta: 0:00:01  lr: 0.005000  loss: 0.4784 (0.5109)  loss_classifier: 0.0829 (0.0905)  loss_box_reg: 0.1121 (0.1327)  loss_mask: 0.2370 (0.2421)  loss_objectness: 0.0224 (0.0267)  loss_rpn_box_reg: 0.0131 (0.0189)  time: 1.1376  data: 0.0471  max mem: 3368
Epoch: [3] Total time: 0:00:15 (1.1401 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:26  model_time: 0.3231 (0.3231)  evaluator_time: 0.2216 (0.2216)  time: 0.7270  data: 0.1792  max mem: 3368
Test:  [35/36]  eta: 0:00:00  model_time: 0.3119 (0.2960)  evaluator_time: 0.2183 (0.1965)  time: 0.5107  data: 0.0027  max mem: 3368
Test: Total time: 0:00:18 (0.5094 s / it)
Averaged stats: model_tim

model/mask_rcnn/epoch_0006.param
Epoch: [6]  [ 0/14]  eta: 0:00:26  lr: 0.004000  loss: 0.4402 (0.4402)  loss_classifier: 0.1050 (0.1050)  loss_box_reg: 0.1114 (0.1114)  loss_mask: 0.1668 (0.1668)  loss_objectness: 0.0202 (0.0202)  loss_rpn_box_reg: 0.0368 (0.0368)  time: 1.9052  data: 0.5837  max mem: 3493
Epoch: [6]  [13/14]  eta: 0:00:01  lr: 0.004000  loss: 0.3608 (0.3775)  loss_classifier: 0.0592 (0.0718)  loss_box_reg: 0.0934 (0.1158)  loss_mask: 0.1586 (0.1671)  loss_objectness: 0.0101 (0.0114)  loss_rpn_box_reg: 0.0077 (0.0113)  time: 1.3114  data: 0.0494  max mem: 3493
Epoch: [6] Total time: 0:00:18 (1.3140 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:16  model_time: 0.1987 (0.1987)  evaluator_time: 0.0795 (0.0795)  time: 0.4676  data: 0.1862  max mem: 3493
Test:  [35/36]  eta: 0:00:00  model_time: 0.1972 (0.1962)  evaluator_time: 0.0676 (0.0780)  time: 0.2833  data: 0.0027  max mem: 3493
Test: Total time: 0:00:10 (0.2877 s / it)
Averaged stats: model_tim

model/mask_rcnn/epoch_0009.param
Epoch: [9]  [ 0/14]  eta: 0:00:25  lr: 0.004000  loss: 0.3034 (0.3034)  loss_classifier: 0.0772 (0.0772)  loss_box_reg: 0.0604 (0.0604)  loss_mask: 0.1324 (0.1324)  loss_objectness: 0.0073 (0.0073)  loss_rpn_box_reg: 0.0260 (0.0260)  time: 1.8114  data: 0.5688  max mem: 3493
Epoch: [9]  [13/14]  eta: 0:00:01  lr: 0.004000  loss: 0.2614 (0.2805)  loss_classifier: 0.0521 (0.0567)  loss_box_reg: 0.0563 (0.0767)  loss_mask: 0.1324 (0.1313)  loss_objectness: 0.0073 (0.0072)  loss_rpn_box_reg: 0.0056 (0.0085)  time: 1.2876  data: 0.0482  max mem: 3500
Epoch: [9] Total time: 0:00:18 (1.2903 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:13  model_time: 0.1572 (0.1572)  evaluator_time: 0.0235 (0.0235)  time: 0.3672  data: 0.1837  max mem: 3500
Test:  [35/36]  eta: 0:00:00  model_time: 0.1591 (0.1594)  evaluator_time: 0.0290 (0.0301)  time: 0.1966  data: 0.0027  max mem: 3500
Test: Total time: 0:00:07 (0.2012 s / it)
Averaged stats: model_tim

model/mask_rcnn/epoch_0012.param
Epoch: [12]  [ 0/14]  eta: 0:00:27  lr: 0.003200  loss: 0.2072 (0.2072)  loss_classifier: 0.0280 (0.0280)  loss_box_reg: 0.0410 (0.0410)  loss_mask: 0.1014 (0.1014)  loss_objectness: 0.0155 (0.0155)  loss_rpn_box_reg: 0.0213 (0.0213)  time: 1.9335  data: 0.6295  max mem: 3555
Epoch: [12]  [13/14]  eta: 0:00:01  lr: 0.003200  loss: 0.1989 (0.2006)  loss_classifier: 0.0280 (0.0327)  loss_box_reg: 0.0442 (0.0542)  loss_mask: 0.1010 (0.1021)  loss_objectness: 0.0040 (0.0048)  loss_rpn_box_reg: 0.0043 (0.0067)  time: 1.3026  data: 0.0525  max mem: 3565
Epoch: [12] Total time: 0:00:18 (1.3052 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:12  model_time: 0.1371 (0.1371)  evaluator_time: 0.0109 (0.0109)  time: 0.3340  data: 0.1828  max mem: 3565
Test:  [35/36]  eta: 0:00:00  model_time: 0.1401 (0.1448)  evaluator_time: 0.0149 (0.0163)  time: 0.1672  data: 0.0027  max mem: 3565
Test: Total time: 0:00:06 (0.1720 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0015.param
Epoch: [15]  [ 0/14]  eta: 0:00:26  lr: 0.002560  loss: 0.1900 (0.1900)  loss_classifier: 0.0255 (0.0255)  loss_box_reg: 0.0562 (0.0562)  loss_mask: 0.0862 (0.0862)  loss_objectness: 0.0044 (0.0044)  loss_rpn_box_reg: 0.0176 (0.0176)  time: 1.8989  data: 0.5961  max mem: 3565
Epoch: [15]  [13/14]  eta: 0:00:01  lr: 0.002560  loss: 0.1820 (0.1885)  loss_classifier: 0.0257 (0.0281)  loss_box_reg: 0.0495 (0.0568)  loss_mask: 0.0912 (0.0952)  loss_objectness: 0.0023 (0.0032)  loss_rpn_box_reg: 0.0033 (0.0052)  time: 1.3010  data: 0.0502  max mem: 3565
Epoch: [15] Total time: 0:00:18 (1.3039 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:11  model_time: 0.1354 (0.1354)  evaluator_time: 0.0065 (0.0065)  time: 0.3290  data: 0.1839  max mem: 3565
Test:  [35/36]  eta: 0:00:00  model_time: 0.1358 (0.1362)  evaluator_time: 0.0074 (0.0085)  time: 0.1498  data: 0.0027  max mem: 3565
Test: Total time: 0:00:05 (0.1553 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0018.param
Epoch: [18]  [ 0/14]  eta: 0:00:25  lr: 0.002560  loss: 0.1591 (0.1591)  loss_classifier: 0.0277 (0.0277)  loss_box_reg: 0.0286 (0.0286)  loss_mask: 0.0834 (0.0834)  loss_objectness: 0.0029 (0.0029)  loss_rpn_box_reg: 0.0165 (0.0165)  time: 1.8504  data: 0.5443  max mem: 3565
Epoch: [18]  [13/14]  eta: 0:00:01  lr: 0.002560  loss: 0.1723 (0.1838)  loss_classifier: 0.0252 (0.0309)  loss_box_reg: 0.0553 (0.0556)  loss_mask: 0.0837 (0.0897)  loss_objectness: 0.0025 (0.0029)  loss_rpn_box_reg: 0.0029 (0.0047)  time: 1.2951  data: 0.0466  max mem: 3565
Epoch: [18] Total time: 0:00:18 (1.2978 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:13  model_time: 0.1553 (0.1553)  evaluator_time: 0.0251 (0.0251)  time: 0.3646  data: 0.1807  max mem: 3565
Test:  [35/36]  eta: 0:00:00  model_time: 0.1623 (0.1583)  evaluator_time: 0.0308 (0.0276)  time: 0.1919  data: 0.0027  max mem: 3565
Test: Total time: 0:00:07 (0.1972 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0021.param
Epoch: [21]  [ 0/14]  eta: 0:00:27  lr: 0.002048  loss: 0.1976 (0.1976)  loss_classifier: 0.0253 (0.0253)  loss_box_reg: 0.0679 (0.0679)  loss_mask: 0.0843 (0.0843)  loss_objectness: 0.0062 (0.0062)  loss_rpn_box_reg: 0.0138 (0.0138)  time: 1.9566  data: 0.6580  max mem: 3565
Epoch: [21]  [13/14]  eta: 0:00:01  lr: 0.002048  loss: 0.1794 (0.1915)  loss_classifier: 0.0233 (0.0261)  loss_box_reg: 0.0534 (0.0697)  loss_mask: 0.0821 (0.0895)  loss_objectness: 0.0022 (0.0023)  loss_rpn_box_reg: 0.0030 (0.0040)  time: 1.3027  data: 0.0535  max mem: 3565
Epoch: [21] Total time: 0:00:18 (1.3054 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:12  model_time: 0.1371 (0.1371)  evaluator_time: 0.0106 (0.0106)  time: 0.3335  data: 0.1825  max mem: 3565
Test:  [35/36]  eta: 0:00:00  model_time: 0.1383 (0.1423)  evaluator_time: 0.0137 (0.0161)  time: 0.1629  data: 0.0027  max mem: 3565
Test: Total time: 0:00:06 (0.1689 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0024.param
Epoch: [24]  [ 0/14]  eta: 0:00:28  lr: 0.002048  loss: 0.1510 (0.1510)  loss_classifier: 0.0226 (0.0226)  loss_box_reg: 0.0338 (0.0338)  loss_mask: 0.0784 (0.0784)  loss_objectness: 0.0030 (0.0030)  loss_rpn_box_reg: 0.0131 (0.0131)  time: 2.0179  data: 0.7160  max mem: 3565
Epoch: [24]  [13/14]  eta: 0:00:01  lr: 0.002048  loss: 0.1541 (0.1615)  loss_classifier: 0.0191 (0.0221)  loss_box_reg: 0.0398 (0.0458)  loss_mask: 0.0829 (0.0876)  loss_objectness: 0.0020 (0.0024)  loss_rpn_box_reg: 0.0022 (0.0036)  time: 1.3064  data: 0.0570  max mem: 3565
Epoch: [24] Total time: 0:00:18 (1.3090 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:11  model_time: 0.1386 (0.1386)  evaluator_time: 0.0084 (0.0084)  time: 0.3333  data: 0.1832  max mem: 3565
Test:  [35/36]  eta: 0:00:00  model_time: 0.1392 (0.1421)  evaluator_time: 0.0136 (0.0155)  time: 0.1652  data: 0.0027  max mem: 3565
Test: Total time: 0:00:06 (0.1682 s / it)
Averaged stats: model_

model/mask_rcnn/epoch_0027.param
Epoch: [27]  [ 0/14]  eta: 0:00:26  lr: 0.001638  loss: 0.1384 (0.1384)  loss_classifier: 0.0174 (0.0174)  loss_box_reg: 0.0306 (0.0306)  loss_mask: 0.0763 (0.0763)  loss_objectness: 0.0032 (0.0032)  loss_rpn_box_reg: 0.0109 (0.0109)  time: 1.9124  data: 0.5643  max mem: 3565
Epoch: [27]  [13/14]  eta: 0:00:01  lr: 0.001638  loss: 0.1321 (0.1470)  loss_classifier: 0.0159 (0.0194)  loss_box_reg: 0.0358 (0.0439)  loss_mask: 0.0748 (0.0784)  loss_objectness: 0.0016 (0.0021)  loss_rpn_box_reg: 0.0017 (0.0031)  time: 1.3020  data: 0.0481  max mem: 3565
Epoch: [27] Total time: 0:00:18 (1.3047 s / it)
creating index...
index created!
Test:  [ 0/36]  eta: 0:00:11  model_time: 0.1387 (0.1387)  evaluator_time: 0.0085 (0.0085)  time: 0.3330  data: 0.1827  max mem: 3565
Test:  [35/36]  eta: 0:00:00  model_time: 0.1387 (0.1402)  evaluator_time: 0.0114 (0.0134)  time: 0.1592  data: 0.0027  max mem: 3565
Test: Total time: 0:00:05 (0.1642 s / it)
Averaged stats: model_

## Inference

In [18]:
dataset_test = Dataset(['data/rocklas/tiles/test_split.json'], 1000, transforms=get_transform(training=False), include_name=True)
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

In [23]:
mask_rcnn.load_state_dict(torch.load("model/mask_rcnn/epoch_0030.param"))
evaluate(mask_rcnn, data_loader_valid, device=device)
evaluate(mask_rcnn, data_loader_test, device=device)

creating index...
index created!
Test:  [ 0/36]  eta: 0:00:11  model_time: 0.1333 (0.1333)  evaluator_time: 0.0039 (0.0039)  time: 0.3215  data: 0.1808  max mem: 3565
Test:  [35/36]  eta: 0:00:00  model_time: 0.1210 (0.1242)  evaluator_time: 0.0076 (0.0087)  time: 0.1381  data: 0.0027  max mem: 3565
Test: Total time: 0:00:05 (0.1433 s / it)
Averaged stats: model_time: 0.1210 (0.1242)  evaluator_time: 0.0076 (0.0087)
Accumulating evaluation results...
DONE (t=0.01s).
Accumulating evaluation results...
DONE (t=0.01s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.494
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.631
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.575
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.230
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.590
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | max

<coco_eval.CocoEvaluator at 0x7f3a0d5bf490>

In [24]:
def save_inference(dataset, model, device, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    for i, data in enumerate(dataset):
        image, target = data
        pred = model(image.unsqueeze(0).to(device))[0]
        boxes = pred['boxes'].to("cpu").detach().numpy()
        labels = pred['labels'].to("cpu").detach().numpy()
        scores = pred['scores'].to("cpu").detach().numpy()
        masks = pred['masks'].to("cpu").detach().numpy()
        image_name = target['image_name']
        result = {}
        result['image'] = image.to("cpu").detach().numpy()
        result['bb'] = boxes
        result['labels'] = labels
        result['scores'] = scores
        result['masks'] = masks
        result['image_name'] = image_name
        result['true_masks'] = target['masks'].to("cpu").detach().numpy()
        file_name = image_name.split('/')[-1].split('.')[0] +'.pickle'
        f = os.path.join(save_dir, file_name)
        with open(f, 'wb') as handle:
            pickle.dump(result, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print(f)
        
def display_inference(save_dir):
    assert os.path.exists(save_dir)
    files = [os.path.join(save_dir, f) for f in os.listdir(save_dir) if f.endswith('.pickle')]
    for f in files:
        with open(f, 'rb') as handle:
            result = pickle.load(handle)
        image = result['image']
        image_name = result['image_name']
        masks = result['masks']
        if masks.shape[0] == 0:
            continue
        masks = np.squeeze(masks, axis=1).max(axis=0)
        x,y = masks.shape
        color_mask = np.zeros((3, x, y))
        color_mask[0,:,:] = masks
        color_mask[2,:,:] = masks
        overlay = image*0.7 + color_mask*0.3
        overlay = np.moveaxis(overlay, 0, -1)
        im = Image.fromarray(np.uint8(overlay*255))
        img_name = f.split('.pickle')[0]+'.png'
        im.save(img_name)
        print(img_name)
        masks = result['true_masks']
        if masks.shape[0] == 0:
            continue
        masks = masks.max(axis=0)
        x,y = masks.shape
        color_mask = np.zeros((3, x, y))
        color_mask[0,:,:] = masks
        color_mask[2,:,:] = masks
        overlay = image*0.7 + color_mask*0.3
        overlay = np.moveaxis(overlay, 0, -1)
        im = Image.fromarray(np.uint8(overlay*255))
        img_name = f.split('.pickle')[0]+'_true.png'
        im.save(img_name)
        print(img_name)

In [27]:
save_inference(dataset_test, mask_rcnn, device, 'data/rocklas/prediction_2d')

data/rocklas/prediction_2d/12_7.pickle
data/rocklas/prediction_2d/11_8.pickle
data/rocklas/prediction_2d/7_10.pickle
data/rocklas/prediction_2d/10_6.pickle
data/rocklas/prediction_2d/5_7.pickle
data/rocklas/prediction_2d/9_12.pickle
data/rocklas/prediction_2d/12_9.pickle
data/rocklas/prediction_2d/10_11.pickle
data/rocklas/prediction_2d/9_14.pickle
data/rocklas/prediction_2d/12_13.pickle
data/rocklas/prediction_2d/5_10.pickle
data/rocklas/prediction_2d/10_8.pickle
data/rocklas/prediction_2d/9_10.pickle
data/rocklas/prediction_2d/3_13.pickle
data/rocklas/prediction_2d/7_9.pickle
data/rocklas/prediction_2d/8_7.pickle
data/rocklas/prediction_2d/8_11.pickle
data/rocklas/prediction_2d/12_11.pickle


In [28]:
display_inference('data/rocklas/prediction_2d')

data/rocklas/prediction_2d/12_7.png
data/rocklas/prediction_2d/12_7_true.png
data/rocklas/prediction_2d/9_12.png
data/rocklas/prediction_2d/9_12_true.png
data/rocklas/prediction_2d/8_11.png
data/rocklas/prediction_2d/8_11_true.png
data/rocklas/prediction_2d/10_8.png
data/rocklas/prediction_2d/10_8_true.png
data/rocklas/prediction_2d/3_13.png
data/rocklas/prediction_2d/3_13_true.png
data/rocklas/prediction_2d/12_13.png
data/rocklas/prediction_2d/12_13_true.png
data/rocklas/prediction_2d/9_14.png
data/rocklas/prediction_2d/9_14_true.png
data/rocklas/prediction_2d/7_9.png
data/rocklas/prediction_2d/7_9_true.png
data/rocklas/prediction_2d/9_10.png
data/rocklas/prediction_2d/9_10_true.png
data/rocklas/prediction_2d/12_9.png
data/rocklas/prediction_2d/12_9_true.png
data/rocklas/prediction_2d/8_7.png
data/rocklas/prediction_2d/8_7_true.png
data/rocklas/prediction_2d/7_10.png
data/rocklas/prediction_2d/7_10_true.png
data/rocklas/prediction_2d/10_11.png
data/rocklas/prediction_2d/10_11_true.png