In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
# -*- coding: utf-8 -*-
# File   : comm.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 27/01/2018
#
# This file is part of Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
# Distributed under MIT License.

import queue
import collections
import threading

__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']


class FutureResult(object):
    """A thread-safe future implementation. Used only as one-to-one pipe."""

    def __init__(self):
        self._result = None
        self._lock = threading.Lock()
        self._cond = threading.Condition(self._lock)

    def put(self, result):
        with self._lock:
            assert self._result is None, 'Previous result has\'t been fetched.'
            self._result = result
            self._cond.notify()

    def get(self):
        with self._lock:
            if self._result is None:
                self._cond.wait()

            res = self._result
            self._result = None
            return res


_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])


class SlavePipe(_SlavePipeBase):
    """Pipe for master-slave communication."""

    def run_slave(self, msg):
        self.queue.put((self.identifier, msg))
        ret = self.result.get()
        self.queue.put(True)
        return ret


class SyncMaster(object):
    """An abstract `SyncMaster` object.
    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
    call `register(id)` and obtain an `SlavePipe` to communicate with the master.
    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
    and passed to a registered callback.
    - After receiving the messages, the master device should gather the information and determine to message passed
    back to each slave devices.
    """

    def __init__(self, master_callback):
        """
        Args:
            master_callback: a callback to be invoked after having collected messages from slave devices.
        """
        self._master_callback = master_callback
        self._queue = queue.Queue()
        self._registry = collections.OrderedDict()
        self._activated = False

    def __getstate__(self):
        return {'master_callback': self._master_callback}

    def __setstate__(self, state):
        self.__init__(state['master_callback'])

    def register_slave(self, identifier):
        """
        Register an slave device.
        Args:
            identifier: an identifier, usually is the device id.
        Returns: a `SlavePipe` object which can be used to communicate with the master device.
        """
        if self._activated:
            assert self._queue.empty(), 'Queue is not clean before next initialization.'
            self._activated = False
            self._registry.clear()
        future = FutureResult()
        self._registry[identifier] = _MasterRegistry(future)
        return SlavePipe(identifier, self._queue, future)

    def run_master(self, master_msg):
        """
        Main entry for the master device in each forward pass.
        The messages were first collected from each devices (including the master device), and then
        an callback will be invoked to compute the message to be sent back to each devices
        (including the master device).
        Args:
            master_msg: the message that the master want to send to itself. This will be placed as the first
            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
        Returns: the message to be sent back to the master device.
        """
        self._activated = True

        intermediates = [(0, master_msg)]
        for i in range(self.nr_slaves):
            intermediates.append(self._queue.get())

        results = self._master_callback(intermediates)
        assert results[0][0] == 0, 'The first result should belongs to the master.'

        for i, res in results:
            if i == 0:
                continue
            self._registry[i].result.put(res)

        for i in range(self.nr_slaves):
            assert self._queue.get() is True

        return results[0][1]

    @property
    def nr_slaves(self):
        return len(self._registry)

In [3]:
import collections

import torch
import torch.nn.functional as F

from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast

__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']


def _sum_ft(tensor):
    """sum over the first and last dimention"""
    return tensor.sum(dim=0).sum(dim=-1)


def _unsqueeze_ft(tensor):
    """add new dementions at the front and the tail"""
    return tensor.unsqueeze(0).unsqueeze(-1)


_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])


class _SynchronizedBatchNorm(_BatchNorm):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)

        self._sync_master = SyncMaster(self._data_parallel_master)

        self._is_parallel = False
        self._parallel_id = None
        self._slave_pipe = None

    def forward(self, input):
        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
        if not (self._is_parallel and self.training):
            return F.batch_norm(
                input, self.running_mean, self.running_var, self.weight, self.bias,
                self.training, self.momentum, self.eps)

        # Resize the input to (B, C, -1).
        input_shape = input.size()
        input = input.view(input.size(0), self.num_features, -1)

        # Compute the sum and square-sum.
        sum_size = input.size(0) * input.size(2)
        input_sum = _sum_ft(input)
        input_ssum = _sum_ft(input ** 2)

        # Reduce-and-broadcast the statistics.
        if self._parallel_id == 0:
            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
        else:
            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))

        # Compute the output.
        if self.affine:
            # MJY:: Fuse the multiplication for speed.
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
        else:
            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)

        # Reshape it.
        return output.view(input_shape)

    def __data_parallel_replicate__(self, ctx, copy_id):
        self._is_parallel = True
        self._parallel_id = copy_id

        # parallel_id == 0 means master device.
        if self._parallel_id == 0:
            ctx.sync_master = self._sync_master
        else:
            self._slave_pipe = ctx.sync_master.register_slave(copy_id)

    def _data_parallel_master(self, intermediates):
        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""

        # Always using same "device order" makes the ReduceAdd operation faster.
        # Thanks to:: Tete Xiao (http://tetexiao.com/)
        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())

        to_reduce = [i[1][:2] for i in intermediates]
        to_reduce = [j for i in to_reduce for j in i]  # flatten
        target_gpus = [i[1].sum.get_device() for i in intermediates]

        sum_size = sum([i[1].sum_size for i in intermediates])
        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)

        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)

        outputs = []
        for i, rec in enumerate(intermediates):
            outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2])))

        return outputs

    def _compute_mean_std(self, sum_, ssum, size):
        """Compute the mean and standard-deviation with sum and square-sum. This method
        also maintains the moving average on the master device."""
        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
        mean = sum_ / size
        sumvar = ssum - sum_ * mean
        unbias_var = sumvar / (size - 1)
        bias_var = sumvar / size

        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data

        return mean, bias_var.clamp(self.eps) ** -0.5


class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
    r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
    mini-batch.
    .. math::
        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
    This module differs from the built-in PyTorch BatchNorm1d as the mean and
    standard-deviation are reduced across all devices during training.
    For example, when one uses `nn.DataParallel` to wrap the network during
    training, PyTorch's implementation normalize the tensor on each device using
    the statistics only on that device, which accelerated the computation and
    is also easy to implement, but the statistics might be inaccurate.
    Instead, in this synchronized version, the statistics will be computed
    over all training samples distributed on multiple devices.
    Note that, for one-GPU or CPU-only case, this module behaves exactly same
    as the built-in PyTorch implementation.
    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).
    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.
    During evaluation, this running mean/variance is used for normalization.
    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
    Args:
        num_features: num_features from an expected input of size
            `batch_size x num_features [x width]`
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        affine: a boolean value that when set to ``True``, gives the layer learnable
            affine parameters. Default: ``True``
    Shape:
        - Input: :math:`(N, C)` or :math:`(N, C, L)`
        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
    Examples:
        >>> # With Learnable Parameters
        >>> m = SynchronizedBatchNorm1d(100)
        >>> # Without Learnable Parameters
        >>> m = SynchronizedBatchNorm1d(100, affine=False)
        >>> input = torch.autograd.Variable(torch.randn(20, 100))
        >>> output = m(input)
    """

    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError('expected 2D or 3D input (got {}D input)'
                             .format(input.dim()))
        super(SynchronizedBatchNorm1d, self)._check_input_dim(input)


class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
    r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
    of 3d inputs
    .. math::
        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
    This module differs from the built-in PyTorch BatchNorm2d as the mean and
    standard-deviation are reduced across all devices during training.
    For example, when one uses `nn.DataParallel` to wrap the network during
    training, PyTorch's implementation normalize the tensor on each device using
    the statistics only on that device, which accelerated the computation and
    is also easy to implement, but the statistics might be inaccurate.
    Instead, in this synchronized version, the statistics will be computed
    over all training samples distributed on multiple devices.
    Note that, for one-GPU or CPU-only case, this module behaves exactly same
    as the built-in PyTorch implementation.
    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).
    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.
    During evaluation, this running mean/variance is used for normalization.
    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        affine: a boolean value that when set to ``True``, gives the layer learnable
            affine parameters. Default: ``True``
    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)
    Examples:
        >>> # With Learnable Parameters
        >>> m = SynchronizedBatchNorm2d(100)
        >>> # Without Learnable Parameters
        >>> m = SynchronizedBatchNorm2d(100, affine=False)
        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
        >>> output = m(input)
    """

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))
        super(SynchronizedBatchNorm2d, self)._check_input_dim(input)


