### Импорты

In [0]:
import warnings
warnings.filterwarnings('ignore')

import os
import math
import time

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from skimage.color import gray2rgb

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Function, Variable

### Утилитарные функции

In [0]:
# В PyTorch отсутствует общепринятая функция округления
def round_tensor(tensor, digits):
    return (tensor * 10 ** digits).round() / (10 ** digits)

In [0]:
# Наложение маски на изображение
def put_mask(image, mask):
    image[:,:,0][mask == 1] = 255
    return image

### Конфигурация обучения

In [0]:
class Config:
    __options = [
        # Для тренировки
        {'batch_size': 128, 'lr': 5e-3, 'n_epochs': 40, 'momentum': 1e-5, 'eps': 1e-5},
        {'batch_size': 64, 'lr': 2e-3, 'n_epochs': 20, 'momentum': 1e-5, 'eps': 1e-5},
        {'batch_size': 16, 'lr': 2e-3, 'n_epochs': 10, 'momentum': 1e-5, 'eps': 1e-5},
        {'batch_size': 16, 'lr': 2e-3, 'n_epochs': 500, 'momentum': 1e-5, 'eps': 1e-5},
        {'batch_size': 32, 'lr': 2e-3, 'n_epochs': 25, 'momentum': 1e-5, 'eps': 1e-5},
        {'batch_size': 64, 'lr': 5e-3, 'n_epochs': 30, 'momentum': 1e-5, 'eps': 1e-5},

        # Для тестирования
        {'batch_size': 8},
    ]

    # Данные об изображениях
    WIDTH = 224
    HEIGHT = 224

    # Пути, использущиеся в работе
    DATA_DIR = "drive/My Drive/Colab Notebooks/CocoMiniPersonsData"
    OUTPUT_DIR = "output/"
    TRAIN_OUTPUT = "train_output/"

    # Классы объектов датасета
    NUM_CLASSES = 2
    TRAIN_CLASS_PROBS = torch.Tensor([0.92439456, 0.07560544])
    TEST_CLASS_PROBS = torch.Tensor([0.91990379, 0.08009621])

    @staticmethod
    def get_option(idx):
        return Config.__options[idx]

In [0]:
# Создание директорий, необходимых при обучении и тестировании
try:
    os.mkdir(Config.OUTPUT_DIR)
    os.mkdir(Config.TRAIN_OUTPUT)
    os.mkdir('temp/')
except FileExistsError:
    pass

### Функция бинаризации

In [0]:
# Функция бинаризации входа. На выходе дает либо 1, либо -1.
# Определена на этапах и прямого, и обратного распространения.
class BinarizeF(Function):

    @staticmethod
    def forward(cxt, input):
        output = input.new(input.size())
        output[input >= 0] = 1
        output[input < 0] = -1
        return output

    @staticmethod
    def backward(cxt, grad_output):
        grad_input = grad_output.clone()
        return grad_input


binarize = BinarizeF.apply

### Бинаризованные модули

In [0]:
# Ниже представлены бинарные варианты стандартных объектов слоев и активаций

# Бинарный вариант функции активации гиперболического тангенса
class BinaryTanh(nn.Module):
    def __init__(self):
        super(BinaryTanh, self).__init__()
        self.hardtanh = nn.Hardtanh()

    def forward(self, x):
        output = self.hardtanh(x)
        output = binarize(output)
        return output


# Полносвязный слой с бинарными весами на этапе прямого распространения
class BinaryLinear(nn.Linear):
    def forward(self, x):
        binary_weight = binarize(self.weight)
        if self.bias is None:
            return F.linear(x, binary_weight)
        else:
            return F.linear(x, binary_weight, self.bias)

    def reset_parameters(self):
        # Glorot-инициализация
        in_features, out_features = self.weight.size()
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv


