This notebook gets the basic model working.

In [2]:
from pathlib import Path
dataroot = Path('/Users/alex/Desktop/bob-ross-kaggle-dataset/')

import torch
from torchvision.datasets import ImageFolder

imgf = ImageFolder((dataroot / 'train').as_posix())

`torchvision.datasets.ImageFolder` expects the images to be laid out according to class. E.g. `train/dog/img_1.png`. We have separate segmentation masks so we don't support this form of organization; the implication is that this API is appropriate only for the categorical images use case (despite the generic name).

The correct code path is to create our own custom `Dataset`.

In [4]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from torchvision import transforms


class BobRossSegmentedImagesDataset(Dataset):
    def __init__(self, dataroot):
        super().__init__()
        self.dataroot = dataroot
        self.imgs = list((self.dataroot / 'train' / 'images').rglob('*.png'))
        self.segs = list((self.dataroot / 'train' / 'labels').rglob('*.png'))
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)), transforms.ToTensor()
        ])
        self.color_key = {
            3 : 0,
            5: 1,
            10: 2,
            14: 3,
            17: 4,
            18: 5,
            22: 6,
            27: 7,
            61: 8
        }
        assert len(self.imgs) == len(self.segs)
        # TODO: remean images to N(0, 1)?
        
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, i):
        def translate(x):
            return self.color_key[x]
        translate = np.vectorize(translate)
        
        img = Image.open(self.imgs[i])
        img = self.transform(img)
        
        seg = Image.open(self.segs[i])
        seg = seg.resize((256, 256))
        
        # Labels are in the ADE20K ontology and are not consequetive,
        # we have to apply a remap operation over the labels in a just-in-time
        # manner. This slows things down, but it's fine, this is just a demo
        # anyway.
        seg = translate(np.array(seg)).astype('int64')
        
        # One-hot encode the segmentation mask.
        # def ohe_mat(segmap):
        #     return np.array(
        #         list(
        #             np.array(segmap) == i for i in range(9)
        #         )
        #     ).astype(int).reshape(9, 256, 256)
        # seg = ohe_mat(seg)
        
        # Additionally, the original UNet implementation outputs a segmentation map
        # for a subset of the overall image, not the image as a whole! With this input
        # size the segmentation map targeted is a (164, 164) center crop.
        seg = seg[46:210, 46:210]
        
        return img, seg
        

from pathlib import Path
dataroot = Path('/Users/alex/Desktop/bob-ross-kaggle-dataset/')

dataset = BobRossSegmentedImagesDataset(dataroot)
dataloader = DataLoader(dataset, shuffle=True)

Annoyingly it appears that some transforms in the `torchvision.transforms` utility library operate only on PIL images and some transforms operate only on tensors. This requires careful cross-conversion.

