In [None]:
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import glob
import torch.nn as nn
import torch.nn.functional as F
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
import albumentations as A
import random
%matplotlib inline

In [None]:
CLASS_LIST = [
    "background",
    "green_moss", 
    "soil", 
]
ENCODER = 'resnet18'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'softmax2d' 
MACHINE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

EPOCHS = 60
BATCH_SIZE = 8

INFER_SIZE = 512

DATA_DIRECTORY = "project"
TRAIN_DATA = f"{DATA_DIRECTORY}/Train"
TRAIN_DATA_CAMVID = f"{DATA_DIRECTORY}/Trainannot"
VALIDATION_DATA = f"{DATA_DIRECTORY}/Validation"
VALIDATION_DATA_CAMVID = f"{DATA_DIRECTORY}/Validationannot"
MASKS_COLORS = f"{DATA_DIRECTORY}/label_colors.txt"

class NamedLoss(smp.losses.DiceLoss):
    __name__ = "DiceLoss_ignore255"

loss = NamedLoss(mode='multiclass', from_logits=True, ignore_index=255)

torch.backends.cudnn.benchmark = True

In [None]:
def merge_class_masks(channels: np.ndarray):
    """Преобразует многоканальную маску в цветную и рассчитывает площади."""
    
    color_map = {
        "background": np.array([0, 0, 0]),
        "green_moss": np.array([42, 125, 209]),
        "soil": np.array([170, 240, 209]),
    }

    h, w = channels.shape[1:]
    composite = np.zeros((h, w, 3), dtype=np.uint8)
    area_distribution = {}

    for idx, binary_mask in enumerate(channels):
        class_name = CLASS_LIST[idx]
        binary_mask = binary_mask.squeeze()

        area_ratio = binary_mask.sum() / binary_mask.size
        area_distribution[class_name] = area_ratio

        composite += np.multiply.outer(binary_mask > 0, color_map[class_name]).astype(np.uint8)

    label_info = "Покрытие: " + "\n".join([
        f"{name}: {area_distribution[name]*100:.1f}%" for name in CLASS_LIST
    ])
    return composite, label_info


def show_masked_image(original_img: np.ndarray, multichannel_mask: np.ndarray):
    """Визуализирует исходное изображение и объединённую цветную маску."""
    
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(original_img)
    axes[0].set_title("Оригинал")

    formatted_mask = multichannel_mask.transpose(2, 0, 1)
    colored_mask, caption = merge_class_masks(formatted_mask)

    axes[1].imshow(colored_mask)
    axes[1].set_title(caption)

    plt.tight_layout()
    plt.show()


In [None]:
class Dataset(BaseDataset):
    def __init__(self, images_dir, masks_dir, augmentation=None, preprocessing=None):
        self.images_paths = glob.glob(os.path.join(images_dir, "*"))
        self.masks_paths = glob.glob(os.path.join(masks_dir, "*"))

        self.class_colors = self._parse_label_color_map(MASKS_COLORS)
        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def _parse_label_color_map(self, label_colors_dir):
        with open(label_colors_dir, 'r') as file:
            lines = file.readlines()

        class_colors = {}
        for line in lines:
            R, G, B, label = line.strip().split()
            class_colors[label] = np.array([int(B), int(G), int(R)], dtype=np.uint8)

        class_colors_sorted = {}
        for cls in CLASS_LIST:
            if cls in class_colors:
                class_colors_sorted[cls] = class_colors[cls]
            elif cls == "background":
                class_colors_sorted[cls] = np.array([0, 0, 0], dtype=np.uint8)
            else:
                raise ValueError(f"unexpected label {cls}, cls colors: {class_colors}")

        return class_colors_sorted

    def __getitem__(self, i):
        image = cv2.cvtColor(cv2.imread(self.images_paths[i]), cv2.COLOR_BGR2RGB)

        mask_rgb = cv2.imread(self.masks_paths[i])
        mask_shape = mask_rgb.shape[:2]
        mask = np.zeros(mask_shape, dtype="uint8")

        for idx, color in enumerate(self.class_colors.values()):
            match = np.all(mask_rgb == color, axis=-1)
            mask[match] = idx

        if self.augmentation:
            augmented = self.augmentation(image=image, mask=mask)
            image, mask = augmented["image"], augmented["mask"]

        if self.preprocessing:
            processed = self.preprocessing(image=image, mask=mask)
            image, mask = processed["image"], processed["mask"]

        return image, mask

    def __len__(self):
        return len(self.images_paths)


In [None]:
# Примеры фото датасета с масками
dataset = Dataset(TRAIN_DATA, TRAIN_DATA_CAMVID)
sample_index = random.randrange(len(dataset))
image, mask = dataset[sample_index]
one_hot = np.eye(len(dataset.class_colors))[mask]
show_masked_image(image, one_hot)

In [None]:
def build_training_pipeline():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.LongestMaxSize(max_size=512),
        A.PadIfNeeded(512, 512,
                      border_mode=cv2.BORDER_CONSTANT,
                      value=(0,0,0),
                      mask_value=255)
    ])

def build_validation_pipeline():
    return A.Compose([
        A.LongestMaxSize(max_size=512),
        A.PadIfNeeded(512, 512,
                      border_mode=cv2.BORDER_CONSTANT,
                      value=(0,0,0),
                      mask_value=255)
    ])

