In [9]:
!pip install torch==1.4.0 torchvision==0.5.0 ninja torch-encoding
!pip install git+https://github.com/zhanghang1989/PyTorch-Encoding/
!pip install git+https://github.com/facebookresearch/fvcore.git

Collecting git+https://github.com/zhanghang1989/PyTorch-Encoding/
  Cloning https://github.com/zhanghang1989/PyTorch-Encoding/ to /tmp/pip-req-build-kygb_f52
  Running command git clone -q https://github.com/zhanghang1989/PyTorch-Encoding/ /tmp/pip-req-build-kygb_f52
Building wheels for collected packages: torch-encoding
  Building wheel for torch-encoding (setup.py) ... [?25l[?25hdone
  Created wheel for torch-encoding: filename=torch_encoding-1.2.2b20210418-cp37-cp37m-linux_x86_64.whl size=7347175 sha256=b4be4af21d78222465b06cf4c341b38901a397a31d1fcb5a3b94bbbb87c05372
  Stored in directory: /tmp/pip-ephem-wheel-cache-aom21gtw/wheels/f8/4f/46/924a4c89ee95252b34c3e257f1de2664a053e52c5aa5013d4a
Successfully built torch-encoding
Collecting git+https://github.com/facebookresearch/fvcore.git
  Cloning https://github.com/facebookresearch/fvcore.git to /tmp/pip-req-build-01jjs_a1
  Running command git clone -q https://github.com/facebookresearch/fvcore.git /tmp/pip-req-build-01jjs_a1
Build

In [1]:
import warnings
try:
    from queue import Queue
except ImportError:
    from Queue import Queue

import torch
from torch.nn.modules.batchnorm import _BatchNorm

from encoding.utils.misc import EncodingDeprecationWarning
from encoding.functions import *

class DistSyncBatchNorm(_BatchNorm):
    r"""Cross-GPU Synchronized Batch normalization (SyncBN)

    Standard BN [1]_ implementation only normalize the data within each device (GPU).
    SyncBN normalizes the input within the whole mini-batch.
    We follow the sync-onece implmentation described in the paper [2]_ .
    Please see the design idea in the `notes <./notes/syncbn.html>`_.

    .. math::

        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta

    The mean and standard-deviation are calculated per-channel over
    the mini-batches and gamma and beta are learnable parameter vectors
    of size C (where C is the input size).

    During training, this layer keeps a running estimate of its computed mean
    and variance. The running sum is kept with a default momentum of 0.1.

    During evaluation, this running mean/variance is used for normalization.

    Because the BatchNorm is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm

    Args:
        num_features: num_features from an expected input of
            size batch_size x num_features x height x width
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Default: 0.1
        sync: a boolean value that when set to ``True``, synchronize across
            different gpus. Default: ``True``
        activation : str
            Name of the activation functions, one of: `leaky_relu` or `none`.
        slope : float
            Negative slope for the `leaky_relu` activation.

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

    Reference:
        .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015*
        .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*

    Examples:
        >>> m = DistSyncBatchNorm(100)
        >>> net = torch.nn.parallel.DistributedDataParallel(m)
        >>> output = net(input)
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1, process_group=None):
        super(DistSyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True, track_running_stats=True)
        self.process_group = process_group

    def forward(self, x):
        need_sync = self.training or not self.track_running_stats
        process_group = None
        if need_sync:
            process_group = torch.distributed.group.WORLD
            if self.process_group:
                process_group = self.process_group
            world_size = torch.distributed.get_world_size(process_group)
            need_sync = world_size > 1

        # Resize the input to (B, C, -1).
        input_shape = x.size()
        x = x.view(input_shape[0], self.num_features, -1)
        #def forward(ctx, x, gamma, beta, running_mean, running_var, eps, momentum, training, process_group):
        y = dist_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var,
                               self.eps, self.momentum, self.training, process_group)

        #_var = _exs - _ex ** 2
        #running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * _ex)
        #running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * _var)
        return y.view(input_shape)


In [3]:
import encoding
model = encoding.models.get_segmentation_model('deeplab', dataset='citys',
                                   backbone="resnest50", aux='True',
                                   se_loss='False', norm_layer=DistSyncBatchNorm,
                                   base_size=520, crop_size=248)

In [4]:
import torch
torch.cuda.empty_cache()

In [10]:
checkpoint = torch.load('/content/drive/MyDrive/model_best.pth(1).tar')


In [22]:
torch.cuda.empty_cache()

# Get the model
with torch.no_grad():
    checkpoint = torch.load('/content/drive/MyDrive/model_best.pth(9).tar')
    model.load_state_dict(checkpoint['state_dict'])
    model.eval()
    model.cuda()

# Prepare the image
filename = '/content/munster_000000_000019_leftImg8bitpng.png'
img = encoding.utils.load_image(filename).cuda().unsqueeze(0)

# Make prediction
output = model.evaluate(img.cuda())
predict = torch.max(output, 1)[1].cpu().numpy() + 1

# Get color pallete for visualization
mask = encoding.utils.get_mask_pallete(predict, 'citys')
mask.save('/content/output_{}.png'.format(checkpoint["epoch"]))