# Imports

In [1]:
import os
from os.path import join as pjoin

import cv2
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

import segmentation_models_pytorch as smp

import torchmetrics.classification as metrics

import torchvision
from torchvision import transforms

from torch.utils.tensorboard import SummaryWriter

import albumentations as A
from albumentations.pytorch import ToTensorV2

from tqdm.notebook import tqdm
import torchinfo

import matplotlib.pyplot as plt

from additonFunc import uniqufy_path, create_image_plot

# Иницилизация ключевых значений

In [2]:
EPOCHS = 10
LEARNING_RATE = 0.00008

BATCH_SIZE = 16

WEIGHT_SAVER = "last" # "all" / "nothing" / "last"

CLASS_NAMES = ['other', 'road']
CLASS_RGB_VALUES = [[0,0,0], [255, 255, 255]]

NORMALIZE_MEAN_IMG = [0.4295, 0.4325, 0.3961]
NORMALIZE_DEVIATIONS_IMG = [0.2267, 0.2192, 0.2240]

CROP_SIZE = (256, 256)
PADDED_SIZE = (1536, 1536)

NUM_WORKERS = 2
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


TBpath = uniqufy_path("TB_cache/roads")
TBwriter = SummaryWriter(TBpath)

# Transform's 

In [3]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

prepare_to_network = A.Lambda(image=to_tensor, mask=to_tensor)

train_transform= A.Compose(
    [
        A.RandomCrop(*CROP_SIZE, always_apply=True),
        A.OneOf(
            [
                A.HorizontalFlip(p=1),
                A.VerticalFlip(p=1),
                A.RandomRotate90(p=1),
            ],
            p=0.75,
        ),
        A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True)
    ]
)


valid_transform = A.Compose(
    [
        A.PadIfNeeded(*PADDED_SIZE, always_apply=True, border_mode=cv2.BORDER_CONSTANT),
        A.Normalize(mean=NORMALIZE_MEAN_IMG, std=NORMALIZE_DEVIATIONS_IMG, always_apply=True),
    ]
)

In [4]:
def one_hot_encode(label, label_values):
    semantic_map = []
    for colour in label_values:
        equality = np.equal(label, colour)
        class_map = np.all(equality, axis = -1)
        semantic_map.append(class_map)
    semantic_map = np.stack(semantic_map, axis=-1)
    return semantic_map

def reverse_one_hot(image):
    x = np.argmax(image, axis = -1)
    return x

def colour_code_segmentation(image, label_values):
    colour_codes = np.array(label_values)
    x = colour_codes[image.astype(int)]
    return x

# Dataset's


In [5]:
class RoadsDataset(Dataset):
    def __init__(self, values_dir, labels_dir, class_rgb_values=None, transform=None, readyToNetwork=None):
        self.values_dir = values_dir
        self.labels_dir = labels_dir
        self.class_rgb_values = class_rgb_values
        self.images = [pjoin(self.values_dir, filename) for filename in sorted(os.listdir(self.values_dir))]
        self.labels = [pjoin(self.labels_dir, filename) for filename in sorted(os.listdir(self.labels_dir))]
        self.transform = transform
        self.readyToNetwork = readyToNetwork

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image_path = self.images[index]
        label_path = self.labels[index]

        image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
        label = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_BGR2RGB)
        label = one_hot_encode(label, self.class_rgb_values).astype('float')

        if self.transform:
            sample = self.transform(image=image, mask=label)
            image, label = sample['image'], sample['mask']
        if self.readyToNetwork:
            sample = self.readyToNetwork(image=image, mask=label)
            image, label = sample['image'], sample['mask']
        return image, label

In [6]:
sample_dataset = RoadsDataset("dataset/tiff/test", "dataset/tiff/test_labels",
                       class_rgb_values=CLASS_RGB_VALUES, transform=train_transform)

for i in range(10):
    image, mask = sample_dataset[np.random.randint(0, len(sample_dataset))]
    TBwriter.add_figure(f'train samples', create_image_plot(origin=image, true=colour_code_segmentation(
        reverse_one_hot(mask), CLASS_RGB_VALUES)), global_step=i)


# Model

In [7]:
ENCODER = 'timm-mobilenetv3_large_100'
CLASSES = CLASS_NAMES
ACTIVATION = nn.ReLU

model = smp.Unet(
    encoder_name=ENCODER, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)


# Dataloader's 

In [8]:
train_dataset = RoadsDataset("dataset/tiff/train", "dataset/tiff/train_labels",
                       class_rgb_values=CLASS_RGB_VALUES, transform=train_transform, readyToNetwork=prepare_to_network)
valid_dataset = RoadsDataset("dataset/tiff/val", "dataset/tiff/val_labels",
                       class_rgb_values=CLASS_RGB_VALUES, transform=valid_transform, readyToNetwork=prepare_to_network)