class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
    """Applies Batch Normalization over a 5d input that is seen as a mini-batch
    of 4d inputs
    .. math::
        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
    This module differs from the built-in PyTorch BatchNorm3d as the mean and
    standard-deviation are reduced across all devices during training.
    For example, when one uses `nn.DataParallel` to wrap the network during
    training, PyTorch's implementation normalize the tensor on each device using
    the statistics only on that device, which accelerated the computation and
    is also easy to implement, but the statistics might be inaccurate.
    Instead, in this synchronized version, the statistics will be computed
    over all training samples distributed on multiple devices.
    Note that, for one-GPU or CPU-only case, this module behaves exactly same
    as the built-in PyTorch implementation.
    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).
    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.
    During evaluation, this running mean/variance is used for normalization.
    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
    or Spatio-temporal BatchNorm
    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x depth x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        affine: a boolean value that when set to ``True``, gives the layer learnable
            affine parameters. Default: ``True``
    Shape:
        - Input: :math:`(N, C, D, H, W)`
        - Output: :math:`(N, C, D, H, W)` (same shape as input)
    Examples:
        >>> # With Learnable Parameters
        >>> m = SynchronizedBatchNorm3d(100)
        >>> # Without Learnable Parameters
        >>> m = SynchronizedBatchNorm3d(100, affine=False)
        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
        >>> output = m(input)
    """

    def _check_input_dim(self, input):
        if input.dim() != 5:
            raise ValueError('expected 5D input (got {}D input)'
                             .format(input.dim()))
        super(SynchronizedBatchNorm3d, self)._check_input_dim(input)

In [4]:
from collections import OrderedDict
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo

#from models.sync_batchnorm import SynchronizedBatchNorm2d

def load_weights_sequential(target, source_state):
    
    new_dict = OrderedDict()

    for k1, v1 in target.state_dict().items():
        if not 'num_batches_tracked' in k1:
            tar_v = source_state[k1]

            if v1.shape != tar_v.shape:
                c, _, w, h = v1.shape
                tar_v = torch.cat([
                    tar_v, 
                    torch.zeros((c,v1.shape[1]-tar_v.shape[1],w,h)),
                ], 1)

            new_dict[k1] = tar_v
    target.load_state_dict(new_dict)

model_urls = {
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}


def conv3x3(in_planes, out_planes, stride=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, dilation=dilation, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
        self.bn1 = SynchronizedBatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
        self.bn2 = SynchronizedBatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
        self.bn1 = SynchronizedBatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
                               padding=dilation, bias=False)
        self.bn2 = SynchronizedBatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
        self.bn3 = SynchronizedBatchNorm2d(planes * 4)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(self, block, layers=(3, 4, 23, 3)):
        self.inplanes = 64
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        #self.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3,
         #                      bias=False)
        self.bn1 = SynchronizedBatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                SynchronizedBatchNorm2d(planes * block.expansion),
            )

        layers = [block(self.inplanes, planes, stride, downsample)]
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, dilation=dilation))

        return nn.Sequential(*layers)

    def forward(self, x): # [6, 4, 224, 224]
        x = self.conv1(x)  # /2 [6, 64, 112, 112]
        x = self.bn1(x)
        x_1 = self.relu(x)
        x = self.maxpool(x_1)  # /2 [6, 64, 56, 56]

        x_2 = self.layer1(x) # [[6, 256, 56, 56]]
        x = self.layer2(x_2)   # /2 [1, 512, 28, 28]
        x_3 = self.layer3(x) # [1, 1024, 28, 28]
        x = self.layer4(x_3) # [1, 2048, 28, 28]

        return x_1, x_2, x_3


def resnet50(pretrained=True):
    model = ResNet(Bottleneck, [3, 4, 6, 3]) # [3, 4, 6, 3]
    if pretrained:
        load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']))

    #model = base_model
    

    return model

In [5]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
#from models.sync_batchnorm import SynchronizedBatchNorm2d

class _ASPPModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
        super(_ASPPModule, self).__init__()
        self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=dilation, bias=False)
        self.bn = BatchNorm(planes)
        self.relu = nn.ReLU(inplace=True)

        self._init_weight()

    def forward(self, x):
        x = self.atrous_conv(x)
        x = self.bn(x)

        return self.relu(x)

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

class ASPP_no4level(nn.Module):
    def __init__(self, backbone, output_stride, BatchNorm):
        super(ASPP_no4level, self).__init__()
        if backbone == 'drn':
            inplanes = 512
        elif backbone == 'mobilenet':
            inplanes = 320
        else:
            inplanes = 2048
            low_level_inplanes = 256 #
        if output_stride == 16:
            dilations = [1, 6, 12, 18]
        elif output_stride == 8:
            dilations = [1, 12, 24, 36]
        else:
            raise NotImplementedError

        self.aspp1_128 = _ASPPModule(64, 64, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
        self.aspp1_256 = _ASPPModule(256, 64, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
        self.aspp1_1024 = _ASPPModule(1024, 128, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)

        self.bn1_128 = BatchNorm(64)
        self.bn1_256 = BatchNorm(64)
        self.bn1_1024 = BatchNorm(128)
        # self.bn1_2048 = BatchNorm(256)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(0.5)

        self.last_conv = nn.Sequential(nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                BatchNorm(256),
                                nn.ReLU(inplace=True),
                                nn.Dropout(0.5))

        self._init_weight()
        print("ASPP_4level")

    def forward(self, x_1, x_2, x_3):
        x_1 = self.aspp1_128(x_1)
        x_1 = self.bn1_128(x_1)
        x_1 = self.relu(x_1)
        x_1 = self.dropout(x_1)

        x_2 = self.aspp1_256(x_2)
        x_2 = self.bn1_256(x_2)
        x_2 = self.relu(x_2)
        x_2 = self.dropout(x_2)

        x_3 = self.aspp1_1024(x_3)
        x_3 = self.bn1_1024(x_3)
        x_3 = self.relu(x_3)
        x_3 = self.dropout(x_3)

        x_2 = F.interpolate(x_2, size=x_1.size()[2:], mode='bilinear', align_corners=True)
        x_3 = F.interpolate(x_3, size=x_1.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x_1, x_2, x_3), dim=1)
        x = self.last_conv(x)

        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, SynchronizedBatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [6]:
import torch
from torch import nn
from torch.nn import functional as F

#from models.network import extractors
#from models.sync_batchnorm import SynchronizedBatchNorm2d
#from models.network.aspp import ASPP_no4level

def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret


class MLP(nn.Module):

    def __init__(self, in_dim, out_dim, hidden_list):
        super().__init__()
        layers = []
        lastv = in_dim
        for hidden in hidden_list:
            layers.append(nn.Linear(lastv, hidden))
            layers.append(nn.ReLU())
            lastv = hidden
        layers.append(nn.Linear(lastv, out_dim))
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        shape = x.shape[:-1]
        x = self.layers(x.view(-1, x.shape[-1]))
        return x.view(*shape, -1)


class CRMNet(nn.Module):
    def __init__(self,  backend='resnet34', pretrained=True):
        super().__init__()
        self.feats = resnet50(pretrained)
        self.aspp_ = ASPP_no4level(backbone=backend, output_stride=8, BatchNorm=SynchronizedBatchNorm2d)
        self.imnet = MLP(in_dim=256+6, out_dim=1, hidden_list=[32, 32, 32, 32])

    def forward(self, x, seg, coord, cell, inter_s8=None, inter_s4=None):
 
        # extract feature
        p = torch.cat((x, seg), 1)
        # x, low_level_feat
        
        x1_feat, x2_feat, x3_feat = self.feats(p) # [6, 64, 112, 112] [6, 256, 56, 56] [6, 1024, 28, 28]
        feat = self.aspp_(x1_feat, x2_feat, x3_feat)
        
        vx_lst = [-1, 1]
        vy_lst = [-1, 1]
        eps_shift = 1e-6

        rx = 2 / feat.shape[-2] / 2 
        ry = 2 / feat.shape[-1] / 2 

        feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda().permute(2, 0, 1).unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) # 

        preds = []
        areas = []
        for vx in vx_lst:
            for vy in vy_lst:
                coord_ = coord.clone()
                coord_[:, :, 0] += vx * rx + eps_shift
                coord_[:, :, 1] += vy * ry + eps_shift
                coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
                q_feat = F.grid_sample(
                    feat, coord_.flip(-1).unsqueeze(1),
                    mode='nearest', align_corners=False)[:, :, 0, :] \
                    .permute(0, 2, 1)
                q_coord = F.grid_sample(
                    feat_coord, coord_.flip(-1).unsqueeze(1),
                    mode='nearest', align_corners=False)[:, :, 0, :] \
                    .permute(0, 2, 1)
                rel_coord = coord - q_coord
                rel_coord[:, :, 0] *= feat.shape[-2]
                rel_coord[:, :, 1] *= feat.shape[-1]
                inp = torch.cat([q_feat, rel_coord, coord], dim=-1)

                # if self.cell_decode:
                rel_cell = cell.clone()
                rel_cell[:, :, 0] *= feat.shape[-2]
                rel_cell[:, :, 1] *= feat.shape[-1]
                inp = torch.cat([inp, rel_cell], dim=-1)

                bs, q = coord.shape[:2]
                pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
                preds.append(pred)

                area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
                areas.append(area + 1e-9)

        tot_area = torch.stack(areas).sum(dim=0)
        # if self.local_ensemble:
        t = areas[0]; areas[0] = areas[3]; areas[3] = t
        t = areas[1]; areas[1] = areas[2]; areas[2] = t
        ret = 0

        for pred, area in zip(preds, areas):
            ret = ret + pred * (area / tot_area).unsqueeze(-1)
        
        pred_224 = torch.sigmoid(ret) # [6, 1, 224, 224]

        images = {}
        images['out_224'] = ret
        images['pred_224'] = pred_224
        
        return images    

