In [1]:
import logging
import os
from collections import OrderedDict
import torch
from torch.nn.parallel import DistributedDataParallel

import detectron2.utils.comm as comm
from detectron2.data import MetadataCatalog, build_detection_train_loader
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
from detectron2.utils.events import EventStorage
from detectron2.evaluation import (
    COCOEvaluator,
    COCOPanopticEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    PascalVOCDetectionEvaluator,
    SemSegEvaluator,
    verify_results,
)
from detectron2.modeling import GeneralizedRCNNWithTTA
from detectron2.modeling import build_model
from detectron2.utils.logger import setup_logger

from adet.data.dataset_mapper import DatasetMapperWithBasis
from adet.config import get_cfg
from adet.checkpoint import AdetCheckpointer
from adet.evaluation import TextEvaluator

In [2]:
opts = ['MODEL.WEIGHTS', 'BoxInst_MS_R_50_3x.pth']
config_file = '../configs/BoxInst/MS_R_50_1x.yaml'
cfg = get_cfg()
cfg.merge_from_file(config_file)
cfg.merge_from_list(opts)

Config '../configs/BoxInst/MS_R_50_1x.yaml' has no VERSION. Assuming it to be compatible with latest v2.


In [3]:
cfg.MODEL.CONDINST

CfgNode({'MASK_OUT_STRIDE': 4, 'BOTTOM_PIXELS_REMOVED': -1, 'MAX_PROPOSALS': -1, 'TOPK_PROPOSALS_PER_IM': 64, 'MASK_HEAD': CfgNode({'CHANNELS': 8, 'NUM_LAYERS': 3, 'USE_FP16': False, 'DISABLE_REL_COORDS': False}), 'MASK_BRANCH': CfgNode({'OUT_CHANNELS': 16, 'IN_FEATURES': ['p3', 'p4', 'p5'], 'CHANNELS': 128, 'NORM': 'BN', 'NUM_CONVS': 4, 'SEMANTIC_LOSS_ON': False})})

In [7]:
class Trainer(DefaultTrainer):
    
    def train_loop(self, start_iter: int, max_iter: int):
        """
        Args:
            start_iter, max_iter (int): See docs above
        """
        logger = logging.getLogger("adet.trainer")
        logger.info("Starting training from iteration {}".format(start_iter))

        self.iter = self.start_iter = start_iter
        self.max_iter = max_iter

        with EventStorage(start_iter) as self.storage:
            self.before_train()
            for self.iter in range(start_iter, max_iter):
                self.before_step()
                self.run_step()
                self.after_step()
            self.after_train()
            
    def train(self):
        """
        Run training.

        Returns:
            OrderedDict of results, if evaluation is enabled. Otherwise None.
        """
        self.train_loop(self.start_iter, self.max_iter)
        if hasattr(self, "_last_eval_results") and comm.is_main_process():
            verify_results(self.cfg, self._last_eval_results)
            return self._last_eval_results

In [8]:
model = Trainer.build_model(cfg)

In [9]:
model

