In [1]:
%matplotlib inline
import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l  # Refer to https://d2l.ai/

In [2]:
class double_conv2d_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):
        super(double_conv2d_bn,self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=kernel_size,
                              stride = strides, padding=padding, bias=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels,
                              kernel_size = kernel_size,
                              stride = strides, padding=padding, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out
    
class deconv2d_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):
        super(deconv2d_bn,self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels,
                                        kernel_size = kernel_size,
                                       stride = strides,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out
    
class Unet(nn.Module):
    def __init__(self, num_classes):
        super(Unet,self).__init__()
        self.layer1_conv = double_conv2d_bn(3, 64)
        self.layer2_conv = double_conv2d_bn(64, 128)
        self.layer3_conv = double_conv2d_bn(128, 256)
        self.layer4_conv = double_conv2d_bn(256, 512)
        self.layer4_drop = nn.Dropout2d(0.5)
        self.layer5_conv = double_conv2d_bn(512, 1024)
        self.layer5_drop = nn.Dropout2d(0.5)
        self.layer6_conv = double_conv2d_bn(1024, 512)
        self.layer7_conv = double_conv2d_bn(512, 256)
        self.layer8_conv = double_conv2d_bn(256, 128)
        self.layer9_conv = double_conv2d_bn(128, 64)
        self.layer10_conv = nn.Conv2d(64, num_classes, kernel_size=3,
                                     stride=1, padding=1, bias=True)
        
        self.deconv1 = deconv2d_bn(1024, 512)
        self.deconv2 = deconv2d_bn(512, 256)
        self.deconv3 = deconv2d_bn(256, 128)
        self.deconv4 = deconv2d_bn(128, 64)
        
        # self.sigmoid = nn.Sigmoid()
        
    def forward(self,x):
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1, 2)
        
        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2, 2)
        
        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3, 2)
        
        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4, 2)
        drop4 = self.layer4_drop(pool4)
        
        conv5 = self.layer5_conv(drop4)
        drop5 = self.layer5_drop(conv5)
        
        convt1 = self.deconv1(drop5)                 # C = 512
        concat1 = torch.cat([convt1, conv4], dim=1)  # C = 1024
        conv6 = self.layer6_conv(concat1)            # C = 512
        
        convt2 = self.deconv2(conv6)                 # C = 256
        concat2 = torch.cat([convt2, conv3], dim=1)  # C = 512
        conv7 = self.layer7_conv(concat2)            # C = 256
        
        convt3 = self.deconv3(conv7)                 # C = 128
        concat3 = torch.cat([convt3, conv2], dim=1)  # C = 256
        conv8 = self.layer8_conv(concat3)            # C = 128
        
        convt4 = self.deconv4(conv8)                 # C = 64
        concat4 = torch.cat([convt4, conv1], dim=1)  # C = 128
        conv9 = self.layer9_conv(concat4)            # C = 64
        outp = self.layer10_conv(conv9)              # C = num_classes
        # outp = self.sigmoid(outp)
        return outp
    

net = Unet(num_classes=21)
inp = torch.rand(10, 3, 224, 224)
outp = net(inp)
print(outp.shape)
print(outp)

torch.Size([10, 1, 224, 224])


# Reading the Dataset

In [None]:
batch_size, crop_size = 64, (224, 224)  # adjust to fit your requirements
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)

# Training

In [None]:
def train(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    """Train a model with multiple GPUs (defined in Chapter 13).

    Defined in :numref:`sec_image_augmentation`"""
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1.5],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(
                net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')

In [None]:
def loss(inputs, targets):
    return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)

num_epochs, lr, wd, devices = 100, 1e-4, 1e-3, d2l.try_all_gpus()
# trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
trainer = torch.optim.Adam(net.parameters(), lr=lr)
train(net, train_iter, test_iter, loss, trainer, num_epochs, devices)
torch.save(net.state_dict(), 'unet_100.pth')

# Prediction

In [None]:
def predict(img):
    X = test_iter.dataset.normalize_image(img).unsqueeze(0)
    pred = net(X.to(devices[0])).argmax(dim=1)
    return pred.reshape(pred.shape[1], pred.shape[2])

In [None]:
def label2image(pred):
    colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0])
    X = pred.long()
    return colormap[X, :]

In [None]:
voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)

In [None]:
n, imgs = 4, []
for i in range(n):
    crop_rect = (0, 0, 320, 480)
    X = transforms.functional.crop(test_images[i], *crop_rect)
    pred = label2image(predict(X))
    imgs += [
        X.permute(1,2,0),
        transforms.functional.crop(test_labels[i], *crop_rect).permute(1,2,0),
        pred.cpu()
    ]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);