In [7]:
import torch
from torch import nn
from torch.nn import functional as F

import numpy as np

class SobelOperator(nn.Module):
    def __init__(self, epsilon):
        super().__init__()
        self.epsilon = epsilon

        x_kernel = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]])/4
        self.conv_x = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv_x.weight.data = torch.tensor(x_kernel).unsqueeze(0).unsqueeze(0).float().cuda()
        self.conv_x.weight.requires_grad = False

        y_kernel = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]])/4
        self.conv_y = nn.Conv2d(1, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv_y.weight.data = torch.tensor(y_kernel).unsqueeze(0).unsqueeze(0).float().cuda()
        self.conv_y.weight.requires_grad = False

    def forward(self, x):

        b, c, h, w = x.shape
        if c > 1:
            x = x.view(b*c, 1, h, w)

        x = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)

        grad_x = self.conv_x(x)
        grad_y = self.conv_y(x)
        
        x = torch.sqrt(grad_x ** 2 + grad_y ** 2 + self.epsilon)

        x = x.view(b, c, h, w)

        return x

class SobelComputer:
    def __init__(self):
        self.sobel = SobelOperator(1e-4)

    def compute_edges(self, images):
        gt = images['gt'].float()
        pred = images['pred_224'].float()
        images['gt_sobel'] = self.sobel(gt)
        images['pred_sobel'] = self.sobel(pred)

In [8]:
import torchvision.transforms as transforms

import os

from torch.utils.tensorboard import SummaryWriter
# import git
import warnings

def tensor_to_numpy(image):
    image_np = (image.numpy() * 255).astype('uint8')
    return image_np

def detach_to_cpu(x):
    return x.detach().cpu()

def fix_width_trunc(x):
    return ('{:.9s}'.format('{:0.9f}'.format(x)))

class BoardLogger:
    def __init__(self, id):

        if id is None:
            self.no_log = True
            warnings.warn('Logging has been disbaled.')
        else:
            self.no_log = False

            self.inv_im_trans = transforms.Normalize(
                mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                std=[1/0.229, 1/0.224, 1/0.225])

            self.inv_seg_trans = transforms.Normalize(
                mean=[-0.5/0.5],
                std=[1/0.5])

            log_path = os.path.join('.', 'log', '%s' % id)
            self.logger = SummaryWriter(log_path)

        # repo = git.Repo(".")
        # self.log_string('git', str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha))

    def log_scalar(self, tag, x, step):
        if self.no_log:
            warnings.warn('Logging has been disabled.')
            return
        self.logger.add_scalar(tag, x, step)

    def log_metrics(self, l1_tag, l2_tag, val, step, f=None):
        tag = l1_tag + '/' + l2_tag
        text = 'It {:8d} [{:5s}] [{:19s}]: {:s}'.format(step, l1_tag.upper(), l2_tag, fix_width_trunc(val))
        print(text)
        if f is not None:
            f.write(text + '\n')
            f.flush()
        self.log_scalar(tag, val, step)

    def log_im(self, tag, x, step):
        if self.no_log:
            warnings.warn('Logging has been disabled.')
            return
        x = detach_to_cpu(x)
        x = self.inv_im_trans(x)
        x = tensor_to_numpy(x)
        self.logger.add_image(tag, x, step)

    def log_cv2(self, tag, x, step):
        if self.no_log:
            warnings.warn('Logging has been disabled.')
            return
        x = x.transpose((2, 0, 1))
        self.logger.add_image(tag, x, step)

    def log_seg(self, tag, x, step):
        if self.no_log:
            warnings.warn('Logging has been disabled.')
            return
        x = detach_to_cpu(x)
        x = self.inv_seg_trans(x)
        x = tensor_to_numpy(x)
        self.logger.add_image(tag, x, step)

    def log_gray(self, tag, x, step):
        if self.no_log:
            warnings.warn('Logging has been disabled.')
            return
        x = detach_to_cpu(x)
        x = tensor_to_numpy(x)
        self.logger.add_image(tag, x, step)

    def log_string(self, tag, x):
        print(tag, x)
        if self.no_log:
            warnings.warn('Logging has been disabled.')
            return
        self.logger.add_text(tag, x)

    def log_total(self, tag, im, gt, seg, pred, step):
        
        if self.no_log:
            warnings.warn('Logging has been disabled.')
            return
        
        row_cnt = min(10, im.shape[0])
        w = im.shape[2]
        h = im.shape[3]
        
        output_image = np.zeros([3, w*row_cnt, h*5], dtype=np.uint8)
        
        for i in range(row_cnt):
            im_ = tensor_to_numpy(self.inv_im_trans(detach_to_cpu(im[i])))
            gt_ = tensor_to_numpy(detach_to_cpu(gt[i]))
            seg_ = tensor_to_numpy(self.inv_seg_trans(detach_to_cpu(seg[i])))
            pred_ = tensor_to_numpy(detach_to_cpu(pred[i]))
            
            output_image[:, i * w : (i+1) * w, 0 : h] = im_
            output_image[:, i * w : (i+1) * w, h : 2*h] = gt_
            output_image[:, i * w : (i+1) * w, 2*h : 3*h] = seg_
            output_image[:, i * w : (i+1) * w, 3*h : 4*h] = pred_
            output_image[:, i * w : (i+1) * w, 4*h : 5*h] = im_*0.5 + 0.5 * (im_ * (1-(pred_/255)) + (pred_/255) * (np.array([255,0,0],dtype=np.uint8).reshape([1,3,1,1])))
            
        self.logger.add_image(tag, output_image, step)

In [9]:
import os
import torch

class ModelSaver:
    def __init__(self, id):

        if id is None:
            self.no_log = True
            print('Saving has been disbaled.')
        else:
            self.no_log = False

            self.save_path = os.path.join('.', 'weights', '%s' % id )

    def save_model(self, model, step):
        if self.no_log:
            print('Saving has been disabled.')
            return

        os.makedirs(self.save_path, exist_ok=True)

        model_path = os.path.join(self.save_path, 'model_%s' % step)
        torch.save(model.state_dict(), model_path)
        print('Model saved to %s.' % model_path)

In [10]:
from argparse import ArgumentParser

class HyperParameters():
    def parse(self, unknown_arg_ok=False):
        parser = ArgumentParser()
        parser = {}
        parser['iterations'] = 8000
        parser['batch_size'] = 12
        parser['lr'] = 2.25e-4
        parser['steps'] = [22500, 37500]
        parser['gamma'] = 0.1
        parser['weight_decay'] = 1e-4
        #parser['load'] = # path to pretrained model
        parser['ce_weight'] = 0.0
        parser['l1_weight'] = 1.0
        parser['l2_weight'] = 1.0
        parser['grad_weight'] = 5.0
        parser['id'] = 'first'

        return parser


        # Generic learning parameters
        #parser.add_argument('-i', '--iterations', help='Number of training iterations', default=4.5e4, type=int)
        #print("added first argument")
        #parser.add_argument('-b', '--batch_size', help='Batch size', default=12, type=int)
        #parser.add_argument('--lr', help='Initial learning rate', default=2.25e-4, type=float)
        #parser.add_argument('--steps', help='Iteration at which learning rate is decayed by gamma', default=[22500, 37500], type=int, nargs='*')
        #parser.add_argument('--gamma', help='Gamma used in learning rate decay', default=0.1, type=float)
        #parser.add_argument('--weight_decay', help='Weight decay', default=1e-4, type=float)

        # same decay applied to discriminator
        #parser.add_argument('--load', help='Path to pretrained model if available')

        #parser.add_argument('--ce_weight', help='Weight of the CE loss', default=0.0, type=float)
        #parser.add_argument('--l1_weight', help='Weight of the L1 loss', default=1.0, type=float)
        #parser.add_argument('--l2_weight', help='Weight of the L2 loss', default=1.0, type=float)
        #parser.add_argument('--grad_weight', help='Weight of the gradient loss', default=5.0, type=float)

        # Logging information, this one is positional and mandatory
        #parser.add_argument('id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard')

        #print("added all arguments")

        #if unknown_arg_ok:
         #   args, _ = parser.parse_known_args()
          #  self.args = vars(args)
        #else:
        #    print("unknown arg")
        #    pa = parser.parse_args(args=[])
         #   print("passed first")
         #   self.args = vars(pa)
         #   print("passed unknown args")

    #def __getitem__(self, key):
     #   return self.args[key]

    #def __str__(self):
     #   return str(self.args)

