In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import VOCSegmentation
import numpy as np
import matplotlib.pyplot as plt
import VOCLoader  # import the custom dataset

# Data Preparation

In [None]:
train_loader, test_loader = VOCLoader.load(train_batch_size=32, test_batch_size=1)

In [None]:
# train_dataset = VOCSegmentation(root="./data", year='2012', image_set='train')

# images, masks = train_dataset[0]

# fig, ax = plt.subplots(1, 2, figsize=(10, 5))

# ax[0].imshow(images)
# ax[1].imshow(masks)

# plt.show()

# Show Images and Labels

In [None]:
# get some random training images
images, masks = next(iter(train_loader))

# show masks
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

def mask2label(mask, num_classes=21):
    mask = mask.numpy()  # (1, 256, 256)
    mask = mask[0]  # remove the channel dimension
    mask = np.uint8(mask * 255)  # convert to 0-255 range
    
    label_colors = np.array(VOC_COLORMAP)
    r = np.zeros_like(mask).astype(np.uint8)
    g = np.zeros_like(mask).astype(np.uint8)
    b = np.zeros_like(mask).astype(np.uint8)
    
    for l in range(0, num_classes):
        idx = mask == l
        r[idx] = label_colors[l, 0]
        g[idx] = label_colors[l, 1]
        b[idx] = label_colors[l, 2]
    
    # border
    idx = mask == 255
    r[idx] = 255
    g[idx] = 255
    b[idx] = 255
    
    rgb = np.stack([r, g, b], axis=2)
    rgb = transforms.ToTensor()(rgb)
    return rgb


num_images = 4
images = images[:num_images]
masks = masks[:num_images]
plt.figure(figsize=(10, 6))

# show images
for i, image in enumerate(images):
    # Reverse of the transformation used in the dataloader
    image[0] = image[0] * 0.229 + 0.485
    image[1] = image[1] * 0.224 + 0.456
    image[2] = image[2] * 0.225 + 0.406
    npimg = image.numpy()
    plt.subplot(2, num_images, i + 1)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')

# show labels
labels = map(mask2label, masks)
for i, label in enumerate(labels):
    plt.subplot(2, num_images, i + 1 + num_images)
    plt.imshow(label.permute(1, 2, 0))
    plt.axis('off')

plt.show()

# FCN
See https://arxiv.org/abs/1411.4038

In [None]:
def bilinear_kernel(in_channels, out_channels, kernel_size):
    factor = (kernel_size + 1) // 2
    if kernel_size % 2 == 1:
        center = factor - 1
    else:
        center = factor - 0.5
    og = (torch.arange(kernel_size).reshape(-1, 1),
          torch.arange(kernel_size).reshape(1, -1))
    filt = (1 - torch.abs(og[0] - center) / factor) * \
           (1 - torch.abs(og[1] - center) / factor)
    weight = torch.zeros((in_channels, out_channels,
                          kernel_size, kernel_size))
    weight[range(in_channels), range(out_channels), :, :] = filt
    return weight

In [None]:
# Define FCN model
# class FCN(nn.Module):
#     def __init__(self, num_classes):
#         super(FCN, self).__init__()
#         # Load the pre-trained VGG16 model
#         resnet18 = models.resnet18(pretrained=True)
#         features = list(resnet18.features.children())
#         self.features = nn.Sequential(*features[:-2])  # Extract features until the last max pooling layer
        
#         self.classifier = nn.Sequential(
#             nn.Conv2d(512, num_classes, kernel_size=1),
#             nn.ConvTranspose2d(num_classes, num_classes,
#                                 kernel_size=64, padding=16, stride=32)
#         )
            
#     def forward(self, x):
#         x = self.features(x)
#         x = self.classifier(x)
#         return x


resnet18 = models.resnet18(pretrained=True)
FCN = nn.Sequential(*list(resnet18.children())[:-2])

