In [1]:
import os
import json
import time
import math
import random
import numpy as np
from PIL import Image, ImageFilter, ImageOps

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

# Helpers

In [2]:
class AverageMeter(object):
    """
    Computes and stores the average and current value of some metric.
    
    Reference: https://github.com/pytorch/examples/blob/master/imagenet/main.py
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        
        
class Normalize(object):
    """
    Normalize an image (tensor) with mean and standard deviation. This
    should be just before ToTensor.
    """
    def __init__(self, mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)):
        self.mean = mean
        self.std = std
        
    def __call__(self, sample):
        # Extract PIL image and PIL label from dict
        img = sample['image']
        lab = sample['label']
        
        # Convert PIL data to NumPy array
        img = np.array(img).astype(np.float32)
        lab = np.array(lab).astype(np.float32)
        
        # Normalize img
        img /= 255.0
        img -= self.mean
        img /= self.std
        
        return {'image': img, 'label': lab}
    
    
class ToTensor(object):
    """
    Convert NumPy arrays to PyTorch tensors. This should be 
    the last transformation.
    """
    def __call__(self, sample):
        # Extract PIL image and PIL label from dict
        img = sample['image']
        lab = sample['label']
        
        # Convert PIL data to NumPy array
        img = np.array(img).astype(np.float32)
        lab = np.array(lab).astype(np.float32)
        
        # H x W x C -> C x H x W
        img = img.transpose((2, 0, 1))
        
        # Convert NumPy array to PyTorch tensor
        img = torch.from_numpy(img).float()
        lab = torch.from_numpy(lab).float()
        
        return {'image': img, 'label': lab}
    
    
class RandomHorizontalFlip(object):
    """
    Randomly horizontal flip image and label.
    
    NOTE: Returns data in PIL format
    """
    def __init__(self, p=0.5):
        self.p = p
        
    def __call__(self, sample):
        # Extract PIL image and PIL label from dict
        img = sample['image']
        lab = sample['label']
        
        # Horizontally flip
        if random.random() < self.p:
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
            lab = lab.transpose(Image.FLIP_LEFT_RIGHT)
            
        return {'image': img, 'label': lab}
    
                
class RandomGaussianBlur(object):
    """
    Randomly apply Gaussian blur to image only.
    
    NOTE: Returns data in PIL format
    """
    def __init__(self, p=0.5):
        self.p = p
        
    def __call__(self, sample):
        # Extract PIL image and PIL label from dict
        img = sample['image']
        lab = sample['label']
        
        # Apply Gaussian blur to image
        if random.random() < self.p:
            img = img.filter(ImageFilter.GaussianBlur(radius=random.random()))
                
        return {'image': img, 'label': lab}
    
    
class FixedResize(object):
    """
    Resizes image and label to a fixed size.
    
    NOTE: Returns data in PIL format
    """
    def __init__(self, size=256):
        self.size = (size, size)
        
    def __call__(self, sample):
        # Extract PIL image and PIL label from dict
        img = sample['image']
        lab = sample['label']
        
        # Resize image and label
        img = img.resize(size=self.size, resample=Image.BILINEAR)
        lab = lab.resize(size=self.size, resample=Image.NEAREST)
        
        return {'image': img, 'label': lab}
    
    
class FixedScaleCrop(object):
    """
    Resizes image and label to a fixed size and then returns
    center cropped image and label
    
    NOTE: Returns data in PIL format
    """
    def __init__(self, crop_size=256):
        self.crop_size = crop_size
        
    def __call__(self, sample):
        # Extract PIL image and PIL label from dict
        img = sample['image']
        lab = sample['label']
        
        # Compute resize width and height
        width, height = img.size
        if width > height:
            resize_h = self.crop_size
            resize_w = int(resize_h * float(width) / height)
        else:
            resize_w = self.crop_size
            resize_h = int(resize_w * float(height) / width)
            
        # Resize image and label
        img = img.resize(size=(resize_w, resize_h), resample=Image.BILINEAR)
        lab = lab.resize(size=(resize_w, resize_h), resample=Image.NEAREST)
        
        # Center crop the resized image
        x1 = int(round(resize_w - self.crop_size) / 2.0)
        y1 = int(round(resize_h - self.crop_size) / 2.0)
        x2 = x1 + self.crop_size
        y2 = y1 + self.crop_size
        
        img = img.crop(box=(x1, y1, x2, y2))
        lab = lab.crop(box=(x1, y1, x2, y2))
        
        return {'image': img, 'label': lab}
    
    
class RandomScaleCrop(object):
    """
    Resize image and label by a random scale, and then randomly 
    crop the resized image and label.
    
    base_size must be > crop_size and multiple of 8
    
    fill: int, for ignoring purpose, as labelId 255 is to be ignored
    
    NOTE: Returns data in PIL format
    """
    def __init__(self, base_size, crop_size, fill=255):
        self.base_size = base_size
        self.crop_size = crop_size
        self.fill = fill
    
    def __call__(self, sample):
        # Extract PIL image and PIL label from dict
        img = sample['image']
        lab = sample['label']
        
        # Randomly scale short edge
        short_size = random.randint(int(self.base_size * 0.75), 
                                    int(self.base_size * 1.75))
        
        # Compute resize width and height
        width, height = img.size
        if width > height:
            resize_h = short_size
            resize_w = int(resize_h * float(width) / height)
        else:
            resize_w = short_size
            resize_h = int(resize_w * float(height) / width)
            
        # Resize image and label
        img = img.resize(size=(resize_w, resize_h), resample=Image.BILINEAR)
        lab = lab.resize(size=(resize_w, resize_h), resample=Image.NEAREST)
        
        # Pad image and label
        if short_size < self.crop_size:
            pad_h = self.crop_size - resize_h if resize_h < self.crop_size else 0
            pad_w = self.crop_size - resize_w if resize_w < self.crop_size else 0
            
            img = ImageOps.expand(img, border=(0, 0, pad_w, pad_h), fill=0)
            lab = ImageOps.expand(lab, border=(0, 0, pad_w, pad_h), fill=self.fill)
            
        # Randomly crop the resized image and label
        max_x = 1 if resize_w - self.crop_size < 0 else resize_w - self.crop_size
        max_y = 1 if resize_h - self.crop_size < 0 else resize_h - self.crop_size
        x1 = random.randint(0, max_x)
        y1 = random.randint(0, max_y)
        x2 = x1 + self.crop_size
        y2 = y1 + self.crop_size
        
        img = img.crop(box=(x1, y1, x2, y2))
        lab = lab.crop(box=(x1, y1, x2, y2))
        
        return {'image': img, 'label': lab}

# Cityscapes Dataset

In [3]:
class Cityscapes(data.Dataset):
    """
    Modified from: https://pytorch.org/docs/master/_modules/torchvision/datasets/cityscapes.html#Cityscapes
    
    `Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.

    Args:
        root (string): Root directory of dataset where directory ``leftImg8bit``
            and ``gtFine`` or ``gtCoarse`` are located.
        split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="gtFine"
            otherwise ``train``, ``train_extra`` or ``val``
        mode (string, optional): The quality mode to use, ``gtFine`` or ``gtCoarse``
        target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
            or ``color``. Can also be a list to output a tuple with all specified target types.
        transform (callable, optional): A function/transform that takes in a PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.

    Examples:

        Get semantic segmentation target

        .. code-block:: python
            dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
                                 target_type='semantic')

            img, smnt = dataset[0]

        Get multiple targets

        .. code-block:: python
            dataset = Cityscapes('./data/cityscapes', split='train', mode='gtFine',
                                 target_type=['instance', 'color', 'polygon'])

            img, (inst, col, poly) = dataset[0]

        Validate on the "gtCoarse" set

        .. code-block:: python
            dataset = Cityscapes('./data/cityscapes', split='val', mode='gtCoarse',
                                 target_type='semantic')

            img, smnt = dataset[0]
    """

    def __init__(self, root, split='train', mode='gtFine', target_type='instance', transform=None):
        self.root = os.path.expanduser(root)
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
        self.targets_dir = os.path.join(self.root, mode, split)
        self.transform = transform
        self.target_type = target_type
        self.split = split
        self.mode = mode
        self.images = []
        self.targets = []
        
        # Modifications to ignore trainId = [255, -1] as per Cityscapes label file and for training with correct index
        self.ignore_index = 255
        self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
        self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
        self.n_classes = len(self.valid_classes)
        self.class_map = dict(zip(self.valid_classes, range(self.n_classes)))

        if mode not in ['gtFine', 'gtCoarse']:
            raise ValueError('Invalid mode! Please use mode="gtFine" or mode="gtCoarse"')

        if mode == 'gtFine' and split not in ['train', 'test', 'val']:
            raise ValueError('Invalid split for mode "gtFine"! Please use split="train", split="test"'
                             ' or split="val"')
        elif mode == 'gtCoarse' and split not in ['train', 'train_extra', 'val']:
            raise ValueError('Invalid split for mode "gtCoarse"! Please use split="train", split="train_extra"'
                             ' or split="val"')

        if not isinstance(target_type, list):
            self.target_type = [target_type]

        if not all(t in ['instance', 'semantic', 'polygon', 'color'] for t in self.target_type):
            raise ValueError('Invalid value for "target_type"! Valid values are: "instance", "semantic", "polygon"'
                             ' or "color"')

        if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
            raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
                               ' specified "split" and "mode" are inside the "root" directory')

        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)
            for file_name in os.listdir(img_dir):
                target_types = []
                for t in self.target_type:
                    target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
                                                 self._get_target_suffix(self.mode, t))
                    target_types.append(os.path.join(target_dir, target_name))

                self.images.append(os.path.join(img_dir, file_name))
                self.targets.append(target_types)
                
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
            than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
        """

        image = Image.open(self.images[index]).convert('RGB')
        
        targets = []
        for i, t in enumerate(self.target_type):
            if t == 'polygon':
                target = self._load_json(self.targets[index][i])
                
            # Modifications added to take care of ignore ids and updating ids
            elif t == 'semantic':
                temp = np.array(Image.open(self.targets[index][i])).astype(np.int32)
                temp = self._encode_target(temp)
                target = Image.fromarray(temp)
            else:
                target = np.array(Image.open(self.targets[index][i])).astype(np.int32)

            targets.append(target)

        target = tuple(targets) if len(targets) > 1 else targets[0]
        
        sample = {'image': image, 'label': target}
        
        if self.transform:
            sample = self.transform(sample)

        return sample

    def __len__(self):
        return len(self.images)

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        fmt_str += '    Split: {}\n'.format(self.split)
        fmt_str += '    Mode: {}\n'.format(self.mode)
        fmt_str += '    Type: {}\n'.format(self.target_type)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

    def _load_json(self, path):
        with open(path, 'r') as file:
            data = json.load(file)
        return data

    def _get_target_suffix(self, mode, target_type):
        if target_type == 'instance':
            return '{}_instanceIds.png'.format(mode)
        elif target_type == 'semantic':
            return '{}_labelIds.png'.format(mode)
        elif target_type == 'color':
            return '{}_color.png'.format(mode)
        else:
            return '{}_polygons.json'.format(mode)
        
    def _encode_target(self, mask):
        # Fill void class with value 255
        for void_class in self.void_classes:
            mask[mask == void_class] = self.ignore_index
            
        # Fill valid class with updated index
        for valid_class in self.valid_classes:
            mask[mask == valid_class] = self.class_map[valid_class]
            
        return mask

# Modified MobileNetV2 for DeepLabV3+

In [4]:
def _make_divisible(v, divisor, min_value=None):
    """
    This function makes sure that number of channels number is divisible by 8.
    Source: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class ConvBnReLU(nn.Module):
    """
    [CONV]-[BN]-[ReLU6]
    """

    def __init__(self, inCh, outCh, stride):
        super(ConvBnReLU, self).__init__()
        self.inCh = inCh  # Number of input channels
        self.outCh = outCh  # Number of output channels
        self.stride = stride  # Stride
        self.conv = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, 3, stride=self.stride, padding=1, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)


class InvertedResidual(nn.Module):
    """
    [EXP:CONV_1x1-BN-ReLU6]-[DW:CONV_3x3-BN-ReLU6]-[PW:CONV_1x1-BN] with identity shortcut 
    and dilation.
    """

    def __init__(self, inCh, outCh, t, s, r):
        super(InvertedResidual, self).__init__()
        self.inCh = inCh
        self.outCh = outCh
        self.t = t  # t: expansion factor
        self.r = r  # r: dilation
        if self.r > 1:
            self.s = 1  # s: Stride
            self.padding = self.r  # Atrous Conv padding same as dilation rate
        else:
            self.s = s  # s: Stride
            self.padding = 1
        self.identity_shortcut = (self.inCh == self.outCh) and (self.s == 1)  # L:506 Keras official code

        # Bottleneck block
        self.block = nn.Sequential(
            # Expansition Conv
            nn.Conv2d(self.inCh, self.t * self.inCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Depthwise Conv
            nn.Conv2d(self.t * self.inCh, self.t * self.inCh, kernel_size=3, stride=self.s, padding=self.padding, 
                      dilation=self.r, groups=self.t * self.inCh, bias=False),
            nn.BatchNorm2d(self.t * self.inCh),
            nn.ReLU6(inplace=True),

            # Pointwise Linear Conv (Projection): i.e. No non-linearity
            nn.Conv2d(self.t * self.inCh, self.outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(self.outCh),
        )

    def forward(self, x):
        if self.identity_shortcut:
            return x + self.block(x)
        else:
            return self.block(x)


class PointwiseConv(nn.Module):
    def __init__(self, inCh, outCh):
        super(PointwiseConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(inCh, outCh, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(outCh),
            nn.ReLU6(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)


# MobileNetV2
class MobileNetV2(nn.Module):
    """
    MobileNetV2 feature extractor modified to include dilation for DeepLabV3+. 
    NOTE: Last conv Layer and classification layer removed.
    """

    def __init__(self, params):
        super(MobileNetV2, self).__init__()
        self.params = params
        self.first_inCh = 3

        self.c = [_make_divisible(c * self.params.alpha, 8) for c in self.params.c]
        
        # Layer-0
        self.layer0 = nn.Sequential(ConvBnReLU(self.first_inCh, self.c[0], self.params.s[0]))

        # Layer-1
        self.layer1 = self._make_layer(self.c[0], self.c[1], self.params.t[1], self.params.s[1], 
                                       self.params.n[1], self.params.r[1])

        # Layer-2: Image size: 512 -> [IRB-2] -> Output size: 128 (low level feature: 128 * 4 = 512)
        self.layer2 = self._make_layer(self.c[1], self.c[2], self.params.t[2], self.params.s[2], 
                                       self.params.n[2], self.params.r[2])

        # Layer-3
        self.layer3 = self._make_layer(self.c[2], self.c[3], self.params.t[3], self.params.s[3], 
                                       self.params.n[3], self.params.r[3])

        # Layer-4
        self.layer4 = self._make_layer(self.c[3], self.c[4], self.params.t[4], self.params.s[4], 
                                       self.params.n[4], self.params.r[4])

        # Layer-5: Image size: 512 -> [IRB-5] -> Output size: 32, so output stride = 16 achieved
        self.layer5 = self._make_layer(self.c[4], self.c[5], self.params.t[5], self.params.s[5], 
                                       self.params.n[5], self.params.r[5])

        # Layer-6: Apply dilation rate = 2
        self.layer6 = self._make_layer(self.c[5], self.c[6], self.params.t[6], self.params.s[6], 
                                       self.params.n[6], self.params.r[6])

        # Layer-7: Apply dilation rate = 2
        self.layer7 = self._make_layer(self.c[6], self.c[7], self.params.t[7], self.params.s[7], 
                                       self.params.n[7], self.params.r[7])
        
        # Initialize weights
        self._initialize_weights()

    def _make_layer(self, inCh, outCh, t, s, n, r):
        layers = []
        for i in range(n):
            # First layer of each sequence has a stride s and all others use stride 1
            if i == 0:
                layers.append(InvertedResidual(inCh, outCh, t, s, r))
            else:
                layers.append(InvertedResidual(inCh, outCh, t, 1, r))

            # Update input channel for next IRB layer in the block
            inCh = outCh
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        low_level_features = self.layer2(x)  # [512, 512]/4 = [128, 128] 
        x = self.layer3(low_level_features)
        x = self.layer4(x)
        x = self.layer5(x)
        x = self.layer6(x)
        x = self.layer7(x)
        return x, low_level_features
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                

def MobileNet(pretrained=True, **kwargs):
    """
    Constructs a MobileNet V2 model.
    
    Parameters
    ----------
    pretrained: bool, use ImageNet pretrained model or not.
    n_class: int, 1000 classes in ImageNet data.
    weight_file: str, path to pretrained weights
    """
    weight_file = kwargs.pop('weight_file', '')
    model = MobileNetV2(**kwargs)
    if pretrained:
        state_dict = torch.load(weight_file)
        model.load_state_dict(state_dict)
    return model

# Atrous Spatial Pyramid Pooling (ASPP)

In [5]:
class AtrousConvBnRelu(nn.Module):
    """
    [Atrous CONV]-[BN]-[ReLU]
    """
    def __init__(self, inCh, outCh, dilation=1):
        super(AtrousConvBnRelu, self).__init__()
        self.inCh = inCh
        self.outCh = outCh
        self.dilation = dilation
        self.kernel = 1 if self.dilation == 1 else 3
        self.padding = 0 if self.dilation == 1 else self.dilation
        self.atrous_conv = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, self.kernel, stride=1, 
                      padding=self.padding, dilation=self.dilation, bias=False), 
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        return self.atrous_conv(x)
    

class ASPP(nn.Module):
    """
    Atrous Spatial Pyramid Pooling
    
    Ref(s): https://github.com/rishizek/tensorflow-deeplab-v3-plus/blob/master/deeplab_model.py
    and https://github.com/chenxi116/DeepLabv3.pytorch/blob/master/deeplab.py
    """
    def __init__(self, inCh, outCh):
        super(ASPP, self).__init__()
        self.rates = [1, 6, 12, 18] # for output stride 16
        self.inCh = inCh
        self.outCh = outCh
        
        # ASPP layers
        # (a) One 1x1 convolution and three 3x3 convolutions with rates = (6, 12, 18)
        self.conv_1x1_0 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[0])
        self.conv_3x3_1 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[1])
        self.conv_3x3_2 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[2])
        self.conv_3x3_3 = AtrousConvBnRelu(inCh=self.inCh, outCh=self.outCh, 
                                           dilation=self.rates[3])
        
        # (b) The image-level features
        # Global Average Pooling
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        
        # CONV-BN-ReLU after Global Average Pooling
        self.conv_bn_relu_4 = nn.Sequential(
            nn.Conv2d(self.inCh, self.outCh, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
        # CONV-BN-ReLU after Concatenation. NOTE: 5 Layers are concatenated
        self.conv_bn_relu_5 = nn.Sequential(
            nn.Conv2d(self.outCh * 5, self.outCh, kernel_size=1, stride=1, bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True)
        )
        
        self._initialize_weights()
        
    def forward(self, x):
        x0 = self.conv_1x1_0(x)  # size: [1, outCh, fs, fs]
        x1 = self.conv_3x3_1(x)  # size: [1, outCh, fs, fs]
        x2 = self.conv_3x3_2(x)  # size: [1, outCh, fs, fs]
        x3 = self.conv_3x3_3(x)  # size: [1, outCh, fs, fs]
        
        # Global Average Pooling, CONV-BN-ReLU and upsample
        global_avg_pool = self.global_avg_pooling(x)
        
        x4 = self.conv_bn_relu_4(global_avg_pool)
        
        upsample = F.interpolate(x4, size=(x.size(2), x.size(3)), mode='bilinear', 
                                 align_corners=True)
        
        # Concatinate
        x_concat = torch.cat([x0, x1, x2, x3, upsample], dim=1) # size: [1, 5 * outCh, fs, fs]
        
        # CONV-BN-ReLU after concatination
        out = self.conv_bn_relu_5(x_concat)
        
        return out
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

# Decoder

In [6]:
class Decoder(nn.Module):
    """
    Decoder for DeepLabV3+
    """
    def __init__(self, low_level_inch, low_level_outch, inCh, outCh, n_classes):
        super(Decoder, self).__init__()
        self.low_level_inch = low_level_inch
        self.low_level_outch = low_level_outch # 48 (or lower for speed)
        self.inCh = inCh
        self.outCh = outCh
        self.n_classes = n_classes
        
        # 1x1 Conv with BN and ReLU for low level features
        self.conv_1x1_bn_relu = nn.Sequential(
            nn.Conv2d(self.low_level_inch, self.low_level_outch, kernel_size=1, bias=False),
            nn.BatchNorm2d(self.low_level_outch),
            nn.ReLU(inplace=True)
        )
        
        # Conv block with BN and ReLU (paper suggests to use a few 3x3 Convs, but using only 1
        # for speed improvement) and final Conv 1x1 
        self.conv_block = nn.Sequential(
            nn.Conv2d(self.inCh + self.low_level_outch, self.outCh, kernel_size=3, stride=1, padding=1, 
                      bias=False),
            nn.BatchNorm2d(self.outCh),
            nn.ReLU(inplace=True),
            
            # For reducing number of channels
            nn.Conv2d(self.outCh, self.n_classes, kernel_size=1, stride=1, bias=False)
        )
        
        self._initialize_weights()
    
    def forward(self, x, low_level_features):
        
        # Low level features from MobileNetV2
        low_level_features = self.conv_1x1_bn_relu(low_level_features)
        
        # Upsample features from ASPP by 4
        x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=True)
        
        # Concatinate
        x_concat = torch.cat([x, low_level_features], dim=1)
        
        # Final Convolution
        out = self.conv_block(x_concat)
        
        return out
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

# DeepLabV3+

In [7]:
class DeepLabV3Plus(nn.Module):
    def __init__(self, config):
        super(DeepLabV3Plus, self).__init__()
        self.config = config
        
        # Base Network
        self.base = MobileNet(weight_file=self.config.pretrained_weights, params=self.config)
        
        # ASPP Module
        self.aspp = ASPP(inCh=self.config.aspp_inch, 
                         outCh=self.config.aspp_outch)
        
        # Decoder
        self.decoder = Decoder(low_level_inch=self.config.low_level_inCh, 
                               low_level_outch=self.config.low_level_outCh, 
                               inCh=self.config.in_channels, 
                               outCh=self.config.out_channels,
                               n_classes=self.config.n_classes)
        
    def forward(self, x):
        # Extract features from base network
        base_out, low_level_features = self.base(x)
        
        # Pool base network output using Atrous Spatial Pyramid Pooling
        aspp_out = self.aspp(base_out)
        
        # Use decoder to obtain object boundaries
        decoder_out = self.decoder(aspp_out, low_level_features)
        
        # Upsample features from decoder by 4
        out = F.interpolate(decoder_out, scale_factor=4, mode='bilinear', align_corners=True)
        
        return out

# Config

In [8]:
class Config():
    """
    Configuration for training DeepLabV3+
    """
    def __init__(self):
        # MobileNetV2 parameters
        # ----------------------
        self.pretrained_weights = './MobileNetV2-Pretrained-Weights.pth.tar'
        # Conv and Inverted Residual Parameters: Table-2 (https://arxiv.org/pdf/1801.04381.pdf)
        self.t = [1, 1, 6, 6, 6, 6, 6, 6]  # t: expansion factor
        self.c = [32, 16, 24, 32, 64, 96, 160, 320]  # c: Output channels
        self.n = [1, 1, 2, 3, 4, 3, 3, 1]  # n: Number of times layer is repeated
        self.s = [2, 1, 2, 2, 2, 1, 2, 1]  # s: Stride
        self.r = [1, 1, 1, 1, 1, 1, 2, 2]  # r: Dilation (added to take care of dilation)
        # Width multiplier: Controls the width of the network
        self.alpha = 1 # Use multiples of 0.25, min=0.25, max=1.0
        
        # Data Augmentations 
        # ------------------
        self.img_mean = [0.485, 0.456, 0.406]
        self.img_std = [0.229, 0.224, 0.225]
        self.base_size = 640  # Scale
        self.image_size = 512  # Crop size
        
        # ASPP Parameters
        # ---------------
        self.aspp_inch = int(self.alpha * self.c[-1])  # Width multiplier * 320
        self.aspp_outch = int(self.alpha * 256)  # Width multiplier * 256
        
        # Decoder Parameters
        # ------------------
        self.n_classes = 19
        self.low_level_inCh = _make_divisible(self.alpha * self.c[2], 8) # Width multiplier * 32
        self.low_level_outCh = int(2 * self.low_level_inCh)  # 2 * low level features channels
        self.in_channels = _make_divisible(self.alpha * 256, 8) # Width multiplier * 256
        self.out_channels = _make_divisible(self.alpha * 256, 8) # Width multiplier * 256
        
        # Data
        # ----
        self.dataset_root = './cityscapes'
        
        # Training config
        # ---------------
        self.use_gpu = True
        self.batch_size = 8
        self.start_epoch = 0
        self.num_epochs = 2
        self.power = 0.9 # Learning rate policy multiplier
        self.lr = 0.0001 # Learning rate 
        self.lr_multiplier = 0.9 # Learning rate decay
        self.device_id = 0 
        self.device = 'cuda:' + str(self.device_id) if self.device_id else 'cpu'
        
        # Terminal display
        # ----------------
        self.display_interval = 100
        
        # Checkpoint config
        # -----------------
        self.best_acc = 0
        self.start_epoch = 0
        self.start_from = None # Use None if training from epoch 0
        self.checkpoint_path = './checkpoints'
        self.load_best_model = False
        
config = Config()

# Trainer

**Learning Rate Policy**

$\text{mult} = (1-\frac{\text{iter}}{\text{max iter}})^p$, where $p=0.9$

In [9]:
class Trainer(object):
    """
    Trainer for DeepLabV3+ with (modified) MobileNetV2 base
    """
    def __init__(self, opt):
        self.opt = opt
        
        # Start training
        self.start()
        
    # Helpers
    @staticmethod
    def get_optimizer(opt, net):
        optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, net.parameters()), 
                                     lr=opt.lr)
        return optimizer
    
    @staticmethod
    def decay_learning_rate(opt, optimizer, epoch):
        """
        Adjust learning rate at each epoch as per policy stated in DeepLabV3 
        paper (page 5)
        """
        lr = opt.lr * (1 - float(epoch) / opt.num_epochs) ** opt.power
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print('Learning rate updated to %f' % (lr))
    
    @staticmethod
    def accuracy(scores, targets):
        """
        scores: PyTorch Tensor, output of DeepLabV3+ model, [M, 19, 1024, 2048]
        targets: PyTorch Tensor, labelIds of shape [M, 1024, 2048]
        """
        # Get indices maximum values
        preds = torch.argmax(scores, dim=1) # size: [M, 1024, 2048]
        
        # Compute element wise equality and number of elements
        correct = torch.eq(preds, targets) # targets size: [M, 1024, 2048]
        num_elements = correct.numel()
        
        # Total correct
        tot_correct = torch.sum(correct)
        
        return tot_correct.float().item() * 100.0 / num_elements
    
    def create_model(self):
        info = {}

        # DeepLabV3+ and its optimizer
        deeplab = DeepLabV3Plus(self.opt)

        optimizer = self.get_optimizer(self.opt, deeplab)
        
        if self.opt.start_from:
            if self.opt.load_best_model == 1:
                model_path = os.path.join(self.opt.checkpoint_path, 'MobileNetV2_DeepLabV3Plus.pth.tar')
            else:
                epoch = self.opt.start_from
                model_path = os.path.join(self.opt.checkpoint_path, 
                                          'MobileNetV2_DeepLabV3Plus_{}.pth.tar'.format(epoch))

            # Load checkpoint
            checkpoint = torch.load(model_path)
            info['epoch'] = checkpoint['epoch'] + 1
            info['best_accuracy'] = checkpoint['accuracy']

            # Load state dicts for encoder, decoder, and their optimizers
            deeplab.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
                
            # Reference: https://github.com/pytorch/pytorch/issues/2830
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(self.opt.device)

        return deeplab, optimizer, info
    
    def train(self, train_loader, model, loss_fn, optimizer, epoch):
        # Display string
        display = """>>> step: {}/{} (epoch: {}), loss: {ls.val:f}, avg loss: {ls.avg:f}, 
        time/batch: {proc_time.val:.3f}, avg time/batch: {proc_time.avg:.3f}, acc: {acc.val:f}"""
        
        # Training mode
        model.train()
        
        # Stats
        batch_time = AverageMeter() # Forward propagation + back propatation time
        losses = AverageMeter() # Loss 
        accs = AverageMeter() # Accuracy
        
        start = time.time()
        
        # Training loop for one epoch
        for i, batch in enumerate(train_loader):
            
            imgs = batch['image']
            masks = batch['label']
            
            batch_size = imgs.size(0)

            # Using CUDA as default
            imgs = imgs.to(self.opt.device)
            masks = masks.long().to(self.opt.device)

            # Forward pass
            logits = model(imgs)
            
            # Compute loss
            loss = loss_fn(logits.to(self.opt.device), masks)
            
            # Compute accuracy
            acc = self.accuracy(logits.cpu(), masks.cpu())
            
            # Backward propagation and update weights
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Update metrics
            losses.update(loss.item(), batch_size)
            accs.update(acc, batch_size)
            batch_time.update(time.time() - start)
            start = time.time() # Restart timer

            if i % self.opt.display_interval == 0 and i != 0:
                print(display.format(i, len(train_loader), epoch, ls=losses, 
                                     proc_time=batch_time, acc=accs))
                
        # Average Accuracy
        show = '>>> epoch: {}, avg training loss: {ls.avg:f}, avg training accuracy: {acc.avg:f}'
        print(show.format(epoch, ls=losses, acc=accs))
        
    def validate(self, val_loader, model, loss_fn, epoch):
        # Display string
        display = """>>> step: {}/{} (epoch: {}), loss: {ls.val:f}, avg loss: {ls.avg:f}, 
        time/batch: {proc_time.val:.3f}, avg time/batch: {proc_time.avg:.3f}, acc: {acc.val:f}"""

        # Stats
        batch_time = AverageMeter() # Forward propagation
        losses = AverageMeter() # Loss
        accs = AverageMeter() # Accuracy

        # Evaluation mode
        model.eval()
        
        start = time.time()

        # Validation loop for one epoch
        for i, batch in enumerate(val_loader):
            
            imgs = batch['image']
            masks = batch['label']
            
            batch_size = imgs.size(0)

            # Using CUDA as default
            imgs = imgs.to(self.opt.device)
            masks = masks.long().to(self.opt.device)

            # Forward pass
            logits = model(imgs)
            
            # Compute loss
            loss = loss_fn(logits.to(self.opt.device), masks)
            
            # Compute accuracy
            acc = self.accuracy(logits.cpu(), masks.cpu())
            
            # Update metrics
            losses.update(loss.item(), batch_size)
            accs.update(acc, batch_size)
            batch_time.update(time.time() - start)
            
            start = time.time() # Restart timer

            if i % self.opt.display_interval == 0 and i != 0:
                print(display.format(i, len(val_loader), epoch, ls=losses, 
                                     proc_time=batch_time, acc=accs))
                
        # Average Accuracy
        show = '>>> epoch: {}, avg validation loss: {ls.avg:f}, avg validation accuracy: {acc.avg:f}'
        print(show.format(epoch, ls=losses, acc=accs))
        
        return accs.avg, losses.avg
    
    def test(self):
        """
        Test functionality seprately coded for the App.
        """
        raise NotImplementedError
        
    def save_checkpoint(self, epoch, best_acc, val_avg_loss, model, optimizer, best_flag=False):
        if not os.path.exists(self.opt.checkpoint_path):
            os.makedirs(self.opt.checkpoint_path)
            
        checkpoint_name = 'MobileNetV2_DeepLabV3Plus_{}.pth.tar'.format(epoch)
            
        state = {
            'epoch': epoch,
            'accuracy':  best_acc, # Best average accuracy on validation data so far
            'loss': val_avg_loss,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict()}

        torch.save(state, os.path.join(self.opt.checkpoint_path, checkpoint_name))
        
        if best_flag:
            best_checkpoint_name = 'MobileNetV2_DeepLabV3Plus.pth.tar'
            torch.save(state, os.path.join(self.opt.checkpoint_path, best_checkpoint_name))
            
    def start(self):
        
        # Create model
        deeplab, optimizer, info = self.create_model()
        
        # Loss Function
        loss_function = nn.CrossEntropyLoss(ignore_index=255).to(self.opt.device)
        
        if self.opt.use_gpu:
            deeplab = deeplab.to(self.opt.device)
            loss_function = loss_function.to(self.opt.device)
            
        # Data Transforms: train, val and test
        train_transforms = transforms.Compose([
            RandomHorizontalFlip(),
            RandomScaleCrop(base_size=self.opt.base_size, crop_size=self.opt.image_size),
            RandomGaussianBlur(),
            Normalize(mean=self.opt.img_mean, std=self.opt.img_std),
            ToTensor()
        ])
        
        val_transforms = transforms.Compose([
            FixedScaleCrop(crop_size=self.opt.image_size),
            Normalize(mean=self.opt.img_mean, std=self.opt.img_std), 
            ToTensor()
        ])
        
        test_transforms = transforms.Compose([
            FixedResize(size=self.opt.image_size),
            Normalize(mean=self.opt.img_mean, std=self.opt.img_std), 
            ToTensor()
        ])
        
        # Data loaders
        train_data = Cityscapes(self.opt.dataset_root, split='train', mode='gtFine', target_type='semantic', 
                                transform=train_transforms)
        train_loader = DataLoader(train_data, batch_size=self.opt.batch_size, shuffle=True)
        
        val_data = Cityscapes(self.opt.dataset_root, split='val', mode='gtFine', target_type='semantic', 
                              transform=val_transforms)
        val_loader = DataLoader(val_data, batch_size=self.opt.batch_size, shuffle=True)
        
        # Start training: Train for epochs
        start_epoch = info.get('epoch', 0) if info.get('epoch', 0) else self.opt.start_epoch
        best_acc = info.get('best_accuracy', 0) if info.get('best_accuracy', 0) else self.opt.best_acc
        
        # Train for epochs
        for epoch in range(start_epoch, self.opt.num_epochs):
            
            # One epoch training
            self.train(train_loader=train_loader, model=deeplab, loss_fn=loss_function, optimizer=optimizer, 
                       epoch=epoch)

            # One epoch validation
            val_acc, val_loss = self.validate(val_loader=val_loader, model=deeplab, loss_fn=loss_function, 
                                              epoch=epoch)
            
            # Decay learning rate after each epoch as per policy
            self.decay_learning_rate(self.opt, optimizer, epoch)
            
            # Check for best accuracy
            best_flag = val_acc > best_acc
            best_acc = max(val_acc, best_acc)

            # Save checkpoint
            self.save_checkpoint(epoch, best_acc, val_loss, deeplab, optimizer, best_flag=best_flag)

# Trial Run

In [10]:
Trainer(config)

>>> step: 100/372 (epoch: 0), loss: 0.497480, avg loss: 0.921589, 
        time/batch: 1.547, avg time/batch: 1.609, acc: 78.832960
>>> step: 200/372 (epoch: 0), loss: 0.555770, avg loss: 0.723667, 
        time/batch: 1.576, avg time/batch: 1.604, acc: 65.876341
>>> step: 300/372 (epoch: 0), loss: 0.589132, avg loss: 0.644243, 
        time/batch: 2.012, avg time/batch: 1.610, acc: 71.677637
>>> epoch: 0, avg training loss: 0.603462, avg training accuracy: 76.935045
>>> epoch: 0, avg validation loss: 0.359429, avg validation accuracy: 77.333437
Learning rate updated to 0.000100
>>> step: 100/372 (epoch: 1), loss: 0.390489, avg loss: 0.425153, 
        time/batch: 1.732, avg time/batch: 1.643, acc: 78.600502
>>> step: 200/372 (epoch: 1), loss: 0.485446, avg loss: 0.404533, 
        time/batch: 1.558, avg time/batch: 1.616, acc: 79.573870
>>> step: 300/372 (epoch: 1), loss: 0.296199, avg loss: 0.391227, 
        time/batch: 1.527, avg time/batch: 1.613, acc: 84.401560
>>> epoch: 1, avg 

<__main__.Trainer at 0x7fd8c3d3bb38>