In [11]:
"""
Integrate numerical values for some iterations
Typically used for loss computation
Just call finalize and create a new Integrator when you want to display 
"""
class Integrator:
    def __init__(self, logger):
        self.values = {}
        self.counts = {}
        self.hooks  = [] # List is used here to maintain insertion order

        self.logger = logger

    def add_tensor(self, key, tensor):
        if key not in self.values:
            self.counts[key] = 1
            if type(tensor) == float or type(tensor) == int:
                self.values[key] = tensor
            else:
                self.values[key] = tensor.mean().item()
        else:
            self.counts[key] += 1
            if type(tensor) == float or type(tensor) == int:
                self.values[key] += tensor
            else:
                self.values[key] += tensor.mean().item()

    def add_dict(self, tensor_dict):
        for k, v in tensor_dict.items():
            self.add_tensor(k, v)

    def add_hook(self, hook):
        """
        Adds a custom hook, i.e. compute new metrics using values in the dict
        The hook takes the dict as argument, and returns a (k, v) tuple
        """
        if type(hook) == list:
            self.hooks.extend(hook)
        else:
            self.hooks.append(hook)

    def reset_except_hooks(self):
        self.values = {}
        self.counts = {}

    # Average and output the metrics
    def finalize(self, prefix, iter, f=None):

        for hook in self.hooks:
            k, v = hook(self.values)
            self.add_tensor(k, v)

        for k, v in self.values.items():
            avg = v / self.counts[k]

            self.logger.log_metrics(prefix, k, avg, iter, f)

In [12]:
from torch.nn import functional as F

def compute_tensor_iu(seg, gt):

    #seg = seg.squeeze(1)
    #gt = gt.squeeze(1)
    
    intersection = (seg & gt).float().sum()
    union = (seg | gt).float().sum()

    return intersection, union

def compute_tensor_iou(seg, gt):
    #seg = seg.squeeze(1)
    #gt = gt.squeeze(1)
    
    intersection = (seg & gt).float().sum((1, 2))
    union = (seg | gt).float().sum((1, 2))
    
    iou = (intersection + 1e-6) / (union + 1e-6)
    
    return iou 

def resize_min_side(im, size, method):
    h, w = im.shape[-2:]
    
    min_side = min(h, w)
    ratio = size / min_side
    if method == 'bilinear':
        return F.interpolate(im, scale_factor=ratio, mode=method, align_corners=False)
    else:
        return F.interpolate(im, scale_factor=ratio, mode=method)

def resize_max_side(im, size, method):
    h, w = im.shape[-2:]
    max_side = max(h, w)
    ratio = size / max_side
    if method in ['bilinear', 'bicubic']:
        return F.interpolate(im, scale_factor=ratio, mode=method, align_corners=False)
    else:
        return F.interpolate(im, scale_factor=ratio, mode=method)

In [13]:
import torch.nn.functional as F

#from util.util import compute_tensor_iu

def get_new_iou_hook(values, size):
    return 'iou/new_iou_%s'%size, values['iou/new_i_%s'%size]/values['iou/new_u_%s'%size]

def get_orig_iou_hook(values):
    return 'iou/orig_iou', values['iou/orig_i']/values['iou/orig_u']

def get_iou_gain(values, size):
    return 'iou/iou_gain_%s'%size, values['iou/new_iou_%s'%size] - values['iou/orig_iou']

iou_hooks_to_be_used = [
        get_orig_iou_hook,
        lambda x: get_new_iou_hook(x, '224'), lambda x: get_iou_gain(x, '224'),
    ]

iou_hooks_final_only = [
    get_orig_iou_hook,
    lambda x: get_new_iou_hook(x, '224'), lambda x: get_iou_gain(x, '224'),
]

# Compute common loss and metric for generator only
def compute_loss_and_metrics(images, para, detailed=True, need_loss=True, has_lower_res=True):

    """
    This part compute loss and metrics for the generator
    """

    loss_and_metrics = {}

    gt = images['gt']
    seg = images['seg']


    #seg = seg.argmax(1)

    pred_224 = images['pred_224']

    if need_loss:
        # Loss weights
        ce_weights = para['ce_weight']
        l1_weights = para['l1_weight']
        l2_weights = para['l2_weight']

        # temp holder for losses at different scale
        ce_loss = 0 
        l1_loss = 0 
        l2_loss = 0 
        loss = 0 

        ce_loss = F.binary_cross_entropy_with_logits(images['out_224'], (gt>0.5).float())
        l1_loss = F.l1_loss(pred_224, gt)
        l2_loss = F.mse_loss(pred_224, gt)

        loss_and_metrics['grad_loss'] = F.l1_loss(images['gt_sobel'], images['pred_sobel'])

        # Weighted loss for different levels
        loss = ce_loss * ce_weights + l1_loss * l1_weights + l2_loss * l2_weights
        
        loss += loss_and_metrics['grad_loss'] * para['grad_weight']
    
    #metric = SegmentationMetric(20)


    #metric.update(seg, gt)

    #pixAcc, mIoU = metric.get()


    

    """
    Compute IOU stats
    """
    #print(torch.nonzero(seg))
    orig_total_i, orig_total_u = compute_tensor_iu(seg>0.5, gt>0.5)
    loss_and_metrics['iou/orig_i'] = orig_total_i
    loss_and_metrics['iou/orig_u'] = orig_total_u

    new_total_i, new_total_u = compute_tensor_iu(pred_224>0.5, gt>0.5)
    loss_and_metrics['iou/new_i_224'] = new_total_i
    loss_and_metrics['iou/new_u_224'] = new_total_u
    #loss_and_metrics['mIoU'] = mIoU
        
    """
    All done.
    Now gather everything in a dict for logging
    """

    if need_loss:
        loss_and_metrics['total_loss'] = 0
        loss_and_metrics['ce_loss'] = ce_loss
        loss_and_metrics['l1_loss'] = l1_loss
        loss_and_metrics['l2_loss'] = l2_loss
        loss_and_metrics['loss'] = loss

        loss_and_metrics['total_loss'] += loss

    return loss_and_metrics

In [14]:
import cv2
import numpy as np

import torchvision.transforms as transforms

inv_im_trans = transforms.Normalize(
                mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
                std=[1/0.229, 1/0.224, 1/0.225])

inv_seg_trans = transforms.Normalize(
    mean=[-0.5/0.5],
    std=[1/0.5])

def tensor_to_numpy(image):
    image_np = (image.numpy() * 255).astype('uint8')
    return image_np

def tensor_to_np_float(image):
    image_np = image.numpy().astype('float32')
    return image_np

def detach_to_cpu(x):
    return x.detach().cpu()

def transpose_np(x):
    return np.transpose(x, [1,2,0])

def tensor_to_gray_im(x):
    x = detach_to_cpu(x)
    x = tensor_to_numpy(x)
    x = transpose_np(x)
    return x

def tensor_to_seg(x):
    x = detach_to_cpu(x)
    x = inv_seg_trans(x)
    x = tensor_to_numpy(x)
    x = transpose_np(x)
    return x

def tensor_to_im(x):
    x = detach_to_cpu(x)
    x = inv_im_trans(x)
    x = tensor_to_numpy(x)
    x = transpose_np(x)
    return x

# Predefined key <-> caption dict
key_captions = {
    'im': 'Image', 
    'gt': 'GT', 
    'seg': 'Input', 
    'error_map': 'Error map',
}
for k in ['28', '56', '224']:
    key_captions['pred_' + k] = 'Ours-%sx%s' % (k, k)
    key_captions['pred_' + k + '_overlay'] = '%sx%s' % (k, k)

"""
Return an image array with captions
keys in dictionary will be used as caption if not provided
values should contain lists of cv2 images
"""
def get_image_array(images, grid_shape, captions={}):
    w, h = grid_shape
    cate_counts = len(images)
    rows_counts = len(next(iter(images.values())))

    font = cv2.FONT_HERSHEY_SIMPLEX

    output_image = np.zeros([h*(rows_counts+1), w*cate_counts, 3], dtype=np.uint8)
    col_cnt = 0
    for k, v in images.items():

        # Default as key value itself
        caption = captions.get(k, k)

        # Handles new line character
        y0, dy = h-10-len(caption.split('\n'))*40, 40
        for i, line in enumerate(caption.split('\n')):
            y = y0 + i*dy
            cv2.putText(output_image, line, (col_cnt*w, y),
                     font, 0.8, (255,255,255), 2, cv2.LINE_AA)

        # Put images
        for row_cnt, img in enumerate(v):
            im_shape = img.shape
            if len(im_shape) == 2:
                img = img[..., np.newaxis]

            img = (img * 255).astype('uint8')

            output_image[(row_cnt+1)*h:(row_cnt+2)*h,
                         col_cnt*w:(col_cnt+1)*w, :] = img
            
        col_cnt += 1

    return output_image