def image_to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def build_preprocessing_chain(preprocessing_fn):
    _transform = [
        A.Lambda(image=preprocessing_fn),
        A.Lambda(image=image_to_tensor),
    ]
    return A.Compose(_transform)


In [None]:
# Визуализация аугментации для train
dataset_aug = Dataset(
    TRAIN_DATA, 
    TRAIN_DATA_CAMVID, 
    augmentation=build_training_pipeline()
)

sample_index = random.randrange(len(dataset_aug))
image, mask = dataset_aug[sample_index]
one_hot = np.eye(len(dataset.class_colors))[mask]
show_masked_image(image, one_hot)

In [None]:
# Визуализация аугментации для validation
dataset_aug = Dataset(
    VALIDATION_DATA, 
    VALIDATION_DATA_CAMVID, 
    augmentation=build_validation_pipeline()
)

sample_index = random.randrange(len(dataset_aug))
image, mask = dataset_aug[sample_index]
one_hot = np.eye(len(dataset.class_colors))[mask]
show_masked_image(image, one_hot)

In [None]:
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASS_LIST), 
    activation=None,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [None]:
optimizer = torch.optim.Adam([
    dict(params=model.parameters(), lr=0.0005),
])

LR_STEP = 15

scheduler = StepLR(optimizer, step_size=LR_STEP, gamma=1 / 2)

In [None]:
train_dataset = Dataset(
    images_dir      = TRAIN_DATA,
    masks_dir       = TRAIN_DATA_CAMVID,
    augmentation    = build_training_pipeline(),
    preprocessing   = build_preprocessing_chain(preprocessing_fn)
)

valid_dataset = Dataset(
    images_dir      = VALIDATION_DATA,
    masks_dir       = VALIDATION_DATA_CAMVID,
    augmentation    = build_validation_pipeline(),
    preprocessing   = build_preprocessing_chain(preprocessing_fn)
)

train_loader = DataLoader(
    dataset        = train_dataset,
    batch_size     = BATCH_SIZE,
    shuffle        = True
)

valid_loader = DataLoader(
    dataset        = valid_dataset,
    batch_size     = 1,
    shuffle        = False
)

In [None]:
class WrappedMetric(nn.Module):
    def __init__(self, metric_fn, num_classes, name):
        super().__init__()
        self.metric_fn = metric_fn
        self.num_classes = num_classes
        self.__name__ = name  

    def forward(self, y_pred, y_true):
        y_true = y_true.long()
        y_true = F.one_hot(y_true, self.num_classes).permute(0, 3, 1, 2).float()
        return self.metric_fn(y_pred, y_true)

metrics = [
    WrappedMetric(smp.utils.metrics.IoU(threshold=0.5), num_classes=len(CLASS_LIST), name="iou_score"),
    WrappedMetric(smp.utils.metrics.Fscore(threshold=0.5), num_classes=len(CLASS_LIST), name="f1_score"),
]

In [None]:
train_epoch = utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=MACHINE,
    verbose=True,
)

valid_epoch = utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=MACHINE,
    verbose=True,
)

In [None]:
print("image:", image.shape, image.dtype)
print("mask :", mask.shape, mask.dtype)

In [None]:
def save_best(model, score, best_so_far):
    if score <= best_so_far:
        return best_so_far
    torch.save(model, "models/best_model_new.pth")

    dummy = torch.randn(BATCH_SIZE, 3, INFER_SIZE, INFER_SIZE).to(MACHINE)
    scripted = torch.jit.trace(model, dummy)
    torch.jit.save(scripted, "models/best_model_new.pt")
    print("Model saved!") 
    return score                   


history = {"loss": {"train": [], "val": []},
           "iou" : {"train": [], "val": []}}

best_iou = 0.0

for epoch_idx in range(EPOCHS):
    print(f"\n epoch {epoch_idx}/{EPOCHS - 1}")

    logs_tr = train_epoch.run(train_loader)
    tr_loss, tr_metric, tr_iou = logs_tr.values()
    history["loss"]["train"].append(tr_loss)
    history["iou"]["train"].append(tr_iou)

    logs_val = valid_epoch.run(valid_loader)
    val_loss, val_metric, val_iou = logs_val.values()
    history["loss"]["val"].append(val_loss)
    history["iou"]["val"].append(val_iou)

    best_iou = save_best(model, val_iou, best_iou)

    scheduler.step()
    if (epoch_idx + 1) % LR_STEP == 0:
        lr_now = optimizer.param_groups[0]["lr"]
        print(f"Learning rate decreased to: {lr_now:.6f}")

In [None]:
# Визуализация оценки обучения
fig, ax_arr = plt.subplots(1, 2, figsize=(12, 5))

# Dice-loss
ax_arr[0].plot(history["loss"]["train"], label="train")
ax_arr[0].plot(history["loss"]["val"],   label="val", linestyle="--")
ax_arr[0].set_title("DiceLoss")
ax_arr[0].legend()

# IoU
ax_arr[1].plot(history["iou"]["train"], label="train")
ax_arr[1].plot(history["iou"]["val"],   label="val", linestyle="--")
ax_arr[1].set_title("IoU")
ax_arr[1].legend()

plt.tight_layout()
plt.show()

---