# Сверточный слой с бинарными весами на этапе прямого распространения
class BinaryConv2d(nn.Conv2d):
    def forward(self, x):
        bw = binarize(self.weight)
        return F.conv2d(x, bw, self.bias, self.stride,
                               self.padding, self.dilation, self.groups)

    def reset_parameters(self):
        # Glorot-инициализация
        in_features = self.in_channels
        out_features = self.out_channels
        for k in self.kernel_size:
            in_features *= k
            out_features *= k
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv


# Сверточный транспозиционный слой с бинарными весами на этапе прямого распространения
class BinaryConvTranspose2d(nn.ConvTranspose2d):
    def forward(self, x):
        bw = binarize(self.weight)
        return F.conv_transpose2d(x, bw, self.bias, self.stride, self.padding,
                                  self.output_padding, self.groups, self.dilation)
        
    def reset_parameters(self):
        # Glorot-инициализация
        in_features = self.in_channels
        out_features = self.out_channels
        for k in self.kernel_size:
            in_features *= k
            out_features *= k
        stdv = math.sqrt(1.5 / (in_features + out_features))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.zero_()

        self.weight.lr_scale = 1. / stdv

### SegNet

In [0]:
# Блок понижающего отбора
class SegnetDownsampleUnit(nn.Module):
    def __init__(self, in_channels, out_channels, binary=False):
        super(SegnetDownsampleUnit, self).__init__()

        if binary:
            self.conv = BinaryConv2d(in_channels, out_channels, kernel_size=3, padding=1)
            self.activation = BinaryTanh()
        else:
            self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
            self.activation = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x


# Блок повышающего отбора
class SegnetUpsampleUnit(nn.Module):
    def __init__(self, in_channels, out_channels, binary=False, use_bn=True):
        super(SegnetUpsampleUnit, self).__init__()

        if binary:
            self.conv_tr = BinaryConvTranspose2d(in_channels, out_channels, kernel_size=3, padding=1)
            self.activation = BinaryTanh()
        else:
            self.conv_tr = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, padding=1)
            self.activation = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else None

    def forward(self, x):
        x = self.conv_tr(x)

        if self.bn is not None:
            x = self.bn(x)

        x = self.activation(x)
        return x