"""
Create an image array, transform each image separately as needed
Will only put images in req_keys
"""
def pool_images(images, req_keys, row_cnt=10):
    req_images = {}

    def base_transform(im):
        im = tensor_to_np_float(im)
        im = im.transpose((1, 2, 0))

        # Resize
        if im.shape[1] != 224:
            im = cv2.resize(im, (224, 224), interpolation=cv2.INTER_NEAREST)

        if len(im.shape) == 2:
            im = im[..., np.newaxis]

        return im

    second_pass_keys = []
    for k in req_keys:

        if 'overlay' in k: 
            # Run overlay in the second pass, skip for now
            second_pass_keys.append(k)

            # Make sure the base key information is transformed
            base_key = k.replace('_overlay', '')
            if base_key in req_keys:
                continue
            else:
                k = base_key

        req_images[k] = []

        images[k] = detach_to_cpu(images[k])
        for i in range(min(row_cnt, len(images[k]))):

            im = images[k][i]

            # Handles inverse transform
            if k in ['im']:
                im = inv_im_trans(images[k][i])
            elif k in ['seg']:
                im = inv_seg_trans(images[k][i])

            # Now we are all numpy array
            im = base_transform(im)

            req_images[k].append(im)

    # Handle overlay images in the second pass
    for k in second_pass_keys:
        req_images[k] = []
        base_key = k.replace('_overlay', '')
        for i in range(min(row_cnt, len(images[base_key]))):

            # If overlay
            im = req_images[base_key][i]
            raw = req_images['im'][i]

            im = im.clip(0, 1)

            # Just red overlay
            im = (raw*0.5 + 0.5 * (raw * (1-im) 
                    + im * (np.array([1,0,0],dtype=np.float32)
                    .reshape([1,1,3]))))
            
            req_images[k].append(im)
    
    # Remove all temp items
    output_images = {}
    for k in req_keys:
        output_images[k] = req_images[k]

    return get_image_array(output_images, (224, 224), key_captions)

# Return cv2 image, directly usable for saving
def vis_prediction(images):

    keys = ['im', 'seg', 'gt', 'pred_224', 'pred_224_overlay'] # 'pred_28', 'pred_28_2', 'pred_56', 'pred_28_3', 'pred_56_2', 

    return pool_images(images, keys)

In [15]:
import cv2

import numpy as np

def get_random_structure(size):
    # The provided model is trained with 
    #   choice = np.random.randint(4)
    # instead, which is a bug that we fixed here
    choice = np.random.randint(1, 5)

    if choice == 1:
        return cv2.getStructuringElement(cv2.MORPH_RECT, (size, size))
    elif choice == 2:
        return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))
    elif choice == 3:
        return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size//2))
    elif choice == 4:
        return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size//2, size))

def random_dilate(seg, min=3, max=10):
    size = np.random.randint(min, max)
    kernel = get_random_structure(size)
    seg = cv2.dilate(seg,kernel,iterations = 1)
    return seg

def random_erode(seg, min=3, max=10):
    size = np.random.randint(min, max)
    kernel = get_random_structure(size)
    seg = cv2.erode(seg,kernel,iterations = 1)
    return seg

def compute_iou(seg, gt):
    intersection = seg*gt
    union = seg+gt
    return (np.count_nonzero(intersection) + 1e-6) / (np.count_nonzero(union) + 1e-6)

def perturb_seg(gt, iou_target=0.6):
    h, w = gt.shape
    seg = gt.copy()

    _, seg = cv2.threshold(seg, 127, 255, 0)

    # Rare case
    if h <= 2 or w <= 2:
        print('GT too small, returning original')
        return seg

    # Do a bunch of random operations
    for _ in range(250):
        for _ in range(4):
            lx, ly = np.random.randint(w), np.random.randint(h)
            lw, lh = np.random.randint(lx+1,w+1), np.random.randint(ly+1,h+1)

            # Randomly set one pixel to 1/0. With the following dilate/erode, we can create holes/external regions
            if np.random.rand() < 0.25:
                cx = int((lx + lw) / 2)
                cy = int((ly + lh) / 2)
                seg[cy, cx] = np.random.randint(2) * 255

            if np.random.rand() < 0.5:
                seg[ly:lh, lx:lw] = random_dilate(seg[ly:lh, lx:lw])
            else:
                seg[ly:lh, lx:lw] = random_erode(seg[ly:lh, lx:lw])

        if compute_iou(seg, gt) < iou_target:
            break

    return seg

In [16]:
import cv2
import numpy as np
import random
import math

##try:
   # from util.de_transform import perturb_seg
#except:
  #  from de_transform import perturb_seg


def modify_boundary(image, regional_sample_rate=0.1, sample_rate=0.1, move_rate=0.0, iou_target = 0.8):
    # modifies boundary of the given mask.
    # remove consecutive vertice of the boundary by regional sample rate
    # ->
    # remove any vertice by sample rate
    # ->
    # move vertice by distance between vertice and center of the mask by move rate. 
    # input: np array of size [H,W] image
    # output: same shape as input
    
    # get boundaries
    if int(cv2.__version__[0]) >= 4:
        contours, _ = cv2.findContours(image, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)
    else:
        _, contours, _ = cv2.findContours(image, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE)

    #only modified contours is needed actually. 
    sampled_contours = []   
    modified_contours = [] 

    for contour in contours:
        if contour.shape[0] < 10:
            continue
        M = cv2.moments(contour)

        #remove region of contour
        number_of_vertices = contour.shape[0]
        number_of_removes = int(number_of_vertices * regional_sample_rate)
        
        idx_dist = []
        for i in range(number_of_vertices-number_of_removes):
            idx_dist.append([i, np.sum((contour[i] - contour[i+number_of_removes])**2)])
            
        idx_dist = sorted(idx_dist, key=lambda x:x[1])
        
        remove_start = random.choice(idx_dist[:math.ceil(0.1*len(idx_dist))])[0]
        
       #remove_start = random.randrange(0, number_of_vertices-number_of_removes, 1)
        new_contour = np.concatenate([contour[:remove_start], contour[remove_start+number_of_removes:]], axis=0)
        contour = new_contour
        

        #sample contours
        number_of_vertices = contour.shape[0]
        indices = random.sample(range(number_of_vertices), int(number_of_vertices * sample_rate))
        indices.sort()
        sampled_contour = contour[indices]
        sampled_contours.append(sampled_contour)

        modified_contour = np.copy(sampled_contour)
        if (M['m00'] != 0):
            center = round(M['m10'] / M['m00']), round(M['m01'] / M['m00'])

            #modify contours
            for idx, coor in enumerate(modified_contour):

                change = np.random.normal(0,move_rate) # 0.1 means change position of vertex to 10 percent farther from center
                x,y = coor[0]
                new_x = x + (x-center[0]) * change
                new_y = y + (y-center[1]) * change

                modified_contour[idx] = [new_x,new_y]
        modified_contours.append(modified_contour)
        

    #draw boundary
    gt = np.copy(image)
    image = np.zeros_like(image)

    modified_contours = [cont for cont in modified_contours if len(cont) > 0]
    if len(modified_contours) == 0:
        image = gt.copy()
    else:
        image = cv2.drawContours(image, modified_contours, -1, (255, 0, 0), -1)

    image = perturb_seg(image, iou_target)
    
    return image

In [17]:
def get_seg_as_input(seg_image):

  I = np.eye(20)

  returned_image = I[seg_image]

  returned_image = np.transpose(returned_image, (2, 0, 1))

  return returned_image

In [18]:
import torch
def expand_classes(gt, num_classes):
  I = np.eye(num_classes)
  returned_image = I[gt.numpy()]
  returned_image = torch.from_numpy(returned_image)
  returned_image = returned_image.squeeze(0)
  gt = np.transpose(returned_image, (2, 0, 1))
  gt = torch.tensor(gt, dtype=torch.float)
  return gt

In [19]:
import os
from os import path
import warnings

from torch.utils.data.dataset import Dataset
from torchvision import transforms, utils
from PIL import Image
import numpy as np
import random
#from dataset.reseed import reseed
#import util.boundary_modification as boundary_modification

import torch

seg_normalization = transforms.Normalize(
                mean=[0.5],
                std=[0.5]
            )

import torch
import random

def reseed(seed):
    random.seed(seed)
    torch.manual_seed(seed)

def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret


