In [0]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


### Импорты

In [0]:
from google.colab import files

import warnings
warnings.filterwarnings('ignore')

import os
import math
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from skimage.color import gray2rgb

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Function, Variable

### Утилитарные функции

In [0]:
# В PyTorch отсутствует общепринятая функция округления
def round_tensor(tensor, digits):
    return (tensor * 10 ** digits).round() / (10 ** digits)

In [0]:
# Наложение маски на изображение
def put_mask(image, mask):
    image[:,:,0][mask == 1] = 255
    return image

In [0]:
# Функция загрузки модели из облака
def download_model(name):
    files.download(name + '.pth')

### Конфигурация обучения

In [0]:
class Config:
    __options = [
        # Для обучения
        {'batch_size': 16, 'lr': 2e-3, 'n_epochs': 10, 'momentum': 1e-5, 'eps': 1e-5},
        {'batch_size': 16, 'lr': 2e-3, 'n_epochs': 500, 'momentum': 1e-5, 'eps': 1e-5},

        # Для тестирования
        {'batch_size': 8},
    ]

    # Данные об изображениях
    WIDTH = 224
    HEIGHT = 224

    # Пути, использущиеся в работе
    # DATA_DIR = "drive/My Drive/CocoMiniPersonsData"
    DATA_DIR = "drive/My Drive/Colab Notebooks/CocoMiniPersonsData"
    OUTPUT_DIR = "output/"
    TRAIN_OUTPUT = "train_output/"

    # Классы объектов датасета
    NUM_CLASSES = 2
    TRAIN_CLASS_PROBS = torch.Tensor([0.92439456, 0.07560544])
    TEST_CLASS_PROBS = torch.Tensor([0.91990379, 0.08009621])

    @staticmethod
    def get_option(idx):
        return Config.__options[idx]

In [0]:
try:
    os.mkdir(Config.OUTPUT_DIR)
    os.mkdir(Config.TRAIN_OUTPUT)
    os.mkdir('temp/')
except FileExistsError:
    pass

### Функции бинаризации

In [0]:
# Функция бинаризации входа. 
# На выходе дает либо 1, либо -1
class BinarizeF(Function):

    @staticmethod
    def forward(cxt, input):
        output = input.new(input.size())
        output[input >= 0] = 1
        output[input < 0] = -1
        return output

    @staticmethod
    def backward(cxt, grad_output):
        grad_input = grad_output.clone()
        return grad_input


binarize = BinarizeF.apply

### Бинаризованные модули

In [0]:
class BinaryTanh(nn.Module):
    def __init__(self):
        super(BinaryTanh, self).__init__()
        self.hardtanh = nn.Hardtanh()

    def forward(self, x):
        output = self.hardtanh(x)
        output = binarize(output)
        return output


class BinaryLinear(nn.Linear):
    def forward(self, x):
        binary_weight = binarize(self.weight)
        if self.bias is None:
            return F.linear(x, binary_weight)
        else:
            return F.linear(x, binary_weight, self.bias)

    def reset_parameters(self):
        # Glorot-инициализация
        in_features, out_features = self.weight.size()
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv


class BinaryConv2d(nn.Conv2d):
    def forward(self, x):
        bw = binarize(self.weight)
        return F.conv2d(x, bw, self.bias, self.stride,
                               self.padding, self.dilation, self.groups)

    def reset_parameters(self):
        # Glorot-инициализация
        in_features = self.in_channels
        out_features = self.out_channels
        for k in self.kernel_size:
            in_features *= k
            out_features *= k
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv

### Функция потерь

In [0]:
def dice_loss(pred, target, smooth=.3):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    numerator = (2. * intersection + smooth)
    denominator = (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth) 
    loss = 1 - numerator / denominator
    return loss.mean() 

def calculate_loss(pred, target, class_weights, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target, pos_weight=class_weights[1])
    pred = F.sigmoid(pred)
    dice = dice_loss(pred, target)
    loss = bce * bce_weight + dice * (1 - bce_weight)
    return loss

### UNet

