# Imports

In [2]:
import os
from os.path import join as pjoin

import cv2
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import segmentation_models_pytorch as smp

import torchmetrics.classification as metrics

import torchvision
from torchvision import transforms

from torch.utils.tensorboard import SummaryWriter

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm import tqdm
import torchinfo

import matplotlib.pyplot as plt

from additonFunc import uniqufy_path, create_image_plot

# Иницилизация ключевых значений

In [3]:
EPOCHS = 10
LEARNING_RATE = 0.01

BATCH_SIZE = 32

WEIGHT_SAVER = "last" # "all" / "nothing" / "last"

CLASS_NAMES = ['other', 'road']
CLASS_RGB_VALUES = [[0,0,0], [255, 255, 255]]

NORMALIZE_MEAN = (0, 0, 0)
NORMALIZE_DEVIATIONS = (1, 1, 1)

CROP_SIZE = (128, 128)

NORMALIZE_MEAN_IMG = [0.4295, 0.4325, 0.3961]
NORMALIZE_DEVIATIONS_IMG = [0.2267, 0.2192, 0.2240]

NUM_WORKERS = 0
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# TBwriter = SummaryWriter(uniqufy_path("TB_cache\\roads"))

# Transform's 

In [4]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

prepare_to_network = A.Lambda(image=to_tensor, mask=to_tensor)

train_transform= A.Compose(
    [
        A.RandomCrop(*CROP_SIZE, always_apply=True),
        A.OneOf(
            [
                A.HorizontalFlip(p=1),
                A.VerticalFlip(p=1),
                A.RandomRotate90(p=1),
            ],
            p=0.75,
        ),
        A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True)
    ]
)

valid_transform = A.Compose(
    [
        A.PadIfNeeded(min_height=1536, min_width=1536, always_apply=True, border_mode=cv2.BORDER_CONSTANT),
        A.Normalize(mean=NORMALIZE_MEAN, std=NORMALIZE_DEVIATIONS),
    ]
)

In [5]:
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)
    return semantic_map

def reverse_one_hot(image):
    x = np.argmax(image, axis = -1)
    return x

def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]
    return x

# Dataset's


In [6]:
class RoadsDataset(Dataset):
    def __init__(self, values_dir, labels_dir, class_rgb_values=None, transform=None, readyToNetwork=None):
        self.values_dir = values_dir
        self.labels_dir = labels_dir
        
        self.images = [pjoin(self.values_dir, filename) for filename in sorted(os.listdir(self.values_dir))]
        self.labels = [pjoin(self.labels_dir, filename) for filename in sorted(os.listdir(self.labels_dir))]
        
        self.class_rgb_values = class_rgb_values
        self.transform = transform
        self.readyToNetwork = readyToNetwork

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        label_path = self.labels[index]

        image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
        label = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)
        

        label = one_hot_encode(label, self.class_rgb_values).astype('float')

        if self.transform:
            sample = self.transform(image=image, mask=label)
            image, label = sample['image'], sample['mask']
        if self.readyToNetwork:
            sample = self.readyToNetwork(image=image, mask=label)
            image, label = sample['image'], sample['mask']
        return image, label

# Dataloader's 

In [7]:
train_dataset = RoadsDataset("../Segmentation/data\\tiff\\train", "../Segmentation/data\\tiff\\train_labels",
                       class_rgb_values=CLASS_RGB_VALUES, transform=train_transform, readyToNetwork=prepare_to_network)
# valid_dataset = RoadsDataset("../Segmentation/data\\tiff\\val", "../Segmentation/data\\tiff\\val_labels",
#                        class_rgb_values=CLASS_RGB_VALUES, transform=valid_transform, readyToNetwork=prepare_to_network)
# test_dataset = RoadsDataset("../Segmentation/data\\tiff\\test", "../Segmentation/data\\tiff\\test_labels",
#                        class_rgb_values=CLASS_RGB_VALUES, transform=valid_transform, readyToNetwork=prepare_to_network)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
)
# valid_dataloader = DataLoader(
#     valid_dataset,
#     batch_size=1,
#     shuffle=True,
#     num_workers=0,
# )
# test_dataloader = DataLoader(
#     test_dataset,
#     batch_size=1,
#     shuffle=True,
#     num_workers=0,
# )
import random, tqdm

