# PicsArt AI Hackathon Online

## Детектирование фона на изображениях

In [1]:
%pylab inline

import os
import tqdm

import pandas as pd
from PIL import Image
from skimage.morphology import remove_small_objects, remove_small_holes

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg13, resnet50

from utils import rle_encode

Populating the interactive namespace from numpy and matplotlib


# Загрузим список фотографий из обучающей выборки.

In [2]:
path_images = list(map(
    lambda x: x.split('.')[0],
    filter(lambda x: x.endswith('.jpg'), os.listdir('data/train/'))))

Разделим на обучающую и валидационную выборки. Для ускорения оставим только 24 изображения для валидации.

In [3]:
train_images, val_images = path_images[:-500], path_images[-500:]

Опишем датасет. Предусмотрим загрузку масок из другой директории с теми же названиями файлов.

In [4]:
class FaceDataset(Dataset):
    def __init__(self, images_dir, images_name, target_dir=None,
                 transforms=None):
        
        self.images_dir = images_dir
        self.target_dir = target_dir
        self.images_name = images_name
        self.transforms = transforms
                           
        print('{} images'.format(len(self.images_name)))

    def __len__(self):
        return len(self.images_name)
               
    def __getitem__(self, idx):
        img_filename = os.path.join(
            self.images_dir, self.images_name[idx] + '.jpg')
        img = Image.open(img_filename)
        
        if self.target_dir:
            mask_filename = os.path.join(
                self.target_dir, self.images_name[idx] + '.png')
            mask = Image.open(mask_filename)
        else:
            mask = []
        
        if self.transforms:
            background = Image.new('RGB', (320, 320))
            background.paste(img)
            img = self.transforms(background)
            if mask:
                background = Image.new('RGB', (320, 320))
                background.paste(mask)
                mask = transforms.ToTensor()(background)

        return {'img': img, 'mask': mask}

В качестве трансформации возьмём только нормализацию с параметрами от ImageNet, так как будем использовать предобученный кодировщик.

In [5]:
image_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])

In [6]:
train_dataset = FaceDataset(
    images_dir='data/train/',
    images_name=train_images,
    target_dir='data/train_mask/',
    transforms=image_transforms)

val_dataset = FaceDataset(
    images_dir='data/train/',
    images_name=val_images,
    target_dir='data/train_mask/',
    transforms=image_transforms)

991 images
500 images


Генераторы для обучения и валидации сети.

In [7]:
train_data_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=4)

Для решения задачи сегментации будем использовать UNet. Энкодер сети сделаем из первых блоков предобученного VGG13.

In [8]:
device = torch.device("cuda:0")

In [9]:
class VGG13Encoder(torch.nn.Module):
    def __init__(self, num_blocks, pretrained=True):
        super().__init__()
        self.num_blocks = num_blocks
        self.blocks = []
        feature_extractor = vgg13(pretrained=pretrained).features
        for i in range(self.num_blocks):
            self.blocks.append(
                torch.nn.Sequential(*[feature_extractor[j]
                                      for j in range(i * 5, i * 5 + 4)]).to(device))
    
    def forward(self, x):
        activations = []
        for i in range(self.num_blocks):
            x = self.blocks[i](x)
            activations.append(x)
            if i != self.num_blocks - 1:
                x = torch.functional.F.max_pool2d(x, kernel_size=2, stride=2)
        return activations

Опишем блок декодера.

In [None]:
class DecoderBlock(torch.nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.upconv = torch.nn.Conv2d(
            in_channels=out_channels * 2, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1)
        self.conv1 = torch.nn.Conv2d(
            in_channels=out_channels * 2, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1)
        self.conv2 = torch.nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1)

    def forward(self, down, left):
        x = torch.nn.functional.interpolate(down, scale_factor=2)
        x = self.upconv(x)
        x = self.conv1(torch.cat([left, x], 1))
        x = self.conv2(x)
        return x

