# Fully Convolutional Network - Semantic Segmentation

![image.png](imgs/2.png)

![image.png](imgs/3.png)

In [None]:
import os
import os.path as osp
import pytz
import torch

import warnings
warnings.filterwarnings('ignore')

configurations = {
    # same configuration as original work
    # https://github.com/shelhamer/fcn.berkeleyvision.org
    1: dict(
        max_iteration=100000,
        lr=1.0e-10,
        momentum=0.99,
        weight_decay=0.0005,
        interval_validate=4000,
    )
}

In [None]:
from types import SimpleNamespace
opts = SimpleNamespace()
opts.cfg = configurations[1]
opts.resume = ''
print(opts.cfg)

In [None]:
from utils import get_log_dir
opts.out = get_log_dir('vgg8s', 1, opts.cfg)
print(opts.out)

In [None]:
gpu = 1
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
cuda = torch.cuda.is_available()
print('Cuda: {}'.format(cuda))
opts.cuda = 'cuda' if cuda else 'cpu'
opts.mode = 'train'
opts.backbone = 'vgg'
opts.fcn = '8s'

## PascalVOC Dataset - Downloaded on _`root`_ variable

In [None]:
root = './data/Pascal_VOC'
print(root)

In [None]:
from data_loader import Pascal_Data
kwargs = {'num_workers': 4} if cuda else {}
train_loader = torch.utils.data.DataLoader(
        Pascal_Data(root, image_set='train', backbone='vgg'),
        batch_size=1, shuffle=True, **kwargs)
val_loader = torch.utils.data.DataLoader(
        Pascal_Data(root, image_set='val', backbone='vgg'),
        batch_size=1, shuffle=False, **kwargs)
data_loader = [train_loader, val_loader]

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
for data, target in train_loader: break
print(data.shape)
print(target.shape)
data.min()
data_show, label_show = train_loader.dataset.untransform(data[0].cpu().clone(), target[0].cpu().clone())

plt.imshow(data_show)
plt.show()

def imshow_label(label_show):
    import matplotlib
    import numpy as np
    cmap = plt.cm.jet
    # extract all colors from the .jet map
    cmaplist = [cmap(i) for i in range(cmap.N)]
    cmaplist[0] = (0.0,0.0,0.0,1.0)
    cmap = cmap.from_list('Custom cmap', cmaplist, cmap.N)
    # define the bins and normalize
    bounds = np.arange(0,len(train_loader.dataset.class_names))
    norm = matplotlib.colors.BoundaryNorm(bounds, cmap.N)
    plt.imshow(label_show, cmap=cmap, norm=norm)
    cbar = plt.colorbar(ticks=bounds)
    cbar.ax.set_yticklabels(train_loader.dataset.class_names)
    plt.show()    
    
imshow_label(label_show)


## FCN - Model

In [None]:
import numpy as np
import torch.nn as nn

