In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from glob import glob
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset

from torch.utils.tensorboard import SummaryWriter

import torch
import random
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import utils
import albumentations as albu

In [None]:
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# Инициализация TensorBoard
writer = SummaryWriter(log_dir="runs/segmentation_experiment")

In [None]:
DATSET_NAME = r"..." # Путь к папке с датасетом

X_TRAIN_DIR = f"{DATSET_NAME}/Train"
Y_TRAIN_DIR = f"{DATSET_NAME}/Trainannot"

X_VALID_DIR = f"{DATSET_NAME}/Validation"
Y_VALID_DIR = f"{DATSET_NAME}/Validationannot"

X_TEST_DIR = f"{DATSET_NAME}/Validation"
Y_TEST_DIR = f"{DATSET_NAME}/Validationannot"

LABEL_COLORS_FILE = f"{DATSET_NAME}/label_colors.txt"

In [None]:
# Определяем классы и их цвета
CLASSES = ["background", "drop"]
colors_imshow = {
    "background": np.array([0, 0, 0]),       # Черный цвет для фона
    "drop": np.array([0, 0, 255]),           # Синий цвет для капли
}


ENCODER ='resnet34'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'softmax2d'
DEVICE = torch.device('cuda')

EPOCHS = 100
BATCH_SIZE = 6

INIT_LR = 0.0005

INFER_WIDTH = 512
INFER_HEIGHT = 512

loss = smp.utils.losses.DiceLoss()


In [None]:
def _convert_multichannel2singlechannel(mc_mask: np.ndarray):
    """
    Преобразует многоканальную маску в одноканальное изображение с цветами для каждого класса.
    Также вычисляет процентное соотношение каждого класса.
    """
    sc_mask = np.zeros((mc_mask[0].shape[0], mc_mask[0].shape[1], 3), dtype=np.uint8)
    square_ratios = {}

    for i, singlechannel_mask in enumerate(mc_mask):
        cls = CLASSES[i]
        singlechannel_mask = singlechannel_mask.squeeze()

        # Вычисляем процентное соотношение каждого класса
        square_ratios[cls] = singlechannel_mask.sum() / singlechannel_mask.size

        # Добавляем цвет для каждого класса
        sc_mask += np.multiply.outer(singlechannel_mask > 0, colors_imshow[cls]).astype(np.uint8)

    # Формируем заголовок с процентным соотношением классов
    title = "Процентное соотношение классов:\n" + "\n".join([f"{cls}: {square_ratios[cls] * 100:.1f}%" for cls in CLASSES])
    return sc_mask, title


def visualize_multichennel_mask(img: np.ndarray, multichennel_mask: np.ndarray):
    """
    Визуализирует изображение и соответствующую маску.
    """
    # Создаем график с двумя подграфиками
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # Отображаем исходное изображение
    axes[0].imshow(img)
    axes[0].set_title("Исходное изображение")
    axes[0].axis("off")

    # Преобразуем многоканальную маску и отображаем ее
    multichennel_mask = multichennel_mask.transpose(2, 0, 1)
    mask_to_show, title = _convert_multichannel2singlechannel(multichennel_mask)
    axes[1].imshow(mask_to_show)
    axes[1].set_title("Маска")
    axes[1].axis("off")

    # Добавляем заголовок с процентным соотношением классов
    plt.suptitle(title, fontsize=12, y=0.95)

    # Добавляем легенду для классов
    legend_elements = [plt.Line2D([0], [0], marker="o", color="w", markerfacecolor=color / 255, markersize=10, label=cls)
                      for cls, color in colors_imshow.items()]
    fig.legend(handles=legend_elements, loc="lower center", ncol=len(CLASSES), bbox_to_anchor=(0.5, -0.05))

    plt.tight_layout()
    plt.show()