def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (3, H, W)
    """
    coord = make_coord_constant2
    #coord = make_coord(img.shape[-2:])
    rgb = img.view(1, -1).permute(1, 0)
    return coord, rgb


def resize_fn(img, size):
    return transforms.ToTensor()(
        transforms.Resize(size, Image.BICUBIC)(
            transforms.ToPILImage()(img)))


class OnlineTransformDataset_Second(Dataset):
    """
    Method 0 - FSS style (class/1.jpg class/1.png)
    Method 1 - Others style (XXX.jpg XXX.png)
    """
    def __init__(self, root, need_name=False, method=0, perturb=True, test = False):
        self.root = root
        self.need_name = need_name
        self.method = method

        #self.im_list = []


        # code goes here

        self.im_list = []
        self.mask_list = []
        self.seg_list = []

        self.flag = True


        img_folder = "/content/drive/MyDrive/multiclass-seg/cityscapes/leftImg8bit/train"
        mask_folder = "/content/drive/MyDrive/multiclass-seg/cityscapes/gtFine/train"
        seg_folder = "/content/drive/MyDrive/multiclass-seg/cityscapes/segs"


        if test == True:
          train_image_path = "/content/drive/MyDrive/multiclass-seg/cityscapes/leftImg8bit/train/hamburg/hamburg_000000_000042_leftImg8bit.png"
          train_seg_path = "/content/drive/MyDrive/multiclass-seg/cityscapes/segs/hamburg_000000_000042_leftImg8bit.png"
          train_gt = "/content/drive/MyDrive/multiclass-seg/cityscapes/gtFine/train/hamburg/hamburg_000000_000042_gtFine_color.png"

          self.im_list.append(train_image_path)
          self.mask_list.append(train_seg_path)
          self.seg_list.append(train_gt)

        else:


          #####

          for root, _, files in os.walk(img_folder):
              for filename in files:
                  if filename.endswith('.png'):

                      imgpath = os.path.join(root, filename)
                      foldername = os.path.basename(os.path.dirname(imgpath))
                      maskname = filename.replace('leftImg8bit', 'gtFine_labelTrainIds')
                      maskpath = os.path.join(mask_folder, foldername, maskname)
                      segPath = os.path.join(seg_folder, filename)
                      if os.path.isfile(imgpath) and os.path.isfile(maskpath) and os.path.isfile(segPath):
                          self.im_list.append(imgpath)
                          self.mask_list.append(maskpath)
                          self.seg_list.append(segPath)

                      else:
                          print('cannot find the mask or image:', imgpath)

          #####




#        if method == 0:
 #           # Get images
  #          self.im_list = []
   #         classes = os.listdir(self.root)
    #        for c in classes:
     #           imgs = os.listdir(path.join(root, c))
      #          jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
       #         unmatched = any([im.replace('.jpg', '.png') not in imgs for im in jpg_list])

        #        if unmatched:
         #           print('Number of image/gt unmatch in class ', c)
          #          print('The whole class is ignored', len(jpg_list))

           #         warnings.warn('Dataset unmatch error')
            #    else:
             #       joint_list = [path.join(root, c, im) for im in jpg_list]
              #      self.im_list.extend(joint_list)

        #elif method == 1:
         #   self.im_list = [path.join(self.root, im) for im in os.listdir(self.root) if '.jpg' in im]

        print('%d images found' % len(self.im_list))

        if perturb:
            # Make up some transforms
            self.bilinear_dual_transform = transforms.Compose([
                transforms.RandomCrop((440, 880), pad_if_needed=True),
                transforms.RandomHorizontalFlip(),
            ])

            self.bilinear_dual_transform_im = transforms.Compose([
                transforms.RandomCrop((440, 880), pad_if_needed=True),
                transforms.RandomHorizontalFlip(),
            ])

            self.im_transform = transforms.Compose([
                transforms.ColorJitter(0.2, 0.05, 0.05, 0),
                transforms.RandomGrayscale(),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
            ])
        else:
            # Make up some transforms
            self.bilinear_dual_transform = transforms.Compose([
                transforms.Resize((220, 440), interpolation=Image.NEAREST), 
                transforms.CenterCrop((220, 440)),
            ])

            self.bilinear_dual_transform_im = transforms.Compose([
                transforms.Resize((220, 440), interpolation=Image.BILINEAR), 
                transforms.CenterCrop((220, 440)),
            ])

            self.im_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
            ])

        self.gt_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        self.seg_transform1 = transforms.Compose([
            transforms.ToTensor()
        ])

        self.seg_transform2 = transforms.Compose([
            seg_normalization
        ])



    def __getitem__(self, idx):

        #print(self.im_list[idx] + ", " + self.mask_list[idx] + ", " + self.seg_list[idx])


        im = Image.open(self.im_list[idx]).convert('RGB')  # im = Image.open(self.im_list[idx]).convert('RGB')

        #############################################################
        #############################################################
        #############################################################
        #############################################################
        #############################################################

        #if self.method == 0:
         #   gt = Image.open(self.im_list[idx][:-3]+'png').convert('L')
        #else:
         #   path = self.im_list[idx].replace('.jpg','.png')
          #  gt = Image.open(path).convert('L')

        #seed = np.random.randint(2147483647)
        
        #reseed(seed)
        #im = self.bilinear_dual_transform_im(im)

        #reseed(seed)
        #gt = self.bilinear_dual_transform(gt)

        #iou_max = 1.0
        #iou_min = 0.8
        #iou_target = np.random.rand()*(iou_max-iou_min) + iou_min
        #seg = modify_boundary((np.array(gt)>0.5).astype('uint8')*255, iou_target=iou_target)


        #############################################################
        #############################################################
        #############################################################
        #############################################################
        #############################################################

        ############################

        im = self.bilinear_dual_transform_im(im)

        gt = self.mask_list[idx]
        gt = Image.open(gt)
        gt = self.bilinear_dual_transform(gt)

        seg_image = self.seg_list[idx]
        seg_image = Image.open(seg_image)
        seg_image = self.bilinear_dual_transform(seg_image)

        #seg_image = np.where(seg_image == 255, 0, seg_image)
        #seg_image = get_seg_as_input(seg_image)

        #seg_image *= 255

        

        ############################

        gt = np.array(gt).astype('int32')

        im = self.im_transform(im)
        gt = self.seg_transform1(gt)
        seg = self.seg_transform1(seg_image)

        seg = np.where(seg_image == 255, 0, seg_image)
        gt = np.where(gt == 255, 0, gt)

        #print(gt[0, 200:400, 200:400])



        seg = get_seg_as_input(seg)
        seg = torch.from_numpy(seg)
        seg = seg.float()

        #seg = seg.reshape(20, 224, 224)
        seg = self.seg_transform2(seg)



        hr_coord, hr_rgb = to_pixel_samples(seg.contiguous())

        cell = torch.ones_like(hr_coord)
        cell[:, 0] *= 2 / seg.shape[-2] 
        cell[:, 1] *= 2 / seg.shape[-1]

        #crop_lr = resize_fn(seg, seg.shape[-2]) # 

        im = im.float()
        seg = seg.float()
        gt = torch.from_numpy(gt)
        gt = expand_classes(gt, 20)

        if self.need_name:
            return im, seg, gt, os.path.basename(self.im_list[idx][:-4])
        else:
            return im, seg, gt, {'coord': hr_coord, 'cell': cell, 'gt': hr_rgb}   # return im, seg, gt, {'inp': crop_lr, 'coord': hr_coord, 'cell': cell, 'gt': hr_rgb}

    def __len__(self):
        return len(self.im_list)

#if __name__ == '__main__':
#    ecssd_dir = '/PathTo/data/ecssd'
#    ecssd_dataset = OnlineTransformDataset(ecssd_dir, method=1, perturb=True)

#    import pdb; pdb.set_trace()
#    ecssd_dataset[0]

In [20]:
import os
from os import path
import warnings

from torch.utils.data.dataset import Dataset
from torchvision import transforms, utils
from PIL import Image
import numpy as np
import random
#from dataset.reseed import reseed
#import util.boundary_modification as boundary_modification

import torch

seg_normalization = transforms.Normalize(
                mean=[0.5],
                std=[0.5]
            )

def make_coord(shape, ranges=None, flatten=True):
    """ Make coordinates at grid centers.
    """
    coord_seqs = []
    for i, n in enumerate(shape):
        if ranges is None:
            v0, v1 = -1, 1
        else:
            v0, v1 = ranges[i]
        r = (v1 - v0) / (2 * n)
        seq = v0 + r + (2 * r) * torch.arange(n).float()
        coord_seqs.append(seq)
    ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
    if flatten:
        ret = ret.view(-1, ret.shape[-1])
    return ret


def to_pixel_samples(img):
    """ Convert the image to coord-RGB pairs.
        img: Tensor, (3, H, W)
    """
    coord = make_coord(img.shape[-2:])
    rgb = img.view(1, -1).permute(1, 0)
    return coord, rgb


def resize_fn(img, size):
    return transforms.ToTensor()(
        transforms.Resize(size, Image.BICUBIC)(
            transforms.ToPILImage()(img)))


class OnlineTransformDataset_crm(Dataset):
    """
    Method 0 - FSS style (class/1.jpg class/1.png)
    Method 1 - Others style (XXX.jpg XXX.png)
    """
    def __init__(self, root, need_name=False, method=0, perturb=True):
        self.root = root
        self.need_name = need_name
        self.method = method


        if method == 0:
            # Get images
            self.im_list = []
            classes = os.listdir(self.root)
            for c in classes:
                imgs = os.listdir(path.join(root, c))
                jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
                unmatched = any([im.replace('.jpg', '.png') not in imgs for im in jpg_list])

                if unmatched:
                    print('Number of image/gt unmatch in class ', c)
                    print('The whole class is ignored', len(jpg_list))

                    warnings.warn('Dataset unmatch error')
                else:
                    joint_list = [path.join(root, c, im) for im in jpg_list]
                    self.im_list.extend(joint_list)

        elif method == 1:
            self.im_list = [path.join(self.root, im) for im in os.listdir(self.root) if '.jpg' in im]

        print('%d images found' % len(self.im_list))

        if perturb:
            # Make up some transforms
            self.bilinear_dual_transform = transforms.Compose([
                transforms.RandomCrop((224, 224), pad_if_needed=True),
                transforms.RandomHorizontalFlip(),
            ])

            self.bilinear_dual_transform_im = transforms.Compose([
                transforms.RandomCrop((224, 224), pad_if_needed=True),
                transforms.RandomHorizontalFlip(),
            ])

            self.im_transform = transforms.Compose([
                transforms.ColorJitter(0.2, 0.05, 0.05, 0),
                transforms.RandomGrayscale(),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
            ])
        else:
            # Make up some transforms
            self.bilinear_dual_transform = transforms.Compose([
                transforms.Resize(224, interpolation=Image.NEAREST), 
                transforms.CenterCrop(224),
            ])

            self.bilinear_dual_transform_im = transforms.Compose([
                transforms.Resize(224, interpolation=Image.BILINEAR), 
                transforms.CenterCrop(224),
            ])

            self.im_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
            ])

        self.gt_transform = transforms.Compose([
            transforms.ToTensor(),
        ])

        self.seg_transform = transforms.Compose([
            transforms.ToTensor(),
            seg_normalization,
        ])

    def __getitem__(self, idx):
        im = Image.open(self.im_list[idx]).convert('RGB')

        if self.method == 0:
            gt = Image.open(self.im_list[idx][:-3]+'png').convert('L')
        else:
            gt = Image.open(self.im_list[idx].replace('.jpg','.png')).convert('L')

        seed = np.random.randint(2147483647)
        
        reseed(seed)
        im = self.bilinear_dual_transform_im(im)

        reseed(seed)
        gt = self.bilinear_dual_transform(gt)

        iou_max = 1.0
        iou_min = 0.8
        iou_target = np.random.rand()*(iou_max-iou_min) + iou_min
        seg = modify_boundary((np.array(gt)>0.5).astype('uint8')*255, iou_target=iou_target)

        temp = seg

        im = self.im_transform(im)
        gt = self.gt_transform(gt)
        seg = self.seg_transform(seg)


        hr_coord, hr_rgb = to_pixel_samples(seg.contiguous())

        cell = torch.ones_like(hr_coord)
        cell[:, 0] *= 2 / seg.shape[-2] 
        cell[:, 1] *= 2 / seg.shape[-1]

        crop_lr = resize_fn(seg, seg.shape[-2]) # 

        if self.need_name:
            return im, seg, gt, os.path.basename(self.im_list[idx][:-4])
        else:
            return im, seg, gt, {'inp': crop_lr, 'coord': hr_coord, 'cell': cell, 'gt': hr_rgb}

    def __len__(self):
        return len(self.im_list)

In [21]:
"""