class FCN8s(nn.Module):

    def __init__(self, n_class=21):
        super(FCN8s, self).__init__()
        # conv1
        self.conv1_1 = nn.Conv2d(3, 64, 3, padding=100)
        self.relu1_1 = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
        self.relu1_2 = nn.ReLU(inplace=True)
        self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/2

        # conv2
        self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
        self.relu2_2 = nn.ReLU(inplace=True)
        self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/4

        # conv3
        self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
        self.relu3_1 = nn.ReLU(inplace=True)
        self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
        self.relu3_2 = nn.ReLU(inplace=True)
        self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
        self.relu3_3 = nn.ReLU(inplace=True)
        self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/8

        # conv4
        self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
        self.relu4_1 = nn.ReLU(inplace=True)
        self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu4_2 = nn.ReLU(inplace=True)
        self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu4_3 = nn.ReLU(inplace=True)
        self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/16

        # conv5
        self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_1 = nn.ReLU(inplace=True)
        self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_2 = nn.ReLU(inplace=True)
        self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
        self.relu5_3 = nn.ReLU(inplace=True)
        self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)  # 1/32

        # fc6
        self.fc6 = nn.Conv2d(512, 4096, 7)
        self.relu6 = nn.ReLU(inplace=True)
        self.drop6 = nn.Dropout2d()

        # fc7
        self.fc7 = nn.Conv2d(4096, 4096, 1)
        self.relu7 = nn.ReLU(inplace=True)
        self.drop7 = nn.Dropout2d()

        self.score_fr = nn.Conv2d(4096, n_class, 1)
        self.score_pool3 = nn.Conv2d(256, n_class, 1)
        self.score_pool4 = nn.Conv2d(512, n_class, 1)

        self.upscore2 = nn.ConvTranspose2d(n_class,
                                           n_class,
                                           4,
                                           stride=2,
                                           bias=False)
        self.upscore_pool4 = nn.ConvTranspose2d(n_class,
                                                n_class,
                                                4,
                                                stride=2,
                                                bias=False)        
        self.upscore8 = nn.ConvTranspose2d(n_class,
                                           n_class,
                                           16,
                                           stride=8,
                                           bias=False)

        self._initialize_weights()

    def _initialize_weights(self):
        from models.vgg.helpers import get_upsampling_weight
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.zero_()
                if m.bias is not None:
                    m.bias.data.zero_()
            if isinstance(m, nn.ConvTranspose2d):
                assert m.kernel_size[0] == m.kernel_size[1]
                initial_weight = get_upsampling_weight(m.in_channels,
                                                       m.out_channels,
                                                       m.kernel_size[0])
                m.weight.data.copy_(initial_weight)

    def forward(self, x, debug=False):
        h = x
        if debug:
            print(h.data.shape)
        h = self.relu1_1(self.conv1_1(h))
        if debug:
            print(h.data.shape)
        h = self.relu1_2(self.conv1_2(h))
        if debug:
            print(h.data.shape)
        h = self.pool1(h)
        if debug:
            print(h.data.shape)

        h = self.relu2_1(self.conv2_1(h))
        if debug:
            print(h.data.shape)
        h = self.relu2_2(self.conv2_2(h))
        if debug:
            print(h.data.shape)
        h = self.pool2(h)
        if debug:
            print(h.data.shape)

        h = self.relu3_1(self.conv3_1(h))
        if debug:
            print(h.data.shape)
        h = self.relu3_2(self.conv3_2(h))
        if debug:
            print(h.data.shape)
        h = self.relu3_3(self.conv3_3(h))
        if debug:
            print(h.data.shape)
        h = self.pool3(h)
        if debug:
            print('pool3: {}'.format(h.data.shape))
        pool3 = h  # 1/8

        h = self.relu4_1(self.conv4_1(h))
        if debug:
            print(h.data.shape)
        h = self.relu4_2(self.conv4_2(h))
        if debug:
            print(h.data.shape)
        h = self.relu4_3(self.conv4_3(h))
        if debug:
            print(h.data.shape)
        h = self.pool4(h)
        if debug:
            print('pool4: {}'.format(h.data.shape))
        pool4 = h  # 1/16 #<------------------------------------

        h = self.relu5_1(self.conv5_1(h))
        if debug:
            print(h.data.shape)
        h = self.relu5_2(self.conv5_2(h))
        if debug:
            print(h.data.shape)
        h = self.relu5_3(self.conv5_3(h))
        if debug:
            print(h.data.shape)
        h = self.pool5(h)
        if debug:
            print(h.data.shape)

        h = self.relu6(self.fc6(h))
        if debug:
            print(h.data.shape)
        h = self.drop6(h)
        if debug:
            print(h.data.shape)

        h = self.relu7(self.fc7(h))
        if debug:
            print(h.data.shape)
        h = self.drop7(h)
        if debug:
            print(h.data.shape)

        h = self.score_fr(h)
        if debug:
            print(h.data.shape)
        h = self.upscore2(h)
        if debug:
            print('upscore2: {}'.format(h.data.shape))
        upscore2 = h  # 1/16

        h = self.score_pool4(pool4)
        if debug:
            print('score_pool4: {}'.format(h.data.shape))
        h = h[:, :, 5:5 + upscore2.size()[2], 5:5 + upscore2.size()[3]]
        if debug:
            print('score_pool4c: {}'.format(h.data.shape))
        score_pool4c = h  # 1/16

        h = upscore2 + score_pool4c
        if debug:
            print('upscore2+score_pool4c: {}'.format(h.data.shape))
        h = self.upscore_pool4(h)
        if debug:
            print('upscore_pool4: {}'.format(h.data.shape))        
        upscore_pool4 = h  # 1/8

        h = self.score_pool3(pool3)
        if debug:
            print('score_pool3: {}'.format(h.data.shape))
        h = h[:, :, 9:9 + upscore_pool4.size()[2], 9:9 +
              upscore_pool4.size()[3]]
        if debug:
            print('score_pool3c: {}'.format(h.data.shape))              
        score_pool3c = h  # 1/8

        h = upscore_pool4 + score_pool3c  # 1/8
        if debug:
            print('upscore_pool4+score_pool3c: {}'.format(h.data.shape))

        h = self.upscore8(h)
        if debug:
            print('upscore8: {}'.format(h.data.shape))
        h = h[:, :, 31:31 + x.size()[2], 31:31 + x.size()[3]].contiguous()
        if debug:
            print('upscore8 rearranged: {}'.format(h.data.shape))

        return h

    def copy_params_from_fcn16s(self, fcn16s):
        for name, l1 in fcn16s.named_children():
            try:
                l2 = getattr(self, name)
                l2.weight  # skip ReLU / Dropout
            except Exception:
                continue
            assert l1.weight.size() == l2.weight.size()
            l2.weight.data.copy_(l1.weight.data)
            if l1.bias is not None:
                assert l1.bias.size() == l2.bias.size()
                l2.bias.data.copy_(l1.bias.data)             