In [None]:
class Dataset(BaseDataset):
    def __init__(
        self,
        images_dir,
        masks_dir,
        augmentation=None,
        preprocessing=None
    ):
        self.images_paths = glob(f"{images_dir}/*")
        self.masks_paths = glob(f"{masks_dir}/*")

        self.cls_colors = self._get_classes_colors(LABEL_COLORS_FILE)

        self.augmentation = augmentation
        self.preprocessing = preprocessing

    def _get_classes_colors(self, label_colors_dir):
        cls_colors = {}
        with open(label_colors_dir) as file:
            while line := file.readline():
                R, G, B, label = line.rstrip().split()
                cls_colors[label] = np.array([B, G, R], dtype=np.uint8)

        keyorder = CLASSES
        cls_colors_ordered = {}
        for k in keyorder:
            if k in cls_colors:
                cls_colors_ordered[k] = cls_colors[k]
            elif k == "background":
                cls_colors_ordered[k] = np.array([0, 0, 0], dtype=np.uint8)
            else:
                raise ValueError(f"unexpected label {k}, cls colors: {cls_colors}")

        return cls_colors_ordered

    def __getitem__(self, i):
        image = cv2.imread(self.images_paths[i])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(self.masks_paths[i])
        masks = [cv2.inRange(mask, color, color) for color in self.cls_colors.values()]
        masks = [(m > 0).astype("float32") for m in masks]
        mask = np.stack(masks, axis=-1).astype("float")

        # Применяем аугментации
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample["image"], sample["mask"]

        # Применяем предобработку
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample["image"], sample["mask"]

        return image, mask

    def __len__(self):
        return len(self.images_paths)

In [None]:
dataset = Dataset(X_TRAIN_DIR, Y_TRAIN_DIR)
image, mask = dataset[np.random.randint(len(dataset))]
visualize_multichennel_mask(image, mask)

In [None]:
def get_training_augmentation():
    train_transform = [
        
        albu.LongestMaxSize(max_size=INFER_HEIGHT),
        albu.PadIfNeeded(min_height=int(INFER_HEIGHT*1.1), min_width=int(INFER_WIDTH*1.1), border_mode=2),
        albu.RandomCrop(height=INFER_HEIGHT, width=INFER_WIDTH),
        
        # Размытия и шумы
        albu.OneOf([
            albu.MotionBlur(blur_limit=7, p=0.3),
            albu.GaussianBlur(blur_limit=(3, 7), p=0.3),
            albu.GlassBlur(sigma=0.7, max_delta=3, iterations=2, p=0.3)
        ], p=0.7),
        
        albu.OneOf([
            albu.GaussNoise(p=0.3), 
            albu.ISONoise(
                color_shift=(0.01, 0.05),
                intensity=(0.1, 0.5),
                p=0.3
            ),
            albu.ImageCompression(p=0.3) 
        ], p=0.7),
        
        # Цветовые искажения
        albu.RandomBrightnessContrast(
            brightness_limit=(-0.2, 0.2),
            contrast_limit=(-0.2, 0.2),
            p=0.5
        ),
        albu.HueSaturationValue(
            hue_shift_limit=10,
            sat_shift_limit=20,
            val_shift_limit=10,
            p=0.5
        )
    ]
    return albu.Compose(train_transform)

def get_validation_augmentation():
    test_transform = [
        albu.LongestMaxSize(max_size=INFER_HEIGHT),
        albu.PadIfNeeded(
            min_height=INFER_HEIGHT,
            min_width=INFER_WIDTH,
            border_mode=cv2.BORDER_CONSTANT
        ),
        albu.CenterCrop(height=INFER_HEIGHT, width=INFER_WIDTH)
    ]
    return albu.Compose(test_transform)


def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
    # Осуществит стартовую нормализацию данных согласно своим значениям или готовым для imagenet
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
augmented_dataset = Dataset(
    X_TRAIN_DIR, 
    Y_TRAIN_DIR, 
    augmentation=get_training_augmentation()
)

# same image with different random transforms
indx = np.random.randint(len(augmented_dataset))

for i in range(3):
    image, mask = augmented_dataset[indx]
    visualize_multichennel_mask(image, mask)

In [None]:
augmented_dataset = Dataset(
    X_VALID_DIR, 
    Y_VALID_DIR, 
    augmentation=get_validation_augmentation()
)

indx = np.random.randint(len(augmented_dataset))

image, mask = augmented_dataset[indx]
visualize_multichennel_mask(image, mask)

