# PicsArt AI Hackathon Online

## Детектирование фона на изображениях

In [None]:
%pylab inline

import os
import tqdm

import pandas as pd
from PIL import Image
from skimage.morphology import remove_small_objects, remove_small_holes

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vgg13

from utils import rle_encode

In [None]:
prefix = ''
train_im_dir = prefix + 'data/train/'
train_mask_dir = prefix + 'data/train_mask/'
test_im_dir = prefix + 'data/test/'

Загрузим список фотографий из обучающей выборки.

In [None]:
path_images = list(map(
    lambda x: x.split('.')[0],
    filter(lambda x: x.endswith('.jpg'), os.listdir(train_im_dir))))

Разделим на обучающую и валидационную выборки. Для ускорения оставим только 24 изображения для валидации.

In [None]:
train_images, val_images = path_images[:-24], path_images[-24:]

Опишем датасет. Предусмотрим загрузку масок из другой директории с теми же названиями файлов.

In [None]:
class FaceDataset(Dataset):
    def __init__(self, images_dir, images_name, target_dir=None,
                 transforms=None):
        
        self.images_dir = images_dir
        self.target_dir = target_dir
        self.images_name = images_name
        self.transforms = transforms
                           
        print('{} images'.format(len(self.images_name)))

    def __len__(self):
        return len(self.images_name)
               
    def __getitem__(self, idx):
        img_filename = os.path.join(
            self.images_dir, self.images_name[idx] + '.jpg')
        img = Image.open(img_filename)
        
        if self.target_dir:
            mask_filename = os.path.join(
                self.target_dir, self.images_name[idx] + '.png')
            mask = Image.open(mask_filename)
        else:
            mask = []
        
        if self.transforms:
            img = self.transforms(img)
            if mask:
                mask = transforms.ToTensor()(mask)

        return {'img': img, 'mask': mask}

В качестве трансформации возьмём только нормализацию с параметрами от ImageNet, так как будем использовать предобученный кодировщик.

In [None]:
image_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])

In [None]:
train_dataset = FaceDataset(
    images_dir=train_im_dir,
    images_name=train_images,
    target_dir=train_mask_dir,
    transforms=image_transforms)

val_dataset = FaceDataset(
    images_dir=train_im_dir,
    images_name=val_images,
    target_dir=train_mask_dir,
    transforms=image_transforms)

Генераторы для обучения и валидации сети.

In [None]:
train_data_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=4)

Для решения задачи сегментации будем использовать UNet. Энкодер сети сделаем из первых блоков предобученного VGG13.

In [None]:
class VGG13Encoder(torch.nn.Module):
    def __init__(self, num_blocks, pretrained=True):
        super().__init__()
        self.num_blocks = num_blocks
        self.blocks = torch.nn.ModuleList()
        feature_extractor = vgg13(pretrained=pretrained).features
        for i in range(self.num_blocks):
            self.blocks.append(
                torch.nn.Sequential(*[feature_extractor[j]
                                      for j in range(i * 5, i * 5 + 4)]))

    def forward(self, x):
        activations = []
        for i in range(self.num_blocks):
            x = self.blocks[i](x)
            activations.append(x)
            if i != self.num_blocks - 1:
                x = torch.functional.F.max_pool2d(x, kernel_size=2, stride=2)
        return activations

Опишем блок декодера.

In [None]:
class DecoderBlock(torch.nn.Module):
    def __init__(self, out_channels):
        super().__init__()

        self.upconv = torch.nn.Conv2d(
            in_channels=out_channels * 2, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1)
        self.conv1 = torch.nn.Conv2d(
            in_channels=out_channels * 2, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1)
        self.conv2 = torch.nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels,
            kernel_size=3, padding=1, dilation=1)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, down, left):
        x = torch.nn.functional.interpolate(down, scale_factor=2)
        x = self.upconv(x)
        x = self.relu(self.conv1(torch.cat([left, x], 1)))
        x = self.relu(self.conv2(x))
        return x