![](https://raw.githubusercontent.com/jaxony/unet-pytorch/master/unet-architecture.png)

* The paper uses `572x572`, we start with `256x256`, and thus will have one less level of depth.
* The paper takes black-and-white input, we have RGB input, no special consideration here, we'll just overload.

In [5]:
from torch import nn


class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_1_1 = nn.Conv2d(3, 64, 3)
        self.relu_1_2 = nn.ReLU()
        self.conv_1_3 = nn.Conv2d(64, 64, 3)
        self.relu_1_4 = nn.ReLU()
        self.pool_1_5 = nn.MaxPool2d(2)
        
        self.conv_2_1 = nn.Conv2d(64, 128, 3)
        self.relu_2_2 = nn.ReLU()
        self.conv_2_3 = nn.Conv2d(128, 128, 3)
        self.relu_2_4 = nn.ReLU()        
        self.pool_2_5 = nn.MaxPool2d(2)
        
        self.conv_3_1 = nn.Conv2d(128, 256, 3)
        self.relu_3_2 = nn.ReLU()
        self.conv_3_3 = nn.Conv2d(256, 256, 3)
        self.relu_3_4 = nn.ReLU()
        self.pool_3_5 = nn.MaxPool2d(2)
        
        self.conv_4_1 = nn.Conv2d(256, 512, 3)
        self.relu_4_2 = nn.ReLU()
        self.conv_4_3 = nn.Conv2d(512, 512, 3)
        self.relu_4_4 = nn.ReLU()
        
        # deconv is the '2D transposed convolution operator'
        self.deconv_5_1 = nn.ConvTranspose2d(512, 256, (2, 2), 2)
        # 61x61 -> 48x48 crop
        self.c_crop_5_2 = lambda x: x[:, :, 6:54, 6:54]
        self.concat_5_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.conv_5_4 = nn.Conv2d(512, 256, 3)
        self.relu_5_5 = nn.ReLU()
        self.conv_5_6 = nn.Conv2d(256, 256, 3)
        self.relu_5_7 = nn.ReLU()
        
        self.deconv_6_1 = nn.ConvTranspose2d(256, 128, (2, 2), 2)
        # 121x121 -> 88x88 crop
        self.c_crop_6_2 = lambda x: x[:, :, 17:105, 17:105]
        self.concat_6_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.conv_6_4 = nn.Conv2d(256, 128, 3)
        self.relu_6_5 = nn.ReLU()
        self.conv_6_6 = nn.Conv2d(128, 128, 3)
        self.relu_6_7 = nn.ReLU()
        
        self.deconv_7_1 = nn.ConvTranspose2d(128, 64, (2, 2), 2)
        # 252x252 -> 168x168 crop
        self.c_crop_7_2 = lambda x: x[:, :, 44:212, 44:212]
        self.concat_7_3 = lambda x, y: torch.cat((x, y), dim=1)
        self.conv_7_4 = nn.Conv2d(128, 64, 3)
        self.relu_7_5 = nn.ReLU()
        self.conv_7_6 = nn.Conv2d(64, 64, 3)
        self.relu_7_7 = nn.ReLU()
        
        # 1x1 conv ~= fc; n_classes = 9
        self.conv_8_1 = nn.Conv2d(64, 9, 1)

    def forward(self, x):
        x = self.conv_1_1(x)
        x = self.relu_1_2(x)
        x = self.conv_1_3(x)
        x_residual_1 = self.relu_1_4(x)
        x = self.pool_1_5(x_residual_1)
        
        x = self.conv_2_1(x)
        x = self.relu_2_2(x)        
        x = self.conv_2_3(x)
        x_residual_2 = self.relu_2_4(x)        
        x = self.pool_2_5(x_residual_2)
        
        x = self.conv_3_1(x)
        x = self.relu_3_2(x)        
        x = self.conv_3_3(x)
        x_residual_3 = self.relu_3_4(x)
        x = self.pool_3_5(x_residual_3)
        
        x = self.conv_4_1(x)
        x = self.relu_4_2(x)
        x = self.conv_4_3(x)
        x = self.relu_4_4(x)
        
        x = self.deconv_5_1(x)
        x = self.concat_5_3(self.c_crop_5_2(x_residual_3), x)
        x = self.conv_5_4(x)
        x = self.relu_5_5(x)
        x = self.conv_5_6(x)
        x = self.relu_5_7(x)
        
        x = self.deconv_6_1(x)
        x = self.concat_6_3(self.c_crop_6_2(x_residual_2), x)
        x = self.conv_6_4(x)
        x = self.relu_6_5(x)
        x = self.conv_6_6(x)
        x = self.relu_6_7(x)
        
        x = self.deconv_7_1(x)
        x = self.concat_7_3(self.c_crop_7_2(x_residual_1), x)
        x = self.conv_7_4(x)
        x = self.relu_7_5(x)
        x = self.conv_7_6(x)
        x = self.relu_7_7(x)
        
        x = self.conv_8_1(x)
        return x

In [6]:
model = UNet()
batch, seg = next(iter(dataloader))

with torch.no_grad():
    y = model(batch)
y.shape

torch.Size([1, 9, 164, 164])

In [7]:
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

for epoch in range(10):
    for i, (batch, segmap) in enumerate(dataloader):
        optimizer.zero_grad()

        output = model(batch)
        loss = criterion(output, segmap)
        loss.backward()
        optimizer.step()
        
        if i % 50 == 0:
            print(f'Finished epoch {epoch}, batch {i}. Loss: {loss.item():.3f}.')
        break
    break

Finished epoch 0, batch 0. Loss: 2.242.


Next on the to-do list after this: use an LR finder to set a learning rate, five-fold cross-validation, regression across cross-validated models, expand the model to predict full mask instead of the center crop subregion, and start working on the various model optimizations.