Сформируем весь декодер из блоков.

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, num_filters, num_blocks):
        super().__init__()
        self.blocks = []
        for i in range(num_blocks):
            self.blocks.append(DecoderBlock(num_filters * 2**(num_blocks-i-1)).to(device))

    def forward(self, activations):
        up = activations[-1]
        for i, left in enumerate(activations[-2::-1]):
            up = self.blocks[i](up, left)
        return up

А теперь и всю сеть целиком.

In [None]:
class UNet(torch.nn.Module):
    def __init__(self, num_classes=1, num_filters=64, num_blocks=4):
        super().__init__()
        self.encoder = VGG13Encoder(num_blocks=num_blocks)
        self.decoder = Decoder(num_filters=64, num_blocks=num_blocks - 1)
        self.final = torch.nn.Conv2d(
            in_channels=num_filters, out_channels=num_classes, kernel_size=1)

    def forward(self, x):
        acts = self.encoder(x)
        x = self.decoder(acts)
        x = self.final(x)
        return x

In [None]:
import torch
import torch.nn as nn


from torchvision.models import resnet18, resnet34


class ResNetEncoder(nn.Module):
    def __init__(self, arch, pretrained=False):
        super().__init__()

        backbone = arch(pretrained=pretrained)

        self.encoder0 = nn.Sequential(
            backbone.conv1,
            backbone.bn1,
            backbone.relu,
            backbone.maxpool
        )
        self.encoder1 = backbone.layer1
        self.encoder2 = backbone.layer2
        self.encoder3 = backbone.layer3
        self.encoder4 = backbone.layer4

        self.filters = [
            module[-1].conv3.out_channels
            if 'conv3' in module[-1].__dict__['_modules']
            else module[-1].conv2.out_channels
            for module in [
                self.encoder1,
                self.encoder2,
                self.encoder3,
                self.encoder4
            ]
        ]

    def forward(self, x):
        acts = []
        x = self.encoder0(x)
        x = self.encoder1(x)
        # print(x.shape)
        acts.append(x)
        x = self.encoder2(x)
        # print(x.shape)
        acts.append(x)
        x = self.encoder3(x)
        # print(x.shape)
        acts.append(x)
        x = self.encoder4(x)
        # print(x.shape)
        acts.append(x)
        return acts


