# Imports

In [1]:
import os
from os.path import join as pjoin

import cv2
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import ReduceLROnPlateau

import segmentation_models_pytorch as smp

import torchmetrics.classification as metrics

import torchvision
from torchvision import transforms

from torch.utils.tensorboard import SummaryWriter

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm
import torchinfo

import matplotlib.pyplot as plt

from additonFunc import uniqufy_path, create_image_plot, save_imgs

# Иницилизация ключевых значений

In [2]:
# LAUNCH_NAME = "MyUnet_FixedSet_default_augs"
LAUNCH_NAME = "resnet50_megaSet_default_augs_lr"


STARTING_EPOCH = 0
LOAD_WEIGHTS = None #
LOAD_ADAM_STATE = None #
USE_MANUAL_TENSORBOARD_FOLDER = None # "/home/sega/progs/AI_Tasks/02_Task/TB_cache/MyUnet_FixedSet_skip_connections_1" #

SAVED_MODEL_PATH = None # "/home/sega/progs/AI_Tasks/02_Task/TB_cache/MyUnet_FixedSet_skip_connections_1/weights_55.pth" #

EPOCHS = 15
LEARNING_RATE = 1E-5 # 0.0001 #1E-5 for resnet-50
WEIGHT_DECAY = 0 # 1E-7

BATCH_SIZE = 96 # 20

SAVE_METHOD = "TORCH" # "TORCH" / "ONNX"
WEIGHT_SAVER = "last" # "all" / "nothing" / "last"

CLASS_NAMES = ['other', 'road']
CLASS_RGB_VALUES = [[0,0,0], [255, 255, 255]]

NORMALIZE_MEAN_IMG =  [0.4295, 0.4325, 0.3961]       #[0.485, 0.456, 0.406]
NORMALIZE_DEVIATIONS_IMG =  [0.2267, 0.2192, 0.2240] #[0.229, 0.224, 0.225]
 
CROP_SIZE = (256, 256)

NUM_WORKERS = 20
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

DATASET_DIR = '/usr/src/app/roads_dataset_cropped/tiff'
VALID_SET   = (pjoin(DATASET_DIR, "val"), pjoin(DATASET_DIR, "val_labels"))
TEST_SET   =  (pjoin(DATASET_DIR, "test"), pjoin(DATASET_DIR, "test_labels"))
TRAIN_SET   = (pjoin(DATASET_DIR, "train"), pjoin(DATASET_DIR, "train_labels"))

trained = False

In [3]:
TBpath = uniqufy_path(f"TB_cache/{LAUNCH_NAME}") if USE_MANUAL_TENSORBOARD_FOLDER is None else USE_MANUAL_TENSORBOARD_FOLDER
TBwriter = SummaryWriter(TBpath)

