# Импорты

In [10]:

# массивы, рандом
import random
import pandas as pd
import numpy as np

# файлы
from os.path import join as pjoin
import os
import json

import seaborn as sns

# модельки
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils

# torch
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

# аугментация
import albumentations as album
import torchvision.transforms as transforms

# отображение
# from tqdm import tqdm
from tqdm.notebook import tqdm
import cv2
import matplotlib.pyplot as plt
import torchvision
import torchinfo
from torch.utils.tensorboard import SummaryWriter
from board import uniqufy_path, create_image_plot
# writer = SummaryWriter()
# writer.add_scalar("Loss/train", loss, epoch)

# метрики
from torchmetrics.classification import BinaryJaccardIndex

# оптимизаторы, изменение lr
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingWarmRestarts

# лосы
from segmentation_models_pytorch.losses import LovaszLoss, DiceLoss


# Константы

In [11]:
DATA_DIR = 'data/tiff/'
DATA_CLASSES = "data/label_class_dict.csv"
PRED_TEST = 'pred/test/'
PRED_VALID = 'pred/valid/'

MEAN_IMAGE_TRANSFORM = [0.4363, 0.4328, 0.3291]
MEAN_IMAGE_STD = [0.2129, 0.2075, 0.2038]
SIZE_IMAGE = 256
BATCH_SIZE = 32

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# BOARD = SummaryWriter()

# Параметры

In [12]:
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')

x_valid_dir = os.path.join(DATA_DIR, 'val')
y_valid_dir = os.path.join(DATA_DIR, 'val_labels')

x_test_dir = os.path.join(DATA_DIR, 'test')
y_test_dir = os.path.join(DATA_DIR, 'test_labels')

class_dict = pd.read_csv("data/label_class_dict.csv")
CLASSES = class_dict['name'].tolist()
CLASSES_RGB = class_dict[['r','g','b']].values.tolist()

print('Классы: ', CLASSES)
print('Классы RGB значений: ', CLASSES_RGB)

with open('data/parametrs.json', 'r', encoding='utf-8') as f:
    PARAMETERS = json.load(f)

Классы:  ['background', 'road']
Классы RGB значений:  [[0, 0, 0], [255, 255, 255]]


# Функции

In [13]:
# ОТОБРАЖЕНИЕ
def print_image(**images):
    n_images = len(images)
    plt.figure(figsize=(16, 4))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(name.replace('_', ' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

# ПРЕОБРАЗОВАНИЯ


def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis=-1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)
    return semantic_map


def reverse_one_hot(image):
    x = np.argmax(image, axis=-1)
    return x


def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]
    return x


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def prepare_to_network():
    return album.Lambda(image=to_tensor, mask=to_tensor)


def transform_train():
    train_transform = album.Compose(
        [
            album.RandomCrop(height=SIZE_IMAGE, width=SIZE_IMAGE,
                             always_apply=True),
            album.OneOf(
                [
                    album.HorizontalFlip(p=1),
                    album.VerticalFlip(p=1),
                    album.RandomRotate90(p=1),
                ],
                p=0.75,
            ),
            album.Normalize(mean=MEAN_IMAGE_TRANSFORM,
                            std=MEAN_IMAGE_STD, always_apply=True)
        ]
    )
    return train_transform


def transform_test():
    test_transform = album.Compose([
        album.PadIfNeeded(min_height=1536, min_width=1536,
                          always_apply=True, border_mode=0),
    ])
    return test_transform

# ПОЛУЧЕНИЕ ПАРАМЕТРОВ ДЛЯ ОБУЧЕНИЯ


def get_model(param, encoder, encoder_weights, classes, activation):
    if (param == 'unet'):
        return smp.Unet(
            encoder_name=encoder,
            encoder_weights=encoder_weights,
            classes=len(classes),
            activation=activation,
        )

    return smp.Unet(
        encoder_name=encoder,
        encoder_weights=encoder_weights,
        classes=len(classes),
        activation=activation,)


def get_function_loss(param):
    if (param == 'DiceLoss'):
        return DiceLoss(mode='binary')

    return LovaszLoss(mode='binary')