Сформируем весь декодер из блоков.

In [None]:
class Decoder(torch.nn.Module):
    def __init__(self, num_filters, num_blocks):
        super().__init__()
        self.blocks = torch.nn.ModuleList()
        for i in range(num_blocks):
            self.blocks.append(DecoderBlock(num_filters * 2**(num_blocks-i-1)))

    def forward(self, activations):
        up = activations[-1]
        for i, left in enumerate(activations[-2::-1]):
            up = self.blocks[i](up, left)
        return up

А теперь и всю сеть целиком.

In [None]:
def init_weights(m, nonlinearity_type):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain(nonlinearity_type))

class UNet(torch.nn.Module):
    def __init__(self, num_classes=1, num_filters=64, num_blocks=4):
        super().__init__()
        self.encoder = VGG13Encoder(num_blocks=num_blocks)
        self.decoder = Decoder(num_filters=64, num_blocks=num_blocks - 1)
        self.final = torch.nn.Conv2d(
            in_channels=num_filters, out_channels=num_classes, kernel_size=1)
        
        self.decoder.apply(lambda m: init_weights(m, 'relu'))
        self.final.apply(lambda m: init_weights(m, 'relu'))

    def forward(self, x):
        acts = self.encoder(x)
        x = self.decoder(acts)
        x = self.final(x)
        return x

In [None]:
unet = UNet()
unet

Проверим размерность выхода.

In [None]:
for batch in train_data_loader:
    break

out = unet.forward(batch['img'])
print(batch['img'].shape)
print(out.shape)

Обучим сеть.

In [None]:
criterion = torch.nn.BCEWithLogitsLoss()
val_criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(unet.parameters())

In [None]:
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

unet = unet.to(device)

In [None]:
num_epoch = 1
steps = 0

for epoch in range(num_epoch):
    for i, batch in enumerate(train_data_loader):
        optimizer.zero_grad()
        batch = {k: v.to(device) for k, v in batch.items()}
        output = unet(batch['img'])
        loss = criterion(output, batch['mask'])
        loss.backward()
        optimizer.step()
        steps += 1
        
        if steps % 10 == 0:
            val_loss = 0
            for i, batch in enumerate(val_data_loader):
                batch = {k: v.to(device) for k, v in batch.items()}
                output = unet(batch['img'])
                val_loss += float(
                    val_criterion(output, batch['mask']).detach())
            val_loss = val_loss / len(val_dataset)
        
            print('steps: {},\ttrain loss: {},\tval loss: {}'.format(
                steps, round(float(loss.detach()), 3), round(val_loss, 3)))

Подготовим итератор по тестовым изображениям.

In [None]:
path_images = list(map(
    lambda x: x.split('.')[0],
    filter(lambda x: x.endswith('.jpg'), os.listdir(test_im_dir))))

In [None]:
test_data_loader = DataLoader(
    FaceDataset(test_im_dir, path_images, transforms=image_transforms), batch_size=4)

Сделаем предсказания. К выходу сети применим сигмоиду (исходно выходы без нелинейности), сделаем отсечение по порогу и небольшой постобработку по удалению отдельных пикселей маски и закрашиванию дыр. Для кодирования масок в виде массива в формат Run-length encoding используем rle_encode.

In [None]:
threshold = 0.25
predictions = []

for batch in tqdm.tqdm_notebook(test_data_loader):
    batch['img'] = batch['img'].to(device)
    output = torch.sigmoid(unet.forward(batch['img']))
    for i in range(output.shape[0]):
        img = output[i].detach().to('cpu').numpy()
        post_img = remove_small_holes(remove_small_objects(img > threshold))
        rle = rle_encode(post_img)
        predictions.append(rle)

In [None]:
df = pd.DataFrame.from_dict({'image': path_images, 'rle_mask': predictions})
df.to_csv('baseline_submission.csv', index=False)