In [None]:
"""
    Класс содержащий разные методы для поворта изображения
"""
class RotationMethods():

    __target_prefix = "_RotationMethods__rotate_"

    def get_method(self, method_name):
        methods = {key.removeprefix(self.__target_prefix):val for key, val in RotationMethods.__dict__.items() if key.startswith(self.__target_prefix)}
        target_method = methods.get(method_name, None)
        if target_method is None:
            raise AttributeError(f"Not found methods with name: '{method_name}'\n\tAvailable methods: {list(methods.keys())}")
        return target_method

    @staticmethod
    def __rotate_PIL(image, angle):
        i = PImage.fromarray(image)
        return np.array(i.rotate(angle, expand=True))

    @staticmethod
    def __rotate_rotMatrix(image, angle, scale = 1.0):
        (h, w) = image.shape[:2]
        center = (w//2, h//2)

        rotation_matrix = cv2.getRotationMatrix2D(center, angle, scale)

        rotated_img = cv2.warpAffine(image, rotation_matrix, (w, h))

        return rotated_img
        
    @staticmethod
    def __rotate_manualImpl_1(image, angle):
        (h, w) = image.shape[:2]
        angle = 360-angle

        angle_rad = np.deg2rad(angle)
        sin = np.sin(angle_rad)
        cos = np.cos(angle_rad)
        
        new_w = int(abs(h * sin) + abs(w * cos))
        new_h = int(abs(h * cos) + abs(w * sin))

        rotated_img = np.zeros((new_h, new_w, image.shape[2]), dtype=np.uint8)

        center = (w//2, h//2)
        new_center = (new_w // 2, new_h // 2)

        for i in range(new_h):
            for j in range(new_w):
                new_x = j - new_center[0]
                new_y = i - new_center[1]

                x = int(new_x * cos + new_y * sin)
                y = int(-new_x * sin + new_y * cos)

                x += center[0]
                y += center[1]

                if 0 <= x < w and 0 <= y < h:
                    rotated_img[i, j] = image[y, x]
        return rotated_img
    
    @staticmethod
    def __rotate_manualImpl_1vec1(image, angle):
        (h, w) = image.shape[:2]
        angle = np.deg2rad(360 - angle)

        sin = np.sin(angle)
        cos = np.cos(angle)

        new_w = int(abs(h * sin) + abs(w * cos))
        new_h = int(abs(h * cos) + abs(w * sin))

        rotated_img = np.zeros((new_h, new_w, image.shape[2]), dtype=np.uint8)

        center = np.array([w // 2, h // 2])
        new_center = np.array([new_w // 2, new_h // 2])

        new_x, new_y = np.meshgrid(np.arange(new_w), np.arange(new_h))
        new_x = new_x - new_center[0]
        new_y = new_y - new_center[1]

        x = np.round(new_x * cos + new_y * sin).astype(int)
        y = np.round(-new_x * sin + new_y * cos).astype(int)

        x += center[0]
        y += center[1]

        valid_indices = np.logical_and.reduce((x >= 0, x < w, y >= 0, y < h))
        rotated_img[new_y[valid_indices] + new_center[1], new_x[valid_indices] + new_center[0]] = image[y[valid_indices], x[valid_indices]]

        return rotated_img
        
    @staticmethod
    def __rotate_manualImpl_1vec2(image, angle):
        (h, w) = image.shape[:2]
        angle = np.deg2rad(360 - angle)

        sin = np.sin(angle)
        cos = np.cos(angle)

        new_w = int(abs(h * sin) + abs(w * cos))
        new_h = int(abs(h * cos) + abs(w * sin))

        rotated_img = np.zeros((new_h, new_w, image.shape[2]), dtype=np.uint8)

        center = np.array([w // 2, h // 2])
        new_center = np.array([new_w // 2, new_h // 2])

        new_x = np.arange(new_w) - new_center[0]
        new_y = np.arange(new_h) - new_center[1]
        new_x = new_x[:, np.newaxis]
        new_y = new_y[np.newaxis, :]

        x = np.round(new_x * cos + new_y * sin).astype(int)
        y = np.round(-new_x * sin + new_y * cos).astype(int)

        x += center[0]
        y += center[1]

        valid_indices = np.logical_and.reduce((x >= 0, x < w, y >= 0, y < h))
        valid_indices_2d = np.nonzero(valid_indices)

        valid_x = x[valid_indices]
        valid_y = y[valid_indices]
        valid_rotated_x = valid_indices_2d[1] + new_center[0]
        valid_rotated_y = valid_indices_2d[0] + new_center[1]

        mask = np.logical_and.reduce((valid_rotated_x >= 0, valid_rotated_x < new_w, valid_rotated_y >= 0, valid_rotated_y < new_h))
        valid_rotated_x = valid_rotated_x[mask]
        valid_rotated_y = valid_rotated_y[mask]
        valid_x = valid_x[mask]
        valid_y = valid_y[mask]

        rotated_img[valid_rotated_y, valid_rotated_x] = image[valid_y, valid_x]


        return rotated_img

    @staticmethod
    def __rotate_manualImpl_2(image, angle):
        (h, w) = image.shape[:2]

        angle_rad = np.deg2rad(angle)
        sin = np.sin(angle_rad)
        cos = np.cos(angle_rad)
        
        new_w = int(abs(h * sin) + abs(w * cos))
        new_h = int(abs(h * cos) + abs(w * sin))

        rotated_img = np.zeros((new_h, new_w, image.shape[2]), dtype=np.uint8)

        center = (w//2, h//2)
        new_center = (new_w // 2, new_h // 2)

        for i in range(h):
            for j in range(w):
                x = j-center[0]
                y = i-center[1]

                new_x = x * cos + y * sin
                new_y = -x * sin + y * cos

                new_x = int(new_center[0] + new_x)
                new_y = int(new_center[1] + new_y)

                if 0 <= new_x < new_w and 0 <= new_y < new_h:
                    rotated_img[new_y, new_x] = image[i,j]

        return rotated_img


In [None]:
class ImageRotater():

    def __init__(self, mode : str = 'PIL', angle = None):
        self.mode = mode
        self.angle = angle
        self.rotation_func = RotationMethods().get_method(mode)
        
    def __call__(self, image, angle = None, **kwargs):
        if self.angle is not None:
            angle = self.angle
        return self.rotation_func(image, angle, **kwargs)

    def forward(self, image, angle = None, **kwargs):
        if self.angle is not None:
            angle = self.angle
        return self(image, angle, **kwargs)
    
    def apply(self, image, angle = None, **kwargs):
        if self.angle is not None:
            angle = self.angle
        return self(image, angle, **kwargs)    

# Transform's 

In [4]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

prepare_to_network = A.Lambda(image=to_tensor, mask=to_tensor)

train_transform = A.Compose(
    [
        A.OneOf(
            [
                A.HorizontalFlip(p=1),
                A.VerticalFlip(p=1),
                A.RandomRotate90(p=1),
            ],
            p=0.75,
        ),
        A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True)
    ]
)


valid_transform = A.Compose(
    [
        A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True),
    ]
)

In [5]:
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)
    return semantic_map

def reverse_one_hot(image):
    x = np.argmax(image, axis = -1)
    return x

def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]
    return x

# Dataset's


In [6]:
class RoadsDataset(Dataset):
    def __init__(self, values_dir, labels_dir, class_rgb_values=None, transform=None, readyToNetwork=None):
        self.values_dir = values_dir
        self.labels_dir = labels_dir
        self.class_rgb_values = class_rgb_values
        self.images = [pjoin(self.values_dir, filename) for filename in sorted(os.listdir(self.values_dir))]
        self.labels = [pjoin(self.labels_dir, filename) for filename in sorted(os.listdir(self.labels_dir))]
        self.transform = transform
        self.readyToNetwork = readyToNetwork

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        label_path = self.labels[index]

        image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
        label = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)
        label = one_hot_encode(label, self.class_rgb_values).astype('float')

        if self.transform:
            sample = self.transform(image=image, mask=label)
            image, label = sample['image'], sample['mask']
        if self.readyToNetwork:
            sample = self.readyToNetwork(image=image, mask=label)
            image, label = sample['image'], sample['mask']
        return image, label

In [7]:
sample_dataset = RoadsDataset(*TEST_SET,
                       class_rgb_values=CLASS_RGB_VALUES, transform=valid_transform)

for i in range(10):
    image, mask = sample_dataset[np.random.randint(0, len(sample_dataset))]
    TBwriter.add_figure(f'train samples', create_image_plot(origin=image, true=colour_code_segmentation(
        reverse_one_hot(mask), CLASS_RGB_VALUES)), global_step=i)
del(sample_dataset)

# Model

In [8]:
class ConvBlock(nn.Module):
    def __init__(self, inC : int, outC : int, kernel_size, **kwargs) -> None:
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        self.in_channels = inC
        self.out_channels = outC
        self.kernel_size = kernel_size
        self.conv = nn.Conv2d(inC, outC, kernel_size, **kwargs)
        self.bn = nn.BatchNorm2d(outC)
        self.activation = nn.ReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        return x

class DeConvBlock(nn.Module):
    def __init__(self, inC : int, outC : int, kernel_size, **kwargs) -> None:
        super().__init__()
        if isinstance(kernel_size, int):
            kernel_size = (kernel_size, kernel_size)
        self.in_channels = inC
        self.out_channels = outC
        self.kernel_size = kernel_size
        self.conv = nn.ConvTranspose2d(inC, outC, kernel_size, **kwargs)
        self.bn = nn.BatchNorm2d(outC)
        self.activation = nn.ReLU(True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return self.activation(x)

class ResidualBlock(nn.Module):
    def __init__(self, inC, outC = 0, interC = -1, block_expansion = 4, skip_connection = True):
        super().__init__()
        self.block_expansion = block_expansion

        self.downConv = None
        self.upConv = None
        self.skip_connection = None

        if interC > 0:
            self.downConv = ConvBlock(inC, interC, 1)
        elif interC < 0:
            if interC == -1:
                interC = inC * self.block_expansion
            else:
                interC = inC * -interC
            self.downConv = ConvBlock(inC, interC, 1)
        else:
            interC = inC

        self.mainConv = ConvBlock(interC, interC, 3, padding=1, groups=interC)

        if outC > 0:
            self.upConv = ConvBlock(interC, outC, 1)
        else:
            outC = interC

        if skip_connection:
            self.skip_connection = ConvBlock(inC, outC, 1)
        
    def forward(self, x):
        inX = x.clone()
        if self.downConv:
            x = self.downConv(x)
        x = self.mainConv(x)
        if self.upConv:
            x = self.upConv(x)
        if self.skip_connection:
            x = x + self.skip_connection(inX)
        return x

class ResidualStepBlock(nn.Module):
    def __init__(self, inC, outC, global_block_expansion, interSize = 4, inter_block_expansion = 2, skip_connection = True):
        super().__init__()

        self.skip_connection = None

        _innerConvs = []
        previousC = inC
        for stepC in range(inC, outC, (outC-inC)//interSize):
            _innerConvs.append(ResidualBlock(previousC, stepC, block_expansion=inter_block_expansion))
            previousC = stepC
        _innerConvs.append(ResidualBlock(previousC, outC, block_expansion=inter_block_expansion))
        
        self.innerConvs = nn.Sequential(*_innerConvs)
        if global_block_expansion > 0:
            self.spatialConv = ConvBlock(inC, inC, 3, stride=global_block_expansion, padding=1)
        else:
            self.spatialConv = DeConvBlock(inC, inC, 4, stride=-global_block_expansion, padding=1)

        if skip_connection:
            self.skip_connection = ConvBlock(inC, outC, 1)
    
    def forward(self, x):
        x = self.spatialConv(x)
        resizedX = x.clone()
        x = self.innerConvs(x)
        if self.skip_connection:
            x += self.skip_connection(resizedX)
        return x

class MyEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv_1 = ConvBlock(3, 64, 3, stride=4, padding=1)

        self.conv_2_residual = ResidualStepBlock(64, 96, global_block_expansion=2, interSize=2)

        self.conv_3_residual = ResidualStepBlock(96, 128, global_block_expansion=2, interSize=4)

        self.conv_4_residual = ResidualStepBlock(128, 256, global_block_expansion=2, interSize=4)

        self.conv_5_residual = ResidualStepBlock(256, 512, global_block_expansion=2, interSize=8)


    def forward(self, x):
        xs = [x]
        xs.append(self.conv_1(xs[-1]))
        xs.append(self.conv_2_residual(xs[-1]))
        xs.append(self.conv_3_residual(xs[-1]))
        xs.append(self.conv_4_residual(xs[-1]))
        xs.append(self.conv_5_residual(xs[-1]))
        return xs

class MyDecoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.deconv_1_residual = ResidualStepBlock(512, 256, global_block_expansion=-2, interSize=8)

        self.deconv_2_residual = ResidualStepBlock(256, 128, global_block_expansion=-2, interSize=4)

        self.deconv_3_residual = ResidualStepBlock(128, 96, global_block_expansion=-2, interSize=4)

        self.deconv_4_residual = ResidualStepBlock(96, 64, global_block_expansion=-2, interSize=2)

        self.deconv_5 = DeConvBlock(64, 3, 4, stride=4)
    
    def forward(self, encoder_samples):
        x = self.deconv_1_residual(encoder_samples[-1]) + encoder_samples[-2]
        x = self.deconv_2_residual(x) + encoder_samples[-3]
        x = self.deconv_3_residual(x) + encoder_samples[-4]
        x = self.deconv_4_residual(x) + encoder_samples[-5]
        x = self.deconv_5(x)
        return x


class MyUnet(nn.Module):
    def __init__(self, outClasses : int):
        super().__init__()

        self.encoder = MyEncoder()
        self.decoder = MyDecoder()

        self.classificator = ConvBlock(3, outClasses, 3, padding=1)

    def forward(self, x):
        encoder_xs = self.encoder(x)
        x = self.decoder(encoder_xs)
        x = self.classificator(x)
        return x

In [9]:
# model = MyUnet(2)

ENCODER = 'resnet50'
CLASSES = CLASS_NAMES
ACTIVATION = nn.ReLU

model = smp.Unet(
    encoder_name=ENCODER, 
    classes=len(CLASSES),
    activation=ACTIVATION,
)

# print(model_sum := torchinfo.summary(model, depth=3, input_size=(BATCH_SIZE, 3, *CROP_SIZE), row_settings=["var_names"], verbose=0, col_names=[
#       "input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds", "trainable"]))

# dummy_input = torch.randn(1, 3, *CROP_SIZE)
# torch.onnx.export(
#             model.cpu(),
#             dummy_input,
#             "model.onnx",
#         )

# Dataloader's 

In [10]:
train_dataset = RoadsDataset(*TRAIN_SET,
                       class_rgb_values=CLASS_RGB_VALUES, transform=train_transform, readyToNetwork=prepare_to_network)
valid_dataset = RoadsDataset(*VALID_SET,
                       class_rgb_values=CLASS_RGB_VALUES, transform=valid_transform, readyToNetwork=prepare_to_network)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=BATCH_SIZE//4,
    num_workers=NUM_WORKERS,
)

In [11]:
# images, _ = next(iter(valid_dataloader))
# TBwriter.add_graph(model, images)

if "ONNX" in SAVE_METHOD:
    model_path = f"{TBpath}/model_first.onnx"
    torch.onnx.export(model, torch.empty(size=(BATCH_SIZE, 3, *CROP_SIZE)), model_path)

model = model.to(DEVICE)
print(model_sum := torchinfo.summary(model, input_size=(BATCH_SIZE, 3, *CROP_SIZE), row_settings=["var_names"], verbose=0, col_names=[
      "input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds", "trainable"]))

Layer (type (var_name))                            Input Shape               Output Shape              Param #                   Param %                   Kernel Shape              Mult-Adds                 Trainable
Unet (Unet)                                        [96, 3, 256, 256]         [96, 2, 256, 256]         --                             --                   --                        --                        True
├─ResNetEncoder (encoder)                          [96, 3, 256, 256]         [96, 3, 256, 256]         --                             --                   --                        --                        True
│    └─Conv2d (conv1)                              [96, 3, 256, 256]         [96, 64, 128, 128]        9,408                       0.03%                   [7, 7]                    14,797,504,512            True
│    └─BatchNorm2d (bn1)                           [96, 64, 128, 128]        [96, 64, 128, 128]        128                         0.00%           

# Optimizer's

In [12]:
loss = smp.losses.DiceLoss(mode='binary')

optimizer_encoder = torch.optim.Adam(model.encoder.parameters(), lr=1E-6, weight_decay=WEIGHT_DECAY)
optimizer_decoder = torch.optim.Adam([{'params':model.decoder.parameters()}, {'params':model.segmentation_head.parameters()}], lr=1E-3, weight_decay=WEIGHT_DECAY)


scheduler = ReduceLROnPlateau(optimizer_decoder, 'min', patience=3, threshold=1e-3, cooldown=1, factor=0.5)

# Шаги обучения

In [13]:
def train_step(net, criterion, optimizers, dataloader, epoch: int = None):
    net.train()
    running_loss = 0.
    for images, labels in dataloader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        for optimizer in optimizers:
            optimizer.zero_grad()

        output = net(images)
        loss = criterion(output, labels)
        loss.backward()
        for optimizer in optimizers:
            optimizer.step()
        running_loss += loss

    with torch.no_grad():
        train_loss = running_loss / len(dataloader)
    return train_loss.item()


def valid_step(net, criterion, dataloader, epoch: int = None):
    net.eval()
    running_loss = 0.
    IoU = metrics.BinaryJaccardIndex()
    IoU.to(DEVICE)

    with torch.no_grad():
        for step, (images, labels) in enumerate(dataloader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            output = net(images)

            IoU(output, labels)
            loss = criterion(output, labels)
            running_loss += loss

            save_imgs(pjoin(TBpath, f"valid_samples/samples_{epoch}"), name=f"img_{step}",
                origin=images[0].cpu().numpy().transpose(2, 1, 0),
                true=colour_code_segmentation(reverse_one_hot(
                    labels[0].cpu().numpy().transpose(2, 1, 0)), CLASS_RGB_VALUES),
                pred=colour_code_segmentation(reverse_one_hot(
                    output[0].cpu().numpy().transpose(2, 1, 0)), CLASS_RGB_VALUES))

        TBwriter.add_figure('valid_sample', create_image_plot(
                origin=images[0].cpu().numpy().transpose(2, 1, 0),
                true=colour_code_segmentation(reverse_one_hot(
                    labels[0].cpu().numpy().transpose(2, 1, 0)), CLASS_RGB_VALUES),
                pred=colour_code_segmentation(reverse_one_hot(
                    output[0].cpu().numpy().transpose(2, 1, 0)), CLASS_RGB_VALUES)),
                  epoch)

        valid_loss = running_loss / len(valid_dataloader)

        return valid_loss.item(), IoU.compute().item()


In [14]:
epoch = STARTING_EPOCH


In [15]:
if LOAD_WEIGHTS is not None:
    model.state_dict(torch.load(LOAD_WEIGHTS))
if LOAD_ADAM_STATE is not None:
    optimizer.load_state_dict(torch.load(LOAD_ADAM_STATE))
    
None

# Цикл обучения

In [16]:
best_loss = 10000
trained = True

pbar = tqdm(range(EPOCHS))
pbar.update(epoch)

while(epoch < EPOCHS):
    train_loss = train_step(model, loss, [optimizer_encoder, optimizer_decoder], train_dataloader, epoch)
    valid_loss, iou_score = valid_step(model, loss, valid_dataloader, epoch)
    scheduler.step(valid_loss)

    if WEIGHT_SAVER != "nothing" and valid_loss < best_loss and epoch > 3:
        best_loss = valid_loss

        print(f"[{epoch}] Saved weights with IoU: {iou_score:.2f} | loss: {valid_loss:.4f}")
    
        
        if WEIGHT_SAVER == "all":
            weights_path = f"{TBpath}/weights_{epoch}.pth"
            model_path = f"{TBpath}/model_{epoch}.onnx"
            optimizer_path = f"{TBpath}/optimizer_{epoch}.pth"
            
        elif WEIGHT_SAVER == "last":
            weights_path = f"{TBpath}/weights_last.pth"
            model_path =   f"{TBpath}/model_last.onnx"
            optimizer_path = f"{TBpath}/optimizer_last.pth"

        if "TORCH" in SAVE_METHOD:
            torch.save(model.state_dict(), weights_path)
        
        if "ONNX" in SAVE_METHOD:
            torch.onnx.export(model, torch.empty(size=(BATCH_SIZE, 3, *CROP_SIZE)), model_path)
        
#         torch.save(optimizer.state_dict(), optimizer_path)


    TBwriter.add_scalar('valid loss', valid_loss, epoch)
    TBwriter.add_scalar('train loss', train_loss, epoch)
    
    TBwriter.add_scalar('IoU', iou_score, epoch)

    for ind, optimizer in enumerate([optimizer_encoder, optimizer_decoder]):
        for i, param_group in enumerate(optimizer.param_groups):
            TBwriter.add_scalar(f'learning rate {ind}', float(param_group['lr']), epoch)

    epoch += 1
    pbar.update()
    pbar.set_description(
        f'IoU: {iou_score:.2f}  | train/valid loss: {train_loss:.4f}/{valid_loss:.4f}')


IoU: 0.93  | train/valid loss: 0.2099/0.2140:  33%|███▎      | 5/15 [22:33<45:08, 270.88s/it]  

[4] Saved weights with IoU: 0.93 | loss: 0.2140
[5] Saved weights with IoU: 0.93 | loss: 0.2132


IoU: 0.93  | train/valid loss: 0.2096/0.2132:  40%|████      | 6/15 [27:05<40:39, 271.06s/it]

[6] Saved weights with IoU: 0.93 | loss: 0.2129


IoU: 0.93  | train/valid loss: 0.2092/0.2129:  60%|██████    | 9/15 [40:38<27:06, 271.04s/it]

[9] Saved weights with IoU: 0.93 | loss: 0.2128


IoU: 0.93  | train/valid loss: 0.2091/0.2128:  67%|██████▋   | 10/15 [45:10<22:36, 271.35s/it]

[10] Saved weights with IoU: 0.94 | loss: 0.2126


IoU: 0.94  | train/valid loss: 0.2090/0.2126:  73%|███████▎  | 11/15 [49:42<18:06, 271.53s/it]

[11] Saved weights with IoU: 0.94 | loss: 0.2126


IoU: 0.94  | train/valid loss: 0.2089/0.2126:  80%|████████  | 12/15 [54:14<13:35, 271.68s/it]

[12] Saved weights with IoU: 0.94 | loss: 0.2125


IoU: 0.94  | train/valid loss: 0.2088/0.2125:  87%|████████▋ | 13/15 [58:46<09:03, 271.74s/it]

[13] Saved weights with IoU: 0.94 | loss: 0.2124


IoU: 0.94  | train/valid loss: 0.2087/0.2125: 100%|██████████| 15/15 [1:07:49<00:00, 271.73s/it]

# Тестирование

In [17]:
test_transform = A.Compose(
    [
        A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True),
    ]
)

test_dataset = RoadsDataset(*TEST_SET,
       class_rgb_values=CLASS_RGB_VALUES, transform=valid_transform, readyToNetwork=prepare_to_network)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=36,
    num_workers=NUM_WORKERS,
)
if not trained:
    print(f"Используется не обученная модель, происходит загрузка модели из {LOAD_TESTSAVED_MODEL_PATH_WEIGHTS}")
    model = None
    if "ONNX" in SAVE_METHOD and model is None:
        print(f"Попытка импорта модели из onnx файла")
        try:
            import onnx
            model = onnx.load(SAVED_MODEL_PATH)
        except:
            pass
    if "TORCH" in SAVE_METHOD and model is None:
        print(f"Попытка импорта модели из pth файла")
        model = MyUnet(2)
        #     model = smp.Unet(
        #     encoder_name=ENCODER, 
        #     classes=len(CLASS_NAMES),
        #     activation=ACTIVATION,
        # )
        model.state_dict(torch.load(f=SAVED_MODEL_PATH))

    model.to(DEVICE)
    

In [18]:
TEST_METRIC = metrics.MulticlassStatScores(num_classes=len(CLASS_NAMES), average=None)

In [19]:
def saveDivide(x, y): return torch.nan_to_num(x/y)

def calculate_metric_by_errors(numerator, denominator, classes):
    with torch.no_grad():
        metric_values = saveDivide(numerator, denominator)
        metric_per_class = {classname: val.item()
                            for classname, val in zip(classes, metric_values)}
        metric_average = torch.sum(metric_values)/len(classes)
        metric_average_micro = saveDivide(torch.sum(numerator), torch.sum(denominator))
        return (metric_values, metric_per_class, metric_average, metric_average_micro)

def test_step(model, loader, metric : metrics.MulticlassStatScores):
    classes = CLASS_NAMES
    metric.to(DEVICE)

    iou = metrics.JaccardIndex(task="multiclass", num_classes=2).to(DEVICE)

    with torch.no_grad():
        model.eval()
        for id, (images, labels) in enumerate(loader):
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)
            output = model(images)
            TBwriter.add_figure('test_sample', create_image_plot(
                origin=images[0].cpu().numpy().transpose(2, 1, 0),
                true=colour_code_segmentation(reverse_one_hot(
                    labels[0].cpu().numpy().transpose(2, 1, 0)), CLASS_RGB_VALUES),
                pred=colour_code_segmentation(reverse_one_hot(
                    output[0].cpu().numpy().transpose(2, 1, 0)), CLASS_RGB_VALUES)),
                  id)
            iou.update(output, labels)
            metric.update(output, labels)

    tp, fp, tn, fn = metric._final_state()

    acc = calculate_metric_by_errors((tp+tn), (tp+fp+tn+fn), classes=classes)
    rec = calculate_metric_by_errors(tp, (tp+fn), classes=classes)
    prec = calculate_metric_by_errors(tp, (tp+fp), classes=classes)

    jaccard = calculate_metric_by_errors(tp, (tp+fp+fn), classes=classes)
    dice = calculate_metric_by_errors(2*tp, 2*tp+tn+tp, classes=classes)

    metric_values = {
        "accuracy": acc,
        "recall": rec,
        "precision": prec,
        "jaccard": jaccard,
        "dice": dice
    }

    return metric_values, iou.compute().cpu()

In [20]:
metric_values, iou = test_step(model, valid_dataloader, TEST_METRIC)

RuntimeError: shape '[2, 2]' is invalid for input of size 64

In [None]:
print(metric_values["jaccard"][0])

In [None]:
metric_count = len(metric_values)

fig, axes = plt.subplots(1, metric_count, figsize=(16,8))

result_string = ""

for (metricName, mValues), ax in zip(metric_values.items(), axes):
    result_string += f"""{metricName}:
    \tmacro: {mValues[2]:.3f}
    \tmicro: {mValues[3]:.3f}\n"""
    mVal = mValues[1]
    plt.sca(ax)
    plt.bar(mVal.keys(), mVal.values())
    plt.title(metricName)
    plt.grid(axis='y')
    plt.xticks(rotation=90)
plt.show()

result_string += f"""jaccardIndex:
macro: {iou}"""

with open(f"{TBpath}/save.txt", 'w') as f:
    f.write(result_string)

print(result_string)

In [None]:
TBwriter.close()

IoU: 0.94  | train/valid loss: 0.2087/0.2125: 100%|██████████| 15/15 [1:22:39<00:00, 330.63s/it]