def get_optimizer(param, model, lr, weight_decay: float = 0.1, momentum: float = 0.9):
    if param == 'Adam':
        return torch.optim.Adam(
            params=model.parameters(), lr=lr, weight_decay=weight_decay)

    return torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)


def get_scheduler(param, optimizer):
    if param['name'] == 'ReduceLROnPlateau':
        return ReduceLROnPlateau(optimizer, 'min',
                                 patience=param['patience'], threshold=param['threshold'],
                                 cooldown=param['cooldown'], factor=param['factor'])
        # return ReduceLROnPlateau(optimizer, 'min', patience=2)

    return CosineAnnealingWarmRestarts(
        optimizer, T_0=1, T_mult=2, eta_min=5e-5,
    )


def get_metric(param, iou, acc, dice):
    if param == 'iou':
        return iou
    if param == 'acc':
        return acc
    if param == 'dice':
        return dice
    return 0

def dice_score(pred: torch.Tensor, mask: torch.Tensor):
    intersection = torch.sum(pred * mask)
    dice = (2. * intersection) / (torch.sum(pred) + torch.sum(mask))
    return dice.item()

def pixel_accuracy(pred: torch.Tensor, mask: torch.Tensor):
    correct = torch.eq(pred, mask).int()
    return float(correct.sum()) / float(correct.numel())

# САМО ОБУЧЕНИЕ