In [0]:
class DoubleConvUnit(nn.Module):
    def __init__(self, in_channels, out_channels, binary=-1):
        super(DoubleConvUnit, self).__init__()

        if binary == -1:
            self.unit = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True)
            )
        elif binary == 0:
            self.unit = nn.Sequential(
                BinaryConv2d(in_channels, out_channels, 3, padding=1),
                BinaryTanh(),
                Conv2d(out_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True)
            )
        elif binary == 1:
            self.unit = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 3, padding=1),
                nn.ReLU(inplace=True),
                BinaryConv2d(out_channels, out_channels, 3, padding=1),
                BinaryTanh(),
            )

    def forward(self, x):
        return self.unit(x)

In [0]:
class UNet(nn.Module):
    def __init__(self, input_channels, n_classes, name='unet'):
        super(UNet, self).__init__()

        self.input_channels = input_channels
        self.n_classes = n_classes
        self.name = name
        self.device = 'cuda:0'

        self.down_1 = DoubleConvUnit(3, 64)
        self.down_2 = DoubleConvUnit(64, 128, binary=0)
        self.down_3 = DoubleConvUnit(128, 256, binary=0)
        self.down_4 = DoubleConvUnit(256, 512, binary=0)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.up_3 = DoubleConvUnit(512 + 256, 256, binary=1)
        self.up_2 = DoubleConvUnit(256 + 128, 128, binary=1)
        self.up_1 = DoubleConvUnit(128 + 64, 64, binary=1)

        self.output_conv = nn.Conv2d(64, n_classes, 1)

    def forward(self, x):
        conv1 = self.down_1(x)
        x = self.maxpool(conv1)

        conv2 = self.down_2(x)
        x = self.maxpool(conv2)

        conv3 = self.down_3(x)
        x = self.maxpool(conv3)

        x = self.down_4(x)

        x = self.upsample(x)
        x = torch.cat([x, conv3], dim=1)

        x = self.up_3(x)
        x = self.upsample(x)
        x = torch.cat([x, conv2], dim=1)

        x = self.up_2(x)
        x = self.upsample(x)
        x = torch.cat([x, conv1], dim=1)

        x = self.up_1(x)
        
        out = self.output_conv(x)
        out = F.softmax(out, dim=1)
        return out

    def load(self, path):
        self.to(self.device)
        self.load_state_dict(torch.load(path))

### Функции выполнения модели

In [0]:
def save_results(softmaxed, image_batch, mask_batch, path, name):
    for idx, predicted_mask in enumerate(softmaxed):
            input_image = image_batch[idx].detach().cpu().numpy()
            input_image = input_image.transpose((1, 2, 0))
            target_mask = mask_batch[idx].detach().cpu().numpy()
            pr_mask = predicted_mask.detach().cpu().numpy().argmax(axis=0)

            fig = plt.figure()

            plot = fig.add_subplot(1, 2, 1)
            with_mask = put_mask(input_image.copy(), pr_mask)
            plt.imshow(with_mask)
            plot.set_title("Predicted")

            plot = fig.add_subplot(1, 2, 2)
            with_mask = put_mask(input_image.copy(), target_mask)
            plt.imshow(with_mask)
            plot.set_title("Ground truth")

            fig.savefig(os.path.join(path, name + f'_id:{idx}.png'))
            plt.close()

In [0]:
def train(net, data_loader, n_epochs, lr, class_weights, verbose=0):
    net.to(net.device)

    print(net.device)

    optimizer = optim.Adam(net.parameters(), lr=lr)
    class_weights = 1.0 / class_weights
    loss_fn = nn.CrossEntropyLoss(weight=class_weights).to(net.device)

    training_time = time.time()
    for epoch in range(n_epochs):
        epoch_time = time.time()

        train_loss = 0.0
        processed = 0
        for X_batch, y_batch in data_loader:
            image_batch = Variable(X_batch).to(net.device)
            mask_batch = Variable(y_batch).to(net.device)

            output_batch = net(image_batch)

            optimizer.zero_grad()
            loss = calculate_loss(output_batch, mask_batch.unsqueeze(1), class_weights=class_weights)
            # loss = loss_fn(output_batch, mask_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.float()
            processed += data_loader.batch_size
            processed_percent = round(
                100.0 * processed / (len(data_loader) * data_loader.batch_size), 
                ndigits=3
            )
            if verbose == 2:
                print(f"Эпоха {epoch}: {processed_percent}%,"
                    f"loss: {round_tensor(train_loss / processed, 5)}")
                
        # Вывод результатов эпохи
        epoch_time = time.time() - epoch_time
        if verbose:
            print(f"Окончание эпохи. Эпоха: {epoch}, "
                  f"train_loss: {round_tensor(train_loss / processed, 5)}; {epoch_time} сек.; "
                  f"Сохранение...")
        
        # Сохраняем модель по окончанию эпохи
        with open(f'{net.name}.pth', 'w'):
            torch.save(net.state_dict(), f"{net.name}.pth")
        download_model(net.name)

    print(f"Время обучения: {time.time() - training_time} сек.")