import numpy as np
import torch
import torch.nn as nn
from torch import optim
from torch.utils.data import DataLoader, ConcatDataset

#from models.network.crm import CRMNet
#from models.sobel_op import SobelComputer

#from dataset import OnlineTransformDataset_crm as OnlineTransformDataset
 
#from util.logger import BoardLogger
#from util.model_saver import ModelSaver
#from util.hyper_para import HyperParameters
#from util.log_integrator import Integrator
#from util.metrics_compute_crm import compute_loss_and_metrics, iou_hooks_to_be_used
#from util.image_saver_crm import vis_prediction

import time
import os
import datetime

torch.backends.cudnn.benchmark = True

# Parse command line arguments
para = HyperParameters()
para = para.parse()



# Logging
if para['id'].lower() != 'null':
    long_id = '%s_%s' % (para['id'],datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S'))
else:
    long_id = None



  
logger = BoardLogger(long_id)
logger.log_string('hyperpara', str(para))

print('CUDA Device count: ', torch.cuda.device_count())

# Construct model
model = CRMNet(backend='resnet50')
model = nn.DataParallel(
        model.cuda(), device_ids=[0] #[0,1]
    )

#if para['load'] is not None:
    #model.load_state_dict(torch.load(para['load']))
optimizer = optim.Adam(model.parameters(), lr=para['lr'], weight_decay=para['weight_decay'])


duts_tr_dir = os.path.join('data', 'DUTS-TR')
duts_te_dir = os.path.join('data', 'DUTS-TE')
ecssd_dir = os.path.join('data', 'ecssd')
msra_dir = os.path.join('data', 'MSRA_10K')


#root_dir = "/content/drive/MyDrive/multiclass-seg/cityscapes"


#train_dataset = OnlineTransformDataset_crm(root_dir, method=1, perturb=False)


#fss_dataset = OnlineTransformDataset_crm(os.path.join('data', 'fss'), method=0, perturb=True)

duts_tr_dataset = OnlineTransformDataset_crm(duts_tr_dir, method=1, perturb=True)
duts_te_dataset = OnlineTransformDataset_crm(duts_te_dir, method=1, perturb=True)

#ecssd_dataset = OnlineTransformDataset_crm(ecssd_dir, method=1, perturb=True)
msra_dataset = OnlineTransformDataset_crm(msra_dir, method=1, perturb=True)

####print('DUTS-TR dataset size: ', len(duts_tr_dataset))
####print('DUTS-TE dataset size: ', len(duts_te_dataset))
####print('MSRA-10K dataset size: ', len(msra_dataset))

train_dataset = ConcatDataset([duts_tr_dataset, duts_te_dataset, msra_dataset]) #[fss_dataset, duts_tr_dataset, duts_te_dataset, ecssd_dataset, msra_dataset]

##################train_dataset = ConcatDataset([ duts_tr_dataset, duts_te_dataset, msra_dataset]) #[fss_dataset, duts_tr_dataset, duts_te_dataset, ecssd_dataset, msra_dataset]

print('Total training size: ', len(train_dataset))

# For randomness: https://github.com/pytorch/pytorch/issues/5059
def worker_init_fn(worker_id): 
    np.random.seed(np.random.get_state()[1][0] + worker_id)

# Dataloaders, multi-process data loading
train_loader = DataLoader(train_dataset, para['batch_size'], shuffle=True, num_workers=8,
                            worker_init_fn=worker_init_fn, drop_last=True, pin_memory=True)

sobel_compute = SobelComputer()

# Learning rate decay scheduling
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, para['steps'], para['gamma'])

saver = ModelSaver(long_id)
report_interval = 50
save_im_interval = 800
memory_chunk = 50176

total_epoch = int(para['iterations']/len(train_loader) + 0.5)
print('Actual training epoch: ', total_epoch)