def train_step(model, criterion, optimizer, dataloader, epoch, epochs):
    model.train()
    running_loss = 0.

    for i, (images, labels) in enumerate(dataloader):
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        output = model(images)
        loss = criterion(output, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss

    with torch.no_grad():
        train_loss = running_loss / len(dataloader)
    return train_loss.item()


def valid_step(model, criterion, dataloader, epoch, BOARD):
    model.eval()
    running_loss = 0.

    iou = BinaryJaccardIndex(num_classes=2)
    iou.to(DEVICE)

    dice = acc = 0

    with torch.no_grad():
        for i, (images, labels) in enumerate(dataloader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            output = model(images)

            iou(output, labels)
            dice += dice_score(output, labels)
            acc += pixel_accuracy(output, labels)

            loss = criterion(output, labels)
            running_loss += loss

            BOARD.add_figure('valid_sample_' + str(i), create_image_plot(
                origin=images[0].cpu().numpy().transpose(2, 1, 0),
                true=colour_code_segmentation(reverse_one_hot(
                    labels[0].cpu().numpy().transpose(2, 1, 0)), CLASSES_RGB),
                pred=colour_code_segmentation(reverse_one_hot(
                    output[0].cpu().numpy().transpose(2, 1, 0)), CLASSES_RGB)),
                epoch)

        dice_value = dice / len(dataloader)
        acc_value = acc / len(dataloader)

        valid_loss = running_loss / len(dataloader)
        return valid_loss.item(), iou.compute().item(), dice_value, acc_value


# Классы

In [14]:
class RoadsDataset(Dataset):
    def __init__(
            self, 
            images_dir, 
            masks_dir, 
            class_rgb_values=None, 
            augmentation=None, 
            preprocessing=None,
    ):
        self.image_paths = [os.path.join(images_dir, image_id) for image_id in sorted(os.listdir(images_dir))]
        self.mask_paths = [os.path.join(masks_dir, image_id) for image_id in sorted(os.listdir(masks_dir))]

        self.class_rgb_values = class_rgb_values
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        image = cv2.cvtColor(cv2.imread(self.image_paths[i]), cv2.COLOR_BGR2RGB)
        mask = cv2.cvtColor(cv2.imread(self.mask_paths[i]), cv2.COLOR_BGR2RGB)
        
        mask = one_hot_encode(mask, self.class_rgb_values).astype('float')
        
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        return image, mask
        
    def __len__(self):
        return len(self.image_paths)

# Загрузка валидации и теста

In [15]:
valid_dataset = RoadsDataset(x_valid_dir, y_valid_dir,
                             class_rgb_values=CLASSES_RGB, augmentation=transform_test(), preprocessing=prepare_to_network())
test_dataset = RoadsDataset(x_test_dir, y_test_dir,
                            class_rgb_values=CLASSES_RGB, augmentation=transform_test(), preprocessing=prepare_to_network())

valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=1,
    num_workers=0,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    num_workers=0,
)


# ОБУЧЕНИЕ

In [16]:
EPOCHS = 12
NUM_WORKERS = 0
# mobilenet_v2

# http://84.237.51.132:9032/
# aboba21932aboba пажилой пароль от второго контейнера

# пока по токену логинься t05042023-2

In [17]:
# exec(PARAMETERS["paramerts"][0]['func'])

In [18]:
for (index_param, param) in tqdm(enumerate(PARAMETERS["paramerts"])):
    id = f"{param['model']}_1_{param['encoder']}_{param['learning_rate']}_{param['metric']}_{param['optimizer']}_{param['scheduler']['name']}_{param['loss']}"
    TBpath = uniqufy_path('res/' + id)
    BOARD = SummaryWriter(TBpath)

    LEARNING_RATE = param['learning_rate']

    ACTIVATION = nn.ReLU

    # МОДЕЛЬ
    model = get_model(param['model'], param['encoder'],
                      'imagenet', CLASSES, ACTIVATION)
    model = model.to(DEVICE)

    # ЗАГРУЗКА
    train_dataset = RoadsDataset(x_train_dir, y_train_dir,
                                 class_rgb_values=CLASSES_RGB, augmentation=transform_train(), preprocessing=prepare_to_network())
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
    )

    # ПАРАМЕТРЫ ДЛЯ ОБУЧЕНИЯ
    criterion = get_function_loss(param['loss'])
    optimizer = get_optimizer(
        param['optimizer'], model, LEARNING_RATE, weight_decay=param['weight_decay'])
    scheduler = get_scheduler(param['scheduler'], optimizer)

    # ПЕРЕМЕННЫЕ ДЛЯ ГРАФИКОВ
    len_steps = len(train_dataloader)

    # ОБУЧЕНИЕ
    pbar = tqdm(range(EPOCHS))
    for epoch in range(EPOCHS):
        for i, param_group in enumerate(optimizer.param_groups):
            BOARD.add_scalar('learning rate', float(param_group['lr']), epoch)

        train_loss = train_step(
            model, criterion, optimizer, train_dataloader, epoch, EPOCHS)
        valid_loss, metric_iou, metric_dice, metric_acc = valid_step(
            model, criterion, valid_dataloader, epoch, BOARD)

        scheduler.step(get_metric(
            param["metric"], metric_iou, metric_acc, metric_dice))

        BOARD.add_scalar('loss_valid', valid_loss, epoch)
        BOARD.add_scalar('loss_train', train_loss, epoch)

        BOARD.add_scalar('metric_iou', metric_iou, epoch)
        BOARD.add_scalar('metric_dice', metric_dice, epoch)
        BOARD.add_scalar('metric_acc', metric_acc, epoch)

        pbar.update()
        pbar.set_description(
            f'{param["metric"]}: {get_metric(param["metric"], metric_iou, metric_acc, metric_dice):.2f}  | train/valid loss: {train_loss:.4f}/{valid_loss:.4f}')

    with torch.no_grad():
        for i, (images, labels) in enumerate(test_dataloader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            output = model(images)
            # METRICS(output, labels)
            loss = criterion(output, labels)

            BOARD.add_figure('test_sample', create_image_plot(
                origin=images[0].cpu().numpy().transpose(2, 1, 0),
                true=colour_code_segmentation(reverse_one_hot(
                    labels[0].cpu().numpy().transpose(2, 1, 0)), CLASSES_RGB),
                pred=colour_code_segmentation(reverse_one_hot(
                    output[0].cpu().numpy().transpose(2, 1, 0)), CLASSES_RGB)),
                i)

            # BOARD.add_scalar('metric_loss', METRICS.compute().item(), i)

    BOARD.close()


0it [00:00, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i