In [None]:
# https://github.com/shelhamer/fcn.berkeleyvision.org/blob/master/surgery.py
def get_upsampling_weight(in_channels, out_channels, kernel_size):
    """Make a 2D bilinear kernel suitable for upsampling"""
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = np.ogrid[:kernel_size, :kernel_size]
    filt = (1 - abs(og[0] - center) / factor) * \
           (1 - abs(og[1] - center) / factor)
    weight = np.zeros((in_channels, out_channels, kernel_size, kernel_size),
                      dtype=np.float64)
    weight[range(in_channels), range(out_channels), :, :] = filt
    return torch.from_numpy(weight).float()

## From fcn16s weights

In [None]:
model = FCN8s(n_class=21)
model.to(opts.cuda)

In [None]:
model

In [None]:
iter_loader=iter(train_loader)
data, target = next(iter_loader)
data = data.to(opts.cuda)
with torch.no_grad():
    output = model(data)

In [None]:
print('input: ', data.shape)
print('output: ', output.data.shape)

In [None]:
data, target = next(iter_loader)
data = data.to(opts.cuda)
with torch.no_grad():
    output = model(data, debug=True)

In [None]:
data, target = next(iter_loader)
data = data.to(opts.cuda)
with torch.no_grad():
    output = model(data, debug=True)

In [None]:
if opts.resume:
    print('Loading checkpoint from: '+resume)
    checkpoint = torch.load(resume)
    model.load_state_dict(checkpoint['model_state_dict'])
else:
    from models.vgg.fcn16s import FCN as FCN16
    fcn16s = FCN16()
    fcn16s_weights = FCN16.download()  # Original FCN16 pretrained model
    fcn16s.load_state_dict(torch.load(fcn16s_weights))
    model.copy_params_from_fcn16s(fcn16s)

In [None]:
%matplotlib inline
from trainer import Trainer

In [None]:
trainer = Trainer(data_loader, opts)

In [None]:
print(opts.cfg.get('interval_validate', len(train_loader))) #Validate every 4000 iterations
print(opts.out)

In [None]:
start_epoch = 0
start_iteration = 0
if opts.resume:
    start_epoch = checkpoint['epoch']
    start_iteration = checkpoint['iteration']

In [None]:
trainer.epoch = start_epoch
trainer.iteration = start_iteration
trainer.Train()