train_integrator = Integrator(logger)
train_integrator.add_hook(iou_hooks_to_be_used)
total_iter = 0
last_time = 0
for e in range(total_epoch):
    np.random.seed() # reset seed
    epoch_start_time = time.time()

    # Train loop
    model = model.train()
    for im, seg, gt, crm_data in train_loader:
        im, seg, gt = im.cuda(), seg.cuda(), gt.cuda() # [12, 3, 224, 224] [12, 1, 224, 224] [12, 1, 224, 224]
        for k, v in crm_data.items():
            crm_data[k] = v.cuda()

        total_iter += 1
        if total_iter % 5000 == 0:
            saver.save_model(model, total_iter)

        images = {}
        for i in range(0, seg.shape[-2]*seg.shape[-1], memory_chunk):
            chunk_images = model(im, seg, coord=crm_data['coord'][:, i:i+memory_chunk, :], cell=crm_data['cell'][:, i:i+memory_chunk, :])
            if 'pred_224' not in images.keys():
                images = chunk_images
            else:
                for key in images.keys():
                    images[key] = torch.cat((images[key], chunk_images[key]), axis=1)
        for key in images.keys():
            images[key] = images[key].view(images[key].shape[0], images[key].shape[1]//(seg.shape[-2]*seg.shape[-1]), *seg.shape[-2:])

        images['im'] = im
        images['seg'] = seg
        images['gt'] = gt
        sobel_compute.compute_edges(images)

        loss_and_metrics = compute_loss_and_metrics(images, para)
        train_integrator.add_dict(loss_and_metrics)

        optimizer.zero_grad()
        loss1 = loss_and_metrics['total_loss'].float()
        (loss1).backward()
        optimizer.step()

        if total_iter % report_interval == 0:
            logger.log_scalar('train/lr', scheduler.get_lr()[0], total_iter)
            train_integrator.finalize('train', total_iter)
            train_integrator.reset_except_hooks()

        # Need to put step AFTER get_lr() for correct logging, see issue #22107 in PyTorch
        scheduler.step()

        if total_iter % save_im_interval == 0:
            predict_vis = vis_prediction(images)
            logger.log_cv2('train/predict', predict_vis, total_iter)

# Final save!
saver.save_model(model, total_iter)

"""

'\n\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch import optim\nfrom torch.utils.data import DataLoader, ConcatDataset\n\n#from models.network.crm import CRMNet\n#from models.sobel_op import SobelComputer\n\n#from dataset import OnlineTransformDataset_crm as OnlineTransformDataset\n \n#from util.logger import BoardLogger\n#from util.model_saver import ModelSaver\n#from util.hyper_para import HyperParameters\n#from util.log_integrator import Integrator\n#from util.metrics_compute_crm import compute_loss_and_metrics, iou_hooks_to_be_used\n#from util.image_saver_crm import vis_prediction\n\nimport time\nimport os\nimport datetime\n\ntorch.backends.cudnn.benchmark = True\n\n# Parse command line arguments\npara = HyperParameters()\npara = para.parse()\n\n\n\n# Logging\nif para[\'id\'].lower() != \'null\':\n    long_id = \'%s_%s\' % (para[\'id\'],datetime.datetime.now().strftime(\'%Y-%m-%d_%H:%M:%S\'))\nelse:\n    long_id = None\n\n\n\n  \nlogger = BoardLogger(long_i

In [22]:
import shutil

from distutils.dir_util import copy_tree

def save_to_drive(source_path):
  #print(source_path)
  shutil.copy(source_path,"/content/drive/MyDrive/multiclass-seg/CRM-models/CRM_unedited")

In [23]:
"""


import os

save_path = os.path.join('.', 'weights')

os.makedirs(save_path, exist_ok=True)

model_path = os.path.join(save_path, 'model_%s' % 8000)
torch.save(model.state_dict(), model_path)
print('Model saved to %s.' % model_path)
save_to_drive(model_path)

"""

"\n\n\nimport os\n\nsave_path = os.path.join('.', 'weights')\n\nos.makedirs(save_path, exist_ok=True)\n\nmodel_path = os.path.join(save_path, 'model_%s' % 8000)\ntorch.save(model.state_dict(), model_path)\nprint('Model saved to %s.' % model_path)\nsave_to_drive(model_path)\n\n"

In [24]:
import torch
import os
from torch.utils.data import DataLoader, ConcatDataset
import timeit

model_temp = torch.load("/content/drive/MyDrive/multiclass-seg/CRM-models/CRM_unedited/model_8000")
model = torch.nn.DataParallel(CRMNet(backend = 'resnet50'))
model.load_state_dict(model_temp)


def worker_init_fn(worker_id): 
    np.random.seed(np.random.get_state()[1][0] + worker_id)


root_dir = "/content/drive/MyDrive/multiclass-seg/cityscapes"
dataset = OnlineTransformDataset_Second(root_dir, method=1, perturb=False)
loader = DataLoader(dataset, 1, shuffle=True, num_workers=8,
                            worker_init_fn=worker_init_fn, drop_last=True, pin_memory=True)

memory_chunk = 50176

counter = 0

final_images =[]
for im, seg, gt, crm_data in loader:
        im, seg, gt = im.cuda(), seg.cuda(), gt.cuda() # [12, 3, 224, 224] [12, 1, 224, 224] [12, 1, 224, 224]
        for k, v in crm_data.items():
            crm_data[k] = v.cuda()


        counter += 1

        final_seg = torch.empty(1, 1, 220, 440)
        

        #start = timeit.default_timer()
        for c in range(1, seg.shape[1]):



          class_images = []
          input_seg = seg[0, c]
          input_seg = input_seg.unsqueeze(0).unsqueeze(0)
          
          images = {}
          for i in range(0, seg.shape[-2]*seg.shape[-1], memory_chunk):
              
              chunk_images = model(im, input_seg, coord=crm_data['coord'][:, i:i+memory_chunk, :], cell=crm_data['cell'][:, i:i+memory_chunk, :])
              if 'pred_224' not in images.keys():
                  images = chunk_images
              else:
                  for key in images.keys():
                      images[key] = torch.cat((images[key], chunk_images[key]), axis=1)
          for key in images.keys():
            images[key] = images[key].view(images[key].shape[0], images[key].shape[1]//(seg.shape[-2]*seg.shape[-1]), *seg.shape[-2:])

            if key == 'pred_224':

              converted_matrix = np.where(images[key].cpu().detach().numpy() <= 0.5, 0, 1)
              converted_matrix = torch.from_numpy(converted_matrix)
              final_seg = torch.cat((final_seg, converted_matrix), axis=1)

        #stop = timeit.default_timer()

        #print("time: " + str(stop-start))


        np_matrix = final_seg.detach().numpy()
        zero_rows = np.where(~np_matrix.any(axis=(2, 3)))[0]
        np_matrix[0, zero_rows, :, :] = 1

        j_max_index = np.argmax(np_matrix, axis=1)
        I = np.eye(20)
        new_matrix = I[j_max_index]
        new_matrix = np.transpose(new_matrix, (0, 3, 1, 2))
        bool_matrix = (~np.all(np_matrix==0, axis=1)).astype(int)
        new_matrix *= bool_matrix

        new_matrix = np.argmax(new_matrix, axis=1)
        new_matrix = torch.from_numpy(new_matrix).to("cuda")
        gtmax = gt.argmax(1).to("cuda")
        seg = seg.argmax(1).to("cuda")

        old_iou = compute_tensor_iou(seg, gtmax)
        new_iou = compute_tensor_iou(new_matrix, gtmax)


        print("old iou:" + str(old_iou))
        print("new iou: " + str(new_iou))
        print("iou gain: " + str(new_iou-old_iou))

"""
images['im'] = im
images['seg'] = seg
images['gt'] = gt
sobel_compute.compute_edges(images)

loss_and_metrics = compute_loss_and_metrics(images, para)
train_integrator.add_dict(loss_and_metrics)

"""

ASPP_4level
2975 images found


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  gt = torch.tensor(gt, dtype=torch.float)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  gt = torch.tensor(gt, dtype=torch.float)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  gt = torch.tensor(gt, dtype=torch.float)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  gt = torch.tensor(gt, dtype=torch.float)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  gt = torch.tensor(gt, dtype=torch.float)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  gt = torch.tensor(gt, dtype=torch.float)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  gt = torch.tensor(gt, dtype=torch.float)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  gt = torch.tensor(gt, dtype=torch.float)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


old iou:tensor([0.9138], device='cuda:0')
new iou: tensor([0.4593], device='cuda:0')
iou gain: tensor([-0.4544], device='cuda:0')
old iou:tensor([0.6994], device='cuda:0')
new iou: tensor([0.3651], device='cuda:0')
iou gain: tensor([-0.3343], device='cuda:0')
old iou:tensor([0.8475], device='cuda:0')
new iou: tensor([0.2769], device='cuda:0')
iou gain: tensor([-0.5706], device='cuda:0')
old iou:tensor([0.8244], device='cuda:0')
new iou: tensor([0.4861], device='cuda:0')
iou gain: tensor([-0.3383], device='cuda:0')
old iou:tensor([0.8429], device='cuda:0')
new iou: tensor([0.1481], device='cuda:0')
iou gain: tensor([-0.6948], device='cuda:0')
old iou:tensor([0.7822], device='cuda:0')
new iou: tensor([0.2493], device='cuda:0')
iou gain: tensor([-0.5329], device='cuda:0')
old iou:tensor([0.8655], device='cuda:0')
new iou: tensor([0.2831], device='cuda:0')
iou gain: tensor([-0.5825], device='cuda:0')


KeyboardInterrupt: ignored