test_dataset = RoadsDataset("dataset/tiff/test", "dataset/tiff/test_labels",
                       class_rgb_values=CLASS_RGB_VALUES, transform=valid_transform, readyToNetwork=prepare_to_network)

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)
valid_dataloader = DataLoader(
    valid_dataset,
    batch_size=1,
    num_workers=0,
)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    num_workers=0,
)

In [9]:
images, _ = next(iter(test_dataloader))
TBwriter.add_graph(model, images)

model = model.to(DEVICE)
print(model_sum := torchinfo.summary(model, input_size=(BATCH_SIZE, 3, *CROP_SIZE), row_settings=["var_names"], verbose=0, col_names=[
      "input_size", "output_size", "num_params", "params_percent", "kernel_size", "mult_adds", "trainable"]))

  if h % output_stride != 0 or w % output_stride != 0:
  return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
  return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
  if pad_h > 0 or pad_w > 0:


Layer (type (var_name))                                      Input Shape               Output Shape              Param #                   Param %                   Kernel Shape              Mult-Adds                 Trainable
Unet (Unet)                                                  [16, 3, 256, 256]         [16, 2, 256, 256]         --                             --                   --                        --                        True
├─MobileNetV3Encoder (encoder)                               [16, 3, 256, 256]         [16, 3, 256, 256]         --                             --                   --                        --                        True
│    └─MobileNetV3Features (model)                           --                        --                        --                             --                   --                        --                        True
│    │    └─Conv2dSame (conv_stem)                           [16, 3, 256, 256]         [16, 16, 128, 128]  

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


# Optimizer's

In [10]:
loss = smp.losses.DiceLoss(mode='binary')

# optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [11]:
def addTuples(a1: tuple, a2: tuple):
    for i in range(len(a1)):
        a1[i] += a2[i]
    return a1


def train_step(net, criterion, optimizer, dataloader, epoch: int = None):
    net.train()
    running_loss = 0.
    for images, labels in dataloader:
        images = images.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()

        output = net(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss

    with torch.no_grad():
        train_loss = running_loss / len(train_dataloader)
    return train_loss.item()


def valid_step(net, criterion, dataloader, epoch: int = None):
    net.eval()
    running_loss = 0.
    IoU = metrics.BinaryJaccardIndex()
    IoU.to(DEVICE)

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(DEVICE)
            labels = labels.to(DEVICE)

            output = net(images)

            IoU(output, labels)
            loss = criterion(output, labels)
            running_loss += loss

        TBwriter.add_figure('valid_sample', create_image_plot(
                origin=images[0].cpu().numpy().transpose(2, 1, 0),
                true=colour_code_segmentation(reverse_one_hot(
                    labels[0].cpu().numpy().transpose(2, 1, 0)), CLASS_RGB_VALUES),
                pred=colour_code_segmentation(reverse_one_hot(
                    output[0].cpu().numpy().transpose(2, 1, 0)), CLASS_RGB_VALUES)),
                  epoch)

        valid_loss = running_loss / len(valid_dataloader)

        return valid_loss.item(), IoU.compute().item()


In [12]:
best_loss = 10000
trained = True


for epoch in (pbar := tqdm(range(EPOCHS))):
    train_loss = train_step(model, loss, optimizer, train_dataloader, epoch)
    valid_loss, iou_score = valid_step(model, loss, valid_dataloader, epoch)



    if WEIGHT_SAVER != "nothing" and valid_loss < best_loss and epoch > 3:
        best_loss = valid_loss

        print(f"Saved weights with IoU: {iou_score:.2f} | loss: {valid_loss:.4f}")
    
        if WEIGHT_SAVER == "all":
            torch.save(model.state_dict(),
                       f"weights_{epoch}.pth")
        elif WEIGHT_SAVER == "last":
            torch.save(model.state_dict(),
                       f"weights_last.pth")

    TBwriter.add_scalar('valid loss', valid_loss, epoch)
    TBwriter.add_scalar('train loss', train_loss, epoch)
    
    TBwriter.add_scalar('IoU', iou_score, epoch)

    for i, param_group in enumerate(optimizer.param_groups):
        TBwriter.add_scalar('learning rate', float(param_group['lr']), epoch)

    pbar.set_description(
        f'IoU: {iou_score:.2f}  | train/valid loss: {train_loss:.4f}/{valid_loss:.4f}')


  0%|          | 0/10 [00:00<?, ?it/s]

Saved weights with IoU: 0.84 | loss: 0.2823
Saved weights with IoU: 0.85 | loss: 0.2710
Saved weights with IoU: 0.86 | loss: 0.2626
Saved weights with IoU: 0.87 | loss: 0.2574
Saved weights with IoU: 0.87 | loss: 0.2536
Saved weights with IoU: 0.87 | loss: 0.2513


In [13]:
TBwriter.close()