In [0]:
# Реализация архитектуры SegNet с бинарными слоями в составе.
class SegNet(nn.Module):
    def __init__(self, input_channels, output_channels, name='segnet'):
        super(SegNet, self).__init__()

        self.input_channels = input_channels
        self.output_channels = output_channels
        self.name = name
        self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

        exp_fact = 2
        init_layers = 32

        self.ds_00 = SegnetDownsampleUnit(self.input_channels, init_layers)
        self.ds_01 = SegnetDownsampleUnit(init_layers, init_layers)

        self.ds_10 = SegnetDownsampleUnit(init_layers, init_layers * exp_fact)
        self.ds_11 = SegnetDownsampleUnit(init_layers * exp_fact, init_layers * exp_fact, binary=True)

        self.ds_20 = SegnetDownsampleUnit(init_layers * exp_fact, init_layers * exp_fact * 2)
        self.ds_21 = SegnetDownsampleUnit(init_layers * exp_fact * 2, init_layers * exp_fact * 2)
        self.ds_22 = SegnetDownsampleUnit(init_layers * exp_fact * 2, init_layers * exp_fact * 2, binary=True)

        self.ds_30 = SegnetDownsampleUnit(init_layers * exp_fact * 2, init_layers * exp_fact * 4)
        self.ds_31 = SegnetDownsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4, binary=True)
        self.ds_32 = SegnetDownsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4, binary=True)

        self.ds_40 = SegnetDownsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4)
        self.ds_41 = SegnetDownsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4, binary=True)
        self.ds_42 = SegnetDownsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4, binary=True)
        
        self.us_42 = SegnetUpsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4, binary=True)
        self.us_41 = SegnetUpsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4, binary=True)
        self.us_40 = SegnetUpsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4)

        self.us_32 = SegnetUpsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4, binary=True)
        self.us_31 = SegnetUpsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 4, binary=True)
        self.us_30 = SegnetUpsampleUnit(init_layers * exp_fact * 4, init_layers * exp_fact * 2)

        self.us_22 = SegnetUpsampleUnit(init_layers * exp_fact * 2, init_layers * exp_fact * 2, binary=True)
        self.us_21 = SegnetUpsampleUnit(init_layers * exp_fact * 2, init_layers * exp_fact * 2)
        self.us_20 = SegnetUpsampleUnit(init_layers * exp_fact * 2, init_layers * exp_fact)

        self.us_11 = SegnetUpsampleUnit(init_layers * exp_fact, init_layers * exp_fact)
        self.us_10 = SegnetUpsampleUnit(init_layers * exp_fact, init_layers)

        self.us_01 = SegnetUpsampleUnit(init_layers, init_layers)
        self.us_00 = SegnetUpsampleUnit(init_layers, output_channels, use_bn=False)

    def forward(self, x):
        # Энкодер

        dim_0 = x.shape

        x = self.ds_00(x)
        x = self.ds_01(x)
        x, indices_0 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        dim_1 = x.shape

        x = self.ds_10(x)
        x = self.ds_11(x)
        x, indices_1 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        dim_2 = x.shape

        x = self.ds_20(x)
        x = self.ds_21(x)
        x = self.ds_22(x)
        x, indices_2 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        dim_3 = x.shape

        x = self.ds_30(x)
        x = self.ds_31(x)
        x = self.ds_32(x)
        x, indices_3 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        dim_4 = x.shape

        x = self.ds_40(x)
        x = self.ds_41(x)
        x = self.ds_42(x)
        x, indices_4 = F.max_pool2d(x, kernel_size=2, stride=2, return_indices=True)
        dim_d = x.shape
        
        # Декодер
        x = F.max_unpool2d(x, indices_4, kernel_size=2, stride=2, output_size=dim_4)
        x = self.us_42(x)
        x = self.us_41(x)
        x = self.us_40(x)
        dim_d4 = x.shape

        x = F.max_unpool2d(x, indices_3, kernel_size=2, stride=2, output_size=dim_3)
        x = self.us_32(x)
        x = self.us_31(x)
        x = self.us_30(x)
        dim_d3 = x.shape

        x = F.max_unpool2d(x, indices_2, kernel_size=2, stride=2, output_size=dim_2)
        x = self.us_22(x)
        x = self.us_21(x)
        x = self.us_20(x)
        dim_d2 = x.shape

        x = F.max_unpool2d(x, indices_1, kernel_size=2, stride=2, output_size=dim_1)
        x = self.us_11(x)
        x = self.us_10(x)
        dim_d1 = x.shape

        x = F.max_unpool2d(x, indices_0, kernel_size=2, stride=2, output_size=dim_0)
        x = self.us_01(x)
        x = self.us_00(x)

        x_softmax = F.softmax(x, dim=1)
        return x, x_softmax

    # Функция загрузки модели из файла
    def load(self, path):
        self.to(self.device)
        self.load_state_dict(torch.load(path, map_location=self.device))

### Функции оценки качества модели

In [0]:
def pixel_accuracy(true, pred):
    def __pix_acc(mask_true, mask_pred):
        return np.sum(mask_true[mask_true == mask_pred])

    true = true.detach().cpu().numpy()
    pred = pred.detach().cpu().numpy()
    accs = []
    batch_size = true.shape[0]
    for i in range(batch_size):
        accs.append(__pix_acc(true[i], pred[i]))
    return np.sum(accs) / (batch_size * true.shape[1] * true.shape[2])


def dice_coefficient(true, pred):
    eps = 1e-6

    def __pix_rates(mask_true, mask_pred):
        tp = np.sum(mask_pred[mask_pred == mask_true] == 1) / mask_pred.size
        fp = np.sum(mask_pred[mask_pred != mask_true] == 1) / mask_pred.size
        fn = np.sum(mask_true[mask_true != mask_pred] == 1) / mask_pred.size
        return tp, fp, fn

    true = true.detach().cpu().numpy()
    pred = pred.detach().cpu().numpy()
    dices = []
    batch_size = true.shape[0]
    for i in range(batch_size):
        tp, fp, fn = __pix_rates(true[i], pred[i])
        dice = (2 * tp) / (2 * tp + fp + fn)
        dices.append(dice)

    return np.sum(dices) / len(dices)

