In [2]:
%matplotlib inline
import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l  # Refer to https://d2l.ai/

# Define your model here

In [None]:
net = nn.Sequential()

# Reading the dataset

In [None]:
batch_size, crop_size = 32, (320, 480)  # adjust to fit your requirements
train_iter, test_iter = d2l.load_data_voc(batch_size, crop_size)

# Training

In [None]:
def loss(inputs, targets):
    return F.cross_entropy(inputs, targets, reduction='none').mean(1).mean(1)

num_epochs, lr, wd, devices = 5, 0.001, 1e-3, d2l.try_all_gpus()
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)

# Prediction

In [None]:
def predict(img):
    X = test_iter.dataset.normalize_image(img).unsqueeze(0)
    pred = net(X.to(devices[0])).argmax(dim=1)
    return pred.reshape(pred.shape[1], pred.shape[2])

In [None]:
def label2image(pred):
    colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0])
    X = pred.long()
    return colormap[X, :]

In [None]:
voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)

n, imgs = 4, []
for i in range(n):
    crop_rect = (0, 0, 320, 480)
    X = transforms.functional.crop(test_images[i], *crop_rect)
    pred = label2image(predict(X))
    imgs += [
        X.permute(1,2,0),
        transforms.functional.crop(test_labels[i], *crop_rect).permute(1,2,0),
        pred.cpu()
    ]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2);