In [None]:
# create segmentation model with pretrained encoder
model = smp.Unet(
    encoder_name=ENCODER, 
    encoder_weights=ENCODER_WEIGHTS, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

In [None]:
train_dataset = Dataset(
    X_TRAIN_DIR, 
    Y_TRAIN_DIR, 
    augmentation=get_training_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn)
)

valid_dataset = Dataset(
    X_VALID_DIR, 
    Y_VALID_DIR, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn)
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False)

In [None]:
metrics = [
    utils.metrics.Fscore(),
    utils.metrics.IoU(),
    utils.metrics.Precision(),
    utils.metrics.Recall()
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=INIT_LR),
])

# Scheduler без параметра verbose
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='max',  # Максимизируем IoU
    factor=0.5,  # Уменьшаем LR в 2 раза
    patience=5   # Ждём 5 эпох без улучшения
)


In [None]:
# Инициализация эпох обучения и валидации
train_epoch = utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=DEVICE,
    verbose=True,
)

valid_epoch = utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=DEVICE,
    verbose=True,
)

In [None]:
# Цикл обучения
max_score = 0
loss_logs = {"train": [], "val": []}
metric_logs = {"train": {"iou": [], "fscore": [], "precision": [], "recall": []},
               "val": {"iou": [], "fscore": [], "precision": [], "recall": []}}

for i in range(0, EPOCHS):
    print(f'\nEpoch: {i}')
    train_logs = train_epoch.run(train_loader)
    train_loss, train_fscore, train_iou, train_precision, train_recall = list(train_logs.values())
    
    # Логируем тренировочные метрики
    loss_logs["train"].append(train_loss)
    metric_logs["train"]["iou"].append(train_iou)
    metric_logs["train"]["fscore"].append(train_fscore)
    metric_logs["train"]["precision"].append(train_precision)
    metric_logs["train"]["recall"].append(train_recall)
    
    valid_logs = valid_epoch.run(valid_loader)
    val_loss, val_fscore, val_iou, val_precision, val_recall = list(valid_logs.values())
    
    # Логируем валидационные метрики
    loss_logs["val"].append(val_loss)
    metric_logs["val"]["iou"].append(val_iou)
    metric_logs["val"]["fscore"].append(val_fscore)
    metric_logs["val"]["precision"].append(val_precision)
    metric_logs["val"]["recall"].append(val_recall)
    
    # Логирование в TensorBoard
    writer.add_scalar("Loss/train", train_loss, i)
    writer.add_scalar("Loss/val", val_loss, i)
    writer.add_scalar("IoU/train", train_iou, i)
    writer.add_scalar("IoU/val", val_iou, i)
    writer.add_scalar("Fscore/train", train_fscore, i)
    writer.add_scalar("Fscore/val", val_fscore, i)
    writer.add_scalar("Precision/train", train_precision, i)
    writer.add_scalar("Precision/val", val_precision, i)
    writer.add_scalar("Recall/train", train_recall, i)
    writer.add_scalar("Recall/val", val_recall, i)
    writer.add_scalar("Learning Rate", optimizer.param_groups[0]['lr'], i)
    
    # Сохранение модели при улучшении IoU
    if max_score < val_iou:
        max_score = val_iou
        torch.save(model.state_dict(), 'models/best_model_new.pth')
        print('Model saved (.pth)!')
        
        # Трассировка модели (с предупреждением о фиксированном размере)
        try:
            trace_image = torch.randn(BATCH_SIZE, 3, INFER_HEIGHT, INFER_WIDTH)
            traced_model = torch.jit.trace(model, trace_image.to(DEVICE))
            torch.jit.save(traced_model, 'models/best_model_new.pt')
            print('Traced model saved (.pt)!')
        except Exception as e:
            print(f"Tracing failed: {e}. Skipping JIT save.")
    
    # Сохранение чекпоинта каждые 10 эпох
    if i % 10 == 0:
        torch.save(model.state_dict(), f'models/checkpoint_epoch_{i}.pth')
        print(f'Checkpoint saved for epoch {i}!')
    
    # Обновление learning rate с помощью scheduler
    scheduler.step(val_iou)
    print(f"Current LR: {optimizer.param_groups[0]['lr']}")

# Закрываем TensorBoard
writer.close()