### Функции выполнения модели

In [0]:
# Функция сохранения изображений с результатами работы модели
def save_results(softmaxed, image_batch, mask_batch, path, name):
    for idx, predicted_mask in enumerate(softmaxed):
            input_image = image_batch[idx].detach().cpu().numpy()
            input_image = input_image.transpose((1, 2, 0))
            target_mask = mask_batch[idx].detach().cpu().numpy()
            pr_mask = predicted_mask.detach().cpu().numpy().argmax(axis=0)

            fig = plt.figure()

            plot = fig.add_subplot(1, 2, 1)
            with_mask = put_mask(input_image.copy(), pr_mask)
            plt.imshow(with_mask)
            plot.set_title("Predicted")

            plot = fig.add_subplot(1, 2, 2)
            with_mask = put_mask(input_image.copy(), target_mask)
            plt.imshow(with_mask)
            plot.set_title("Ground truth")

            fig.savefig(os.path.join(path, name + f'_id:{idx}.png'))
            plt.close()

In [0]:
# Функция тренировки модели
def train(net, data_loader, n_epochs, lr, class_weights, verbose=0):
    net.to(net.device)
    optimizer = optim.Adam(net.parameters(), lr=lr)
    loss_func = nn.CrossEntropyLoss(1.0 / class_weights).to(net.device)

    training_time = time.time()

    for epoch in range(n_epochs):
        epoch_time = time.time()

        train_loss = 0.0
        processed = 0
        for X_batch, y_batch in data_loader:
            image_batch = Variable(X_batch).to(net.device)
            mask_batch = Variable(y_batch).to(net.device)

            output_batch, softmaxed = net(image_batch)

            optimizer.zero_grad()
            loss = loss_func(output_batch, mask_batch)
            loss.backward()
            optimizer.step()

            train_loss += loss.float()
            processed += data_loader.batch_size
            processed_percent = round(
                100.0 * processed / (len(data_loader) * data_loader.batch_size), 
                ndigits=3
            )
            if verbose == 2:
                print(f"Эпоха {epoch}: {processed_percent}%,"
                    f"loss: {round_tensor(train_loss / processed, 5)}")
                
        epoch_time = time.time() - epoch_time
        if verbose:
            print(f"Окончание эпохи. Эпоха: {epoch}, "
                  f"train_loss: {round_tensor(train_loss / processed, 5)}; {epoch_time} сек.; "
                  f"Сохранение...")
        
        with open(f'{net.name}.pth', 'w'):
            torch.save(net.state_dict(), f"{net.name}.pth")

    print(f"Время обучения: {time.time() - training_time} сек.")

In [0]:
# Функция предсказания
def predict(net, data_loader, class_weights, save_prediction_images=False):
    with torch.no_grad():
        loss_func = nn.CrossEntropyLoss(1.0 / class_weights).to(net.device)

        # Отслеживаемые метрики
        accs = []
        dices = []
        losses = []

        batch_id = 0
        for X_batch, y_batch in data_loader:
            image_batch = Variable(X_batch).to(net.device)
            mask_batch = Variable(y_batch).to(net.device)

            output_batch, softmaxed = net(image_batch)
            loss = loss_func(output_batch.squeeze(), mask_batch)

            if save_prediction_images:
                save_results(softmaxed, image_batch, mask_batch, 
                            path=Config.OUTPUT_DIR, name=f'batch:{batch_id}')
            
            accs.append(pixel_accuracy(mask_batch, softmaxed.argmax(axis=1)))
            dices.append(dice_coefficient(mask_batch, softmaxed.argmax(axis=1)))
            losses.append(loss)
            
            batch_id += 1
            
    return np.mean(accs), np.mean(dices), np.mean(losses)

### Набор данных