num_classes = 21  # Pascal VOC dataset has 21 classes
FCN.add_module('final_conv', nn.Conv2d(512, num_classes, kernel_size=1))
FCN.add_module('transpose_conv', nn.ConvTranspose2d(num_classes, num_classes,
                                    kernel_size=64, padding=16, stride=32))

W = bilinear_kernel(num_classes, num_classes, 64)
FCN.transpose_conv.weight.data.copy_(W);

# Initialize FCN model
# model = FCN(num_classes)
model = FCN

# Define loss function and optimizer
# criterion = nn.CrossEntropyLoss()
criterion = lambda inputs, targets: F.cross_entropy(inputs, targets, reduction='none', ignore_index=255).mean(1).mean(1)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, weight_decay=1e-3)

In [None]:
X = torch.rand(size=(1, 3, 320, 480))
FCN(X).shape

# Training

In [None]:
# %%time
num_epochs = 5
total_step = len(train_loader)
losses = []

for epoch in range(num_epochs):
    for i, (images, masks) in enumerate(train_loader):
        # Forward pass
        outputs = model(images)
        labels = (masks.squeeze(1) * 255).long()
        loss = criterion(outputs, labels).mean()  # Average mean loss
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        losses.append(loss.item())  # Save the loss
        
        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

# Save the model
torch.save(model.state_dict(), 'fcn_resnet18.pth')

In [None]:
# Plot the loss curve
plt.plot(losses)
# Plot vertical lines at the end of each epoch
for i in range(num_epochs):
    plt.axvline(x=total_step*(i+1), color='r', linestyle='--')
    
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('Loss curve')
plt.show()

# Test

In [None]:
# Load the model
model.load_state_dict(torch.load('fcn_resnet18.pth'))

In [None]:
model.eval()
with torch.no_grad():
    # for images, labels in test_loader:
    #     output = model(images)
    #     _, predicted = torch.max(output, 1)
        
    #     # Visualize input image, ground truth, and predicted segmentation mask
    #     plt.figure(figsize=(10, 5))
    #     plt.subplot(1, 3, 1)
    #     plt.title("Input Image")
    #     plt.imshow(transforms.ToPILImage()(images.squeeze()))
    #     plt.axis('off')
        
    #     plt.subplot(1, 3, 2)
    #     plt.title("Ground Truth")
    #     plt.imshow(transforms.ToPILImage()(mask2label(labels.squeeze())))
    #     plt.axis('off')
        
    #     plt.subplot(1, 3, 3)
    #     plt.title("Predicted")
    #     plt.imshow(transforms.ToPILImage()(mask2label(predicted.squeeze())))
    #     plt.axis('off')
        
    #     plt.show()
    #     break  # Show only one example for brevity
    
    num_images = 4
    images, masks = next(iter(train_loader))
    images = images[:num_images]
    masks = masks[:num_images]
    output = model(images)
    _, predicted = torch.max(output, 1)
    predicted = predicted.unsqueeze(1)  # Add the channel dimension
    
    # show images
    for i, image in enumerate(images):
        # Reverse of the transformation used in the dataloader
        image[0] = image[0] * 0.229 + 0.485
        image[1] = image[1] * 0.224 + 0.456
        image[2] = image[2] * 0.225 + 0.406
        npimg = image.numpy()
        plt.subplot(3, num_images, i + 1)
        plt.imshow(np.transpose(npimg, (1, 2, 0)))
        plt.axis('off')

    # show ground truth
    labels = map(mask2label, masks)
    for i, label in enumerate(labels):
        plt.subplot(3, num_images, i + 1 + num_images)
        plt.imshow(label.permute(1, 2, 0))
        plt.axis('off')
        
    # show predictions
    predictions = map(mask2label, predicted)
    for i, prediction in enumerate(predictions):
        plt.subplot(3, num_images, i + 1 + 2*num_images)
        plt.imshow(prediction.permute(1, 2, 0))
        plt.axis('off')

    plt.show()