CondInst(
  (backbone): FPN(
    (fpn_lateral3): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral4): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (fpn_lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
    (fpn_output5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (top_block): LastLevelP6P7(
      (p6): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (p7): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (bottom_up): ResNet(
      (stem): BasicStem(
        (conv1): Conv2d(
          3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
          (norm): FrozenBatchNorm2d(num_features=64, eps=1e-05)
        )
      )
      (res2): Sequential(
        (0): BottleneckBlock(
  

In [11]:
cfg.MODEL.PIXEL_MEAN

[103.53, 116.28, 123.675]

In [3]:
mapper = DatasetMapperWithBasis(cfg, True)

In [4]:
train_loader = build_detection_train_loader(cfg, mapper=mapper)

In [5]:
data_loader_iter = iter(train_loader)

In [6]:
data = next(data_loader_iter)

In [7]:
data[0]['instances']

Instances(num_instances=1, image_height=640, image_width=853, fields=[gt_boxes: Boxes(tensor([[262.7773, 220.2133, 696.1546, 534.0267]])), gt_classes: tensor([18])])

In [26]:
model = build_model(cfg)

In [56]:
import copy
from detectron2.structures import ImageList
device = torch.device(cfg.MODEL.DEVICE)
device

device(type='cuda')

In [57]:
batched_inputs = copy.deepcopy(data)

In [58]:
original_images = [x["image"].to(device) for x in batched_inputs]
gt_instances = [x["instances"].to(device) for x in batched_inputs]

In [59]:
original_image_masks = [torch.ones_like(x[0], dtype=torch.float32) for x in original_images]

In [60]:
original_images = ImageList.from_tensors(original_images, model.backbone.size_divisibility)
original_image_masks = ImageList.from_tensors(
    original_image_masks, model.backbone.size_divisibility, pad_value=0.0
)

In [62]:
stride = model.mask_out_stride
start = int(stride // 2)

In [63]:
images = original_images.tensor

In [65]:
images.size()

torch.Size([16, 3, 800, 1344])

In [66]:
assert images.size(2) % stride == 0
assert images.size(3) % stride == 0

In [68]:
# -*- coding: utf-8 -*-
import logging
from skimage import color

import torch
from torch import nn
import torch.nn.functional as F

from detectron2.structures import ImageList
from detectron2.modeling.proposal_generator import build_proposal_generator
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.structures.instances import Instances
from detectron2.structures.masks import PolygonMasks, polygons_to_bitmask

# from .dynamic_mask_head import build_dynamic_mask_head
# from .mask_branch import build_mask_branch

from adet.utils.comm import aligned_bilinear

In [69]:
downsampled_images = F.avg_pool2d(images.float(), kernel_size=stride, stride=stride, padding=0)[:, [2,1,0]]

In [71]:
image_masks = original_image_masks.tensor

In [73]:
image_masks = image_masks[:, start::stride, start::stride]

In [75]:
images_lab = color.rgb2lab(downsampled_images[0].byte().permute(1, 2, 0).cpu().numpy())

In [78]:
images_lab = torch.as_tensor(images_lab, device=device, dtype=torch.float32)

In [79]:
images_lab = images_lab.permute(2, 0, 1)[None]

In [85]:
pairwise_size = cfg.MODEL.BOXINST.PAIRWISE.SIZE
pairwise_dilation = cfg.MODEL.BOXINST.PAIRWISE.DILATION

In [87]:
images_color_similarity = get_images_color_similarity(
    images_lab, image_masks[0],
    pairwise_size, pairwise_dilation
)

In [92]:
unfolded_images = unfold_wo_center(
        images_lab, kernel_size=pairwise_size, dilation=pairwise_dilation
    )

In [97]:
diff = images_lab[:,:,None] - unfolded_images

In [105]:
similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5)

In [102]:
unfolded_weights = unfold_wo_center(
    image_masks[0][None, None], kernel_size=pairwise_size,
    dilation=pairwise_dilation
)

In [106]:
images_color_similarity = similarity * unfolded_weights

In [107]:
images_color_similarity.size()

torch.Size([1, 1, 8, 200, 336])

In [109]:
per_im_boxes = gt_instances[0].gt_boxes.tensor

In [113]:
per_im_boxes

tensor([[  2.3655, 370.6394,  39.4778, 551.3984],
        [285.3402, 408.8090, 347.9527, 481.9584],
        [  0.0000, 567.1322,  72.3610, 589.1379],
        [658.5600, 367.0016, 720.8141, 404.7411],
        [247.8515, 367.2525, 330.8390, 425.6358],
        [588.4032, 358.8659, 652.8256, 394.3654],
        [  0.0000, 236.2752, 591.2883, 671.8029],
        [344.9062, 313.8330, 460.6515, 382.2874],
        [804.1600, 321.8074, 896.0000, 396.8026],
        [396.4262, 325.6781, 427.7863, 335.7491]], device='cuda:0')

In [116]:
im_h, im_w = images.size(2), images.size(3)

In [111]:
per_im_bitmasks = []
per_im_bitmasks_full = []

In [117]:
for per_box in per_im_boxes:
    bitmask_full = torch.zeros((img_h, img_w)).to(device).float()
    bitmask_full[int(per_box[1]):int(per_box[3] + 1), int(per_box[0]):int(per_box[2] + 1)] = 1.0
    bitmask = bitmask_full[start::stride, start::stride]
    assert bitmask.size(0) * stride == im_h
    assert bitmask.size(1) * stride == im_w
    per_im_bitmasks.append(bitmask)
    per_im_bitmasks_full.append(bitmask_full)

In [120]:
gt_instances[0].gt_bitmasks = torch.stack(per_im_bitmasks, dim=0)

In [121]:
gt_instances[0].gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0)

In [126]:
gt_instances[0].image_color_similarity = torch.cat([
                images_color_similarity for _ in range(len(gt_instances[0]))
            ], dim=0)

In [134]:
cfg.MODEL.BASIS_MODULE.LOSS_ON

False

In [81]:
def unfold_wo_center(x, kernel_size, dilation):
    assert x.dim() == 4
    assert kernel_size % 2 == 1

    # using SAME padding
    padding = (kernel_size + (dilation - 1) * (kernel_size - 1)) // 2
    unfolded_x = F.unfold(
        x, kernel_size=kernel_size,
        padding=padding,
        dilation=dilation
    )

    unfolded_x = unfolded_x.reshape(
        x.size(0), x.size(1), -1, x.size(2), x.size(3)
    )

    # remove the center pixels
    size = kernel_size ** 2
    unfolded_x = torch.cat((
        unfolded_x[:, :, :size // 2],
        unfolded_x[:, :, size // 2 + 1:]
    ), dim=2)

    return unfolded_x


def get_images_color_similarity(images, image_masks, kernel_size, dilation):
    assert images.dim() == 4
    assert images.size(0) == 1

    unfolded_images = unfold_wo_center(
        images, kernel_size=kernel_size, dilation=dilation
    )

    diff = images[:, :, None] - unfolded_images
    similarity = torch.exp(-torch.norm(diff, dim=1) * 0.5)

    unfolded_weights = unfold_wo_center(
        image_masks[None, None], kernel_size=kernel_size,
        dilation=dilation
    )
    unfolded_weights = torch.max(unfolded_weights, dim=1)[0]

    return similarity * unfolded_weights