In [0]:
# Класс-датасет
class CocoPersons_Segmentation(torch.utils.data.Dataset):
    def __init__(self, dataset_dir, train=True):
        super(CocoPersons_Segmentation, self).__init__()

        set_ = 'train/' if train else 'test/'
        self.images_dir = os.path.join(dataset_dir, set_ , 'images/')
        self.masks_dir = os.path.join(dataset_dir, set_, 'masks/')

        self.masks_files = os.listdir(self.masks_dir)
        self.file_names = [mask.split('_')[-1].split('.')[0] for mask in self.masks_files]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        name = self.file_names[idx]
        mask_path = os.path.join(self.masks_dir, 'seg_' + name + '.png')
        image_path = os.path.join(self.images_dir, name + '.jpg')

        image = torch.FloatTensor(self.__load_image(image_path))
        mask = torch.LongTensor(self.__load_mask(mask_path))
        return image, mask

    def __load_image(self, path):
        img = np.array(Image.open(path).resize((Config.WIDTH, Config.HEIGHT)))

        try:
            image = img.transpose((2, 0, 1))
        except ValueError:
            image = gray2rgb(img).transpose((2, 0, 1))

        image = np.array(image, dtype=np.float32) / 255.0
        return image

    def __load_mask(self, path):
        mask = Image.open(path).resize((Config.WIDTH, Config.HEIGHT))
        mask = np.array(mask, dtype=np.uint8) / 255.0
        return mask

### Обучение модели

In [0]:
# Выбираем параметры тренировки
train_options = Config.get_option(3)

In [0]:
# Создаем объект-DataLoader на тренировочных данных
coco_train_set = CocoPersons_Segmentation(dataset_dir=Config.DATA_DIR, train=True)
train_loader = torch.utils.data.DataLoader(
    coco_train_set,
    batch_size=train_options['batch_size'],
    shuffle=True,
    num_workers=4
)

In [0]:
# Создаем и строим модель
model = SegNet(input_channels=3, output_channels=Config.NUM_CLASSES, name='bin_SegNet0')

# Запускаем тренировку модели
train(
    model, 
    data_loader=train_loader, 
    n_epochs=train_options['n_epochs'],
    lr=train_options['lr'], 
    class_weights=Config.TRAIN_CLASS_PROBS,
    verbose=1
)

### Предсказание не тренировочных данных

In [0]:
# Загружаем натренированную модель
model = SegNet(input_channels=3, output_channels=Config.NUM_CLASSES)
model.load('bin_SegNet0.pth')

In [26]:
# Строим предсказание
accuracy, dice_coef, ce_loss = predict(
    model,
    train_loader,
    class_weights=Config.TEST_CLASS_PROBS
)
print(f"Train Accuracy: {accuracy}, Train Dice coefficient: {dice_coef}, CE loss: {ce_loss}")

Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f8e74ff3240>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 941, in _shutdown_workers
    w.join()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 122, in join
    assert self._parent_pid == os.getpid(), 'can only join a child process'
AssertionError: can only join a child process
Exception ignored in: <bound method _MultiProcessingDataLoaderIter.__del__ of <torch.utils.data.dataloader._MultiProcessingDataLoaderIter object at 0x7f8e74ff3240>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 961, in __del__
    self._shutdown_workers()
  File "/usr/local/l

KeyboardInterrupt: ignored

### Предсказание на тестовых данных

In [0]:
# Делаем похожие шаги
test_options = Config.get_option(-1)

coco_test_set = CocoPersons_Segmentation(dataset_dir=Config.DATA_DIR, train=False)
test_loader = torch.utils.data.DataLoader(
    coco_test_set,
    batch_size=test_options['batch_size'],
    shuffle=True,
    num_workers=4
)

In [0]:
# Строим предсказание
accuracy, dice_coef = predict(
    model,
    test_loader,
    class_weights=Config.TEST_CLASS_PROBS,
    save_prediction_images=False
)
print(f"Test Accuracy: {accuracy}, Test Dice coefficient: {dice_coef}")