def visualize(**images):
    n_images = len(images)
    plt.figure(figsize=(16, 4))
    for idx, (name, image) in enumerate(images.items()):
        plt.subplot(1, n_images, idx + 1)
        plt.xticks([])
        plt.yticks([])
        # get title from the parameter names
        plt.title(name.replace('_', ' ').title(), fontsize=20)
        plt.imshow(image)
    plt.show()

for i in range(3):
    random_idx = random.randint(0, len(train_dataset)-1)
    image, mask = train_dataset[random_idx]
    
    visualize(
        original_image = image,
        ground_truth_mask = colour_code_segmentation(reverse_one_hot(mask), CLASS_RGB_VALUES),
        one_hot_encoded_mask = reverse_one_hot(mask)
    )


IndexError: index 10 is out of bounds for axis 0 with size 2

# Model

In [17]:
ENCODER = 'mobilenet_v2'
CLASSES = CLASS_NAMES
ACTIVATION = "sigmoid"

model = smp.Unet(
    encoder_name=ENCODER, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

# images, _ = next(iter(test_dataloader))
# TBwriter.add_graph(model, images)

# torch.onnx.export(model,
#                   images,
#                   "model.onnx")

model = model.to(DEVICE)
print(model_sum := torchinfo.summary(model, input_size=(BATCH_SIZE, 3, *CROP_SIZE), row_settings=["var_names"], verbose=0, col_names=[
      "input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds", "trainable"]))

Layer (type (var_name))                            Input Shape               Output Shape              Param #                   Param %                   Kernel Shape              Mult-Adds                 Trainable
Unet (Unet)                                        [32, 3, 128, 128]         [32, 2, 128, 128]         --                             --                   --                        --                        True
├─MobileNetV2Encoder (encoder)                     [32, 3, 128, 128]         [32, 3, 128, 128]         --                             --                   --                        --                        True
│    └─Sequential (features)                       --                        --                        --                             --                   --                        --                        True
│    │    └─Conv2dNormActivation (0)               [32, 3, 128, 128]         [32, 32, 64, 64]          928                         0.01%           

# Optimizer's

In [18]:
loss = smp.losses.DiceLoss(mode='binary')

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=0.0001)
# optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9)

In [19]:
def addTuples(a1 : tuple, a2 : tuple):
    for i in range(len(a1)):
        a1[i] += a2[i]
    return a1
    

def train_step(net, criterion, optimizer, dataloader, epoch : int = None):
    net.train()
    running_loss = 0.
    for images, labels in dataloader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        output = net(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss

    with torch.no_grad():
        train_loss = running_loss / len(train_dataloader)
    return train_loss.item()

def valid_step(net, criterion, dataloader, epoch : int = None):
    net.eval()
    running_loss = 0.
    IoU = metrics.BinaryJaccardIndex()
    IoU.to(DEVICE)

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            output = net(images)

            IoU(output, labels)
            loss = criterion(output, labels)
            running_loss += loss

        valid_loss = running_loss / len(valid_dataloader)

        return valid_loss.item(), IoU.compute().item()

In [20]:
best_loss = 10000
trained = True

for epoch in (pbar := tqdm(range(EPOCHS))):
    train_loss = train_step(model, loss, optimizer, train_dataloader, epoch)
    valid_loss, iou_score = valid_step(model, loss, valid_dataloader, epoch)

    if WEIGHT_SAVER != "nothing" and valid_loss < best_loss and epoch > 3:
        best_loss = valid_loss

        print(f"Saved weights with IoU: {iou_score:.2f} | loss: {valid_loss:.4f}")

    pbar.set_description(
        f'IoU: {iou_score:.2f}  | train/valid loss: {train_loss:.4f}/{valid_loss:.4f}')

IoU: 0.83  | train/valid loss: 0.3545/0.3758:  50%|█████     | 5/10 [17:40<17:43, 212.77s/it]

Saved weights with IoU: 0.83 | loss: 0.3758


IoU: 0.83  | train/valid loss: 0.3545/0.3758:  50%|█████     | 5/10 [19:12<19:12, 230.58s/it]


KeyboardInterrupt: 

In [None]:
# TBwriter.close()