In [1]:
import glob
from PIL import Image
import numpy as np

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torchvision.models import vgg16

In [3]:
color_2_index = np.asarray([
    [0, 0, 0],
    [128, 0, 0],
    [0, 128, 0],
    [128, 128, 0],
    [0, 0, 128],
    [128, 0, 128],
    [0, 128, 128],
    [128, 128, 128],
    [64, 0, 0],
    [192, 0, 0],
    [64, 128, 0],
    [192, 128, 0],
    [64, 0, 128],
    [192, 0, 128],
    [64, 128, 128],
    [192, 128, 128],
    [0, 64, 0],
    [128, 64, 0],
    [0, 192, 0],
    [128, 192, 0],
    [0, 64, 128],
    ])

class_names = ['background', 'aeroplane', 'bicycle', 'bird', 'boat',
                'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable',
                'dog', 'horse', 'motorbike', 'person', 'potted-plant',
                'sheep', 'sofa', 'train', 'tv/monitor']

In [4]:
path = "data/VOCdevkit/VOC2012/"
img_size = 224

In [5]:
class PascalVoc(Dataset):
    def __init__(self, path, img_size):
        
        self.seg_folder = "SegmentationClass/"
        self.img_folder = "JPEGImages/"
        
        self.segmentation_imgs = glob.glob(path + self.seg_folder + "*")
        self.img_size = img_size
        
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view((1, 1, 3))
        self.std = torch.tensor([0.229, 0.224, 0.225]).view((1, 1, 3))
        
    def __len__(self):
        return len(self.segmentation_imgs)
    
    @staticmethod
    def create_label_mask(mask_img):
        mask = np.array(mask_img).astype(int)
        label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)

        for idx, label in enumerate(color_2_index):
            label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = idx
        
        label_mask = label_mask.astype(int)
        return label_mask
    
    def __getitem__(self, idx):
        mask_path = self.segmentation_imgs[idx]
        file_name = mask_path.split("\\")[1]

        mask_img = Image.open(mask_path).convert('RGB')
        mask_img = mask_img.resize((self.img_size, self.img_size))
        mask_img = PascalVoc.create_label_mask(mask_img)
        mask_img = torch.from_numpy(mask_img).int()
        
        img_name = path + self.img_folder + file_name
        img_name = img_name.split(".")[0] + ".jpg"

        img = Image.open(img_name)
        img = img.resize((self.img_size, self.img_size))
        img = torch.from_numpy(np.array(img)).float() / 255
        img = (img - self.mean) / self.std
        
        return (img, mask_img)

In [6]:
dataset = PascalVoc(path, img_size)

In [7]:
class trans_conv2d(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels,
                                       kernel_size=3, stride=stride,
                                       bias=True, padding=1,
                                       output_padding=1)
        
    def forward(self, x):
        x = self.conv(x)
        x = F.relu(x)
        return x

class fcn_8(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        
        self.num_classes = num_classes
        
        # Indicies used for FCN
        self.fcn_block_indicies = [(0, 17), (17, 24), (24, 31)]
        self.init_vgg()
        
        self.pool_5_upsample = trans_conv2d(512, 512)
        self.pool_4_upsample = trans_conv2d(512, 256) 
        self.pool_3_upsample = trans_conv2d(256, 128)
        
        self.upsample_score_1 = trans_conv2d(128, 64)
        self.upsample_score_2 = trans_conv2d(64, 32)
        
        self.out_conv = nn.Conv2d(32, self.num_classes, kernel_size=1)
    
    def init_vgg(self):
        model = vgg16(pretrained=True)
        self.blocks = []
        
        for block_idx in self.fcn_block_indicies:
            self.blocks.append(model.features[block_idx[0]:block_idx[1]])
        
        # The network does not update the VGG weights
        for block in self.blocks:
            for param in block.parameters():
                param.requires_grad = False
        
    
    def forward(self, x):
        x_3 = self.blocks[0](x)
        x_4 = self.blocks[1](x_3)
        x_5 = self.blocks[2](x_4)
        
        x_5 = self.pool_5_upsample(x_5)
        x_4 = self.pool_4_upsample(x_4 + x_5)
        x_3 = self.pool_3_upsample(x_3 + x_4)
        
        score = self.upsample_score_1(x_3)
        score = self.upsample_score_2(score)
        
        score = self.out_conv(score)
        score = F.log_softmax(score, dim=1)
        return score

In [8]:
model = fcn_8(len(class_names))

In [9]:
dummy_input = torch.from_numpy(np.random.uniform(-1, 1, size=(1, 3, 224, 224))).float()

In [10]:
x = model(dummy_input)

In [11]:
x.shape

torch.Size([1, 21, 224, 224])