In [0]:
def predict(net, data_loader, class_weights):
    with torch.no_grad():
        loss_func = nn.CrossEntropyLoss(1.0 / class_weights).to(net.device)

        batch_id = 0
        for X_batch, y_batch in data_loader:
            image_batch = Variable(X_batch).to(net.device)
            mask_batch = Variable(y_batch).to(net.device)

            softmaxed = net(image_batch)
            loss = loss_func(softmaxed, mask_batch)

            save_results(softmaxed, image_batch, mask_batch, 
                        path=Config.OUTPUT_DIR, name=f'batch:{batch_id}')
            
            batch_id += 1

### Набор данных

In [0]:
class CocoPersons_Segmentation(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, train=True):
        super(CocoPersons_Segmentation, self).__init__()

        set_ = 'train/' if train else 'test/'
        self.images_dir = os.path.join(dataset_dir, set_ , 'images/')
        self.masks_dir = os.path.join(dataset_dir, set_, 'masks/')

        self.masks_files = os.listdir(self.masks_dir)
        self.file_names = [mask.split('_')[-1].split('.')[0] for mask in self.masks_files if mask.endswith('.png')]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        name = self.file_names[idx]
        mask_path = os.path.join(self.masks_dir, 'seg_' + name + '.png')
        image_path = os.path.join(self.images_dir, name + '.jpg')

        image = torch.FloatTensor(self.__load_image(image_path))
        mask = torch.LongTensor(self.__load_mask(mask_path))
        return image, mask

    def __load_image(self, path):
        img = np.array(Image.open(path).resize((Config.WIDTH, Config.HEIGHT)))

        try:
            image = img.transpose((2, 0, 1))
        except ValueError:
            image = gray2rgb(img).transpose((2, 0, 1))

        image = np.array(image, dtype=np.float32) / 255.0
        return image

    def __load_mask(self, path):
        mask = Image.open(path).resize((Config.WIDTH, Config.HEIGHT))
        mask = np.array(mask, dtype=np.uint8) / 255.0
        return mask

### Обучение модели

In [0]:
# Выбираем параметры обучения
train_options = Config.get_option(1)

In [0]:
# Создаем объект-DataLoader на тренировочных данных
coco_train_set = CocoPersons_Segmentation(dataset_dir=Config.DATA_DIR, train=True)
train_loader = torch.utils.data.DataLoader(
    coco_train_set,
    batch_size=train_options['batch_size'],
    shuffle=True,
    num_workers=4
)

In [0]:
# Создаем и строим модель
model = UNet(input_channels=3, n_classes=Config.NUM_CLASSES, name='unet0')

# Запускаем обучение модели
train(
    model, 
    data_loader=train_loader,
    n_epochs=train_options['n_epochs'],
    lr=train_options['lr'],
    class_weights=Config.TRAIN_CLASS_PROBS,
    verbose=1
)

NameError: ignored

### Тестирование модели

In [0]:
# Делаем похожие шаги
test_options = Config.get_option(-1)

coco_test_set = CocoPersons_Segmentation(dataset_dir=Config.DATA_DIR, train=False)
test_loader = torch.utils.data.DataLoader(
    coco_test_set,
    batch_size=test_options['batch_size'],
    shuffle=True,
    num_workers=4
)

In [0]:
# Загружаем обученную модель
model = UNet(input_channels=3, n_classes=Config.NUM_CLASSES)
model.load('unet0.pth')

In [0]:
# Строим предсказание на тестовых данных
predict(
    model,
    test_loader,
    class_weights=Config.TEST_CLASS_PROBS
)