class DecoderBlock(nn.Module):
    def __init__(self, m, n, stride=2):
        super().__init__()

        # B, C, H, W -> B, C/4, H, W
        self.conv1 = nn.Conv2d(m, m // 4, 1)
        self.norm1 = nn.BatchNorm2d(m // 4)
        self.relu1 = nn.ReLU(inplace=False)

        # B, C/4, H, W -> B, C/4, H, W
        self.conv2 = nn.ConvTranspose2d(m // 4, m // 4, 3, stride=stride, padding=1)
        self.norm2 = nn.BatchNorm2d(m // 4)
        self.relu2 = nn.ReLU(inplace=False)

        # B, C/4, H, W -> B, C, H, W
        self.conv3 = nn.Conv2d(m // 4, n, 1)
        self.norm3 = nn.BatchNorm2d(n)
        self.relu3 = nn.ReLU(inplace=False)

    def forward(self, x):
        double_size = (x.size(-2) * 2, x.size(-1) * 2)
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.relu1(x)
        x = self.conv2(x, output_size=double_size)
        x = self.norm2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.relu3(x)
        return x


class FinalBlock(nn.Module):
    def __init__(self, num_filters, num_classes=2):
        super().__init__()

        self.conv1 = nn.ConvTranspose2d(num_filters, num_filters // 2, 3, stride=2, padding=1)
        self.relu1 = nn.ReLU(inplace=False)
        self.conv2 = nn.Conv2d(num_filters // 2, num_filters // 2, 3, padding=1)
        self.relu2 = nn.ReLU(inplace=False)
        self.conv3 = nn.Conv2d(num_filters // 2, num_classes, 1)

    def forward(self, inputs):
        double_size = (inputs.size(-2) * 2, inputs.size(-1) * 2)
        x = self.conv1(inputs, output_size=double_size)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        return x


class LinkNet(nn.Module):
    def __init__(self, encoder, num_classes):
        super().__init__()

        self.encoder = encoder
        filters = self.encoder.filters

        self.decoder4 = DecoderBlock(filters[3], filters[2])
        self.decoder3 = DecoderBlock(filters[2], filters[1])
        self.decoder2 = DecoderBlock(filters[1], filters[0])
        self.decoder1 = DecoderBlock(filters[0], filters[0])

        self.final = FinalBlock(filters[0], num_classes)

    def forward(self, x):
        e1, e2, e3, e4 = self.encoder(x)
        d4 = self.decoder4(e4) + e3
        d3 = self.decoder3(d4) + e2
        d2 = self.decoder2(d3) + e1
        d1 = self.decoder1(d2)
        out = self.final(d1)
        out = out.squeeze(1)  # FIXME ugly 2
        return out


def linknet18(num_classes=1, pretrained=False):
    encoder = ResNetEncoder(resnet18, pretrained=pretrained)
    return LinkNet(encoder, num_classes)


def linknet34(num_classes=1, pretrained=False):
    encoder = ResNetEncoder(resnet34, pretrained=pretrained).to(device)
    return LinkNet(encoder, num_classes).to(device)

In [None]:
unet = linknet18(pretrained=True).to(device)

Проверим размерность выхода.

In [None]:
for batch in train_data_loader:
    break

out = unet.forward(batch['img'].to(device))
print(batch['img'].shape)
print(out.shape)

Обучим сеть.

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
val_criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.99)

In [None]:
unet = unet.to(device)

In [None]:
num_epoch = 10
steps = 10

for epoch in range(num_epoch):
    for i, batch in enumerate(train_data_loader):
        optimizer.zero_grad()
        output = unet(batch['img'].to(device))
        loss = criterion(output, batch['mask'][:,0,:,:].to(device))
        loss.backward()
        optimizer.step()
        steps += 1
        
        if steps % 10 == 0:
            val_loss = 0
            for i, batch in enumerate(val_data_loader):
                output = unet(batch['img'].to(device))
                val_loss += float(
                    val_criterion(output, batch['mask'][:,0,:,:].to(device)).detach())
            val_loss = val_loss / len(val_dataset)
        
            print('steps: {},\ttrain loss: {},\tval loss: {}'.format(
                steps, round(float(loss.detach()), 3), round(val_loss, 3)))

Подготовим итератор по тестовым изображениям.

In [None]:
path_images = list(map(
    lambda x: x.split('.')[0],
    filter(lambda x: x.endswith('.jpg'), os.listdir('data/test/'))))

In [None]:
test_data_loader = DataLoader(
    FaceDataset('data/test', path_images, transforms=image_transforms), batch_size=4)

Сделаем предсказания. К выходу сети применим сигмоиду (исходно выходы без нелинейности), сделаем отсечение по порогу и небольшой постобработку по удалению отдельных пикселей маски и закрашиванию дыр. Для кодирования масок в виде массива в формат Run-length encoding используем rle_encode.

In [None]:
threshold = 0.25
predictions = []

for batch in tqdm.tqdm_notebook(test_data_loader):
    output = torch.sigmoid(unet.forward(batch['img'].to(device))).cpu()
    for i in range(output.shape[0]):
        img = output[i].detach().numpy()
        post_img = remove_small_holes(remove_small_objects(img > threshold))
#         import pdb;pdb.set_trace()
        rle = rle_encode(post_img[:,:240])
        predictions.append(rle)

In [None]:
from matplotlib import pyplot as plt

%matplotlib inline

In [None]:
plt.imshow(post_img)

In [None]:
plt.imshow(np.moveaxis(batch['img'][0].cpu().data.numpy(), [0,1,2], [2,0,1]))

In [None]:
df = pd.DataFrame.from_dict({'image': path_images, 'rle_mask': predictions})
df.to_csv('baseline_submission.csv', index=False)