# Binary Semantic Segmentation

Этот ноутбук - уже готовая игровая площадка для создания модели бинарной сематической сегментации, в которой можно менять только конфигурацию. Перед запуском убедитесь, что в директории `dataset` уже лежат нужные датасеты:

```bash
                datasets/
                |-- PipeBoxSegmentation_augmented/
                |   |-- images/
                |   |   |-- train/
                |   |   |-- val/
                |   |-- masks/
                |   |   |-- train/
                |   |   |-- val/
                |-- PipeSegmentation_augmented/
                |   |-- images/
                |   |   |-- train/
                |   |   |-- val/
                |   |-- masks/
                |   |   |-- train/
                |   |   |-- val/
```

Проверка доступности вычислений на GPU

In [None]:
import torch

print("CUDA доступен:", torch.cuda.is_available())
print("Число GPU:", torch.cuda.device_count())
if torch.cuda.is_available():
    print("Имя GPU:", torch.cuda.get_device_name(0))
    print("Версия CUDA:", torch.version.cuda)


##### **Configuration**

**Weak Supervised Learning - 5000 images**

Тут используется датасет, маски для сегментации которого получены просто через заливку bounding boxes для YOLO на задачу детекции. Поэтому *Weak Supervised*. Данных в этом датасете много. Можете его посмотреть в директории:

```bash
    datasets/PipeBoxSegmentation_augmented
```

**Strong Supervised Learning ~ 300 images**

Это обучение происходит на датасете, маски в котором нарисованы руками. Данных в этом датасете очень мало. Можете его увидеть в директории:

```bash
    datasets/PipeSegmentation_augmented
```


In [None]:
class WSLConf:    # Weak Supervised Learning Config
    def __init__(self):
        self.EPOCHS = 10
        self.LEARNING_RATE = 1e-4
        self.BATCH_SIZE = 4
        self.VISUALIZATION_SAMPLES = 20


class SSLConf:    # Strong Supervised Learning Config
    def __init__(self):
        self.EPOCHS = 1
        self.LEARNING_RATE = 1e-4
        self.BATCH_SIZE = 2
        self.VISUALIZATION_SAMPLES = 20

WSLCONF = WSLConf()
SSLCONF = SSLConf()

Общая конфигурация

In [None]:
from dotenv import load_dotenv
import torch
import os


load_dotenv()


class GeneralConfig:

    def __init__(self):

        self.RANDOM_SEED = 42
        
        # --- ЛУЧШЕ НЕ ТРОГАТЬ ---

        self.IMAGE_SIZE = (700, 500)
        self.IN_CHANNELS = 3
        self.CLASSES = 1
        
        self.MODEL_ENCODER_NAME = "resnet34"
        self.MODEL_ENCODER_WEIGHTS = "imagenet"

        self.MLFLOW_TRACKING_URI = os.getenv('MLFLOW_TRACKING_URI')
        self.EXPERIMENT_NAME = "Pipeline Defects Detection"

        self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

        self.EVAL_FREQUENCY = 5

        self.WEEK_DS_PATH = "datasets/PipeBoxSegmentation_augmented"
        self.STRONG_DS_PATH = "datasets/PipeSegmentation_augmented"


GCONF = GeneralConfig()

##### **Internal logic code**

Кастомный класс датасета

In [None]:
from torch.utils.data import Dataset
from glob import glob
from torchvision import transforms
from PIL import Image
import os


class SegmentationDataset(Dataset):
    """Custom dataset for binary segmentation."""

    def __init__(self, images_dir, masks_dir, img_size=(700, 500)):
        self.images = sorted(glob(os.path.join(images_dir, '*')))
        self.masks = sorted(glob(os.path.join(masks_dir, '*')))
        self.img_size = img_size

        # Transformations
        self.transform_img = transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor()
        ])
        self.transform_mask = transforms.Compose([
            transforms.Resize(self.img_size, interpolation=Image.Resampling.NEAREST),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        """Return a transformed image-mask pair."""
        img = Image.open(self.images[idx]).convert('RGB')
        mask = Image.open(self.masks[idx]).convert('L')

        img = self.transform_img(img)
        mask = self.transform_mask(mask)

        # Binarize mask (0 — background, 1 — object)
        mask = (mask > 0).float()
        return img, mask

Проверка правильности структуры директорий датасета

In [None]:
import os


def check_dataset_dirs(dataset_pash: str) -> bool:

    for data_dir in ['images', 'masks']:
        for divide_dir in ['train', 'val']:
            if not os.path.isdir(os.path.join(dataset_pash, data_dir, divide_dir)):
                return False
    return True 


Метрики

In [None]:
def compute_iou(preds: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5) -> float:
    """
    Вычисляет средний IoU для батча.
    """
    preds = torch.sigmoid(preds) if preds.min() < 0 or preds.max() > 1 else preds
    preds = (preds > threshold).float()

    intersection = (preds * targets).sum(dim=(1, 2, 3))
    union = (preds + targets - preds * targets).sum(dim=(1, 2, 3))
    iou = (intersection + 1e-6) / (union + 1e-6)
    return iou.mean().item()


def compute_dice(preds: torch.Tensor, targets: torch.Tensor, threshold: float = 0.5) -> float:
    """
    Вычисляет средний Dice (F1) для батча.
    """
    preds = torch.sigmoid(preds) if preds.min() < 0 or preds.max() > 1 else preds
    preds = (preds > threshold).float()

    intersection = (preds * targets).sum(dim=(1, 2, 3))
    dice = (2 * intersection + 1e-6) / (preds.sum(dim=(1, 2, 3)) + targets.sum(dim=(1, 2, 3)) + 1e-6)
    return dice.mean().item()

Ввод данных с клавиатуры

In [None]:
def _input(message, yes=['y', 'yes'], no=['n','no'], wrong_ans_equals_no=True):

    ans = yes + no

    inp = input(message)
    while inp.lower() not in ans and not wrong_ans_equals_no:
        inp = input(message)
    
    return inp.lower() in yes


Визуализация работы модели

In [None]:
import random
import torch
import numpy as np
import matplotlib.pyplot as plt


def visualize_predictions(
    model,
    dataset,
    device,
    save_path="predictions.png",
    num_samples=12,
    threshold=0.5,
    overlay_alpha=0.5,
):
    """
    Визуализирует результаты сегментации в красивом формате:
        [Image] | [Ground Truth] | [Overlay (GT=Green, Pred=Red)]

    Args:
        model: torch.nn.Module — обученная модель
        dataset: torch.utils.data.Dataset — набор данных
        device: torch.device — устройство (CPU/GPU)
        save_path: str — путь для сохранения визуализации
        num_samples: int — сколько примеров отобразить
        threshold: float — порог бинаризации предсказаний
        overlay_alpha: float — прозрачность оверлея
    """
    model.eval()
    num_samples = min(num_samples, len(dataset))
    indices = random.sample(range(len(dataset)), num_samples)

    # 3 столбца: [Image, GT, Overlay]
    cols = 3
    rows = num_samples
    fig, axes = plt.subplots(rows, cols, figsize=(12, 4 * rows))
    
    # Если num_samples == 1, axes не двумерен, исправим
    if rows == 1:
        axes = np.expand_dims(axes, axis=0)

    # Цвета
    gt_color = np.array([0, 255, 0]) / 255.0   # зелёный
    pred_color = np.array([255, 0, 0]) / 255.0 # красный

    for row_idx, img_idx in enumerate(indices):
        image, mask = dataset[img_idx]
        image_np = image.permute(1, 2, 0).cpu().numpy()
        mask_np = mask.squeeze().cpu().numpy()

        # --- Предсказание ---
        with torch.no_grad():
            pred = model(image.unsqueeze(0).to(device))
            pred = torch.sigmoid(pred).cpu().squeeze().numpy()
            pred_bin = (pred > threshold).astype(np.uint8)

        # --- Изображения ---
        img_orig = np.clip(image_np, 0, 1)

        # Ground Truth
        img_gt = img_orig.copy()
        img_gt[mask_np > 0.5] = (
            gt_color * 0.7 + img_gt[mask_np > 0.5] * (1 - 0.7)
        )

        # Overlay: зелёный GT, красный предикт
        img_overlay = img_orig.copy()
        img_overlay[mask_np > 0.5] = (
            gt_color * overlay_alpha + img_overlay[mask_np > 0.5] * (1 - overlay_alpha)
        )
        img_overlay[pred_bin > 0.5] = (
            pred_color * overlay_alpha + img_overlay[pred_bin > 0.5] * (1 - overlay_alpha)
        )

        # --- Построение ---
        axes[row_idx, 0].imshow(img_orig)
        axes[row_idx, 0].set_title(f"Image {img_idx}", fontsize=10)
        axes[row_idx, 0].axis("off")

        axes[row_idx, 1].imshow(img_gt)
        axes[row_idx, 1].set_title("Ground Truth", fontsize=10)
        axes[row_idx, 1].axis("off")

        axes[row_idx, 2].imshow(img_overlay)
        axes[row_idx, 2].set_title("Overlay (GT=Green, Pred=Red)", fontsize=10)
        axes[row_idx, 2].axis("off")

    # --- Оформление ---
    fig.suptitle("Model Predictions Overview", fontsize=14, fontweight="bold")
    plt.tight_layout(rect=[0, 0, 1, 0.97])
    plt.savefig(save_path, bbox_inches="tight", dpi=150)
    plt.close(fig)

    print(f"Saved visualization to: {save_path}")


Сохранение модели

In [None]:
import os
import torch
from datetime import datetime

def save_model(model, model_name="model"):
    """
    Сохраняет модель PyTorch в папку 'models' с датой и временем в имени.
    
    Args:
        model: экземпляр модели PyTorch
        model_name (str): базовое имя файла (без расширения)
    
    Returns:
        str: полный путь к сохранённому файлу
    """
    # 1. Создаём папку models, если её нет
    save_folder = "models"
    os.makedirs(save_folder, exist_ok=True)
    
    # 2. Формируем метку времени
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # 3. Формируем имя файла: models/model_name_20251108_122430.pth
    filename = f"{model_name}_{timestamp}.pth"
    save_path = os.path.join(save_folder, filename)
    
    # 4. Сохраняем state_dict модели
    torch.save(model.state_dict(), save_path)
    
    print(f"Model saved localy: {save_path}")
    return save_path


Функция обучения

In [None]:
import mlflow
from torch.utils.data import DataLoader
from tqdm import tqdm


def train(model, optimizer, criterion, CFG):

    train_type = 'Weak' if isinstance(CFG, WSLConf) else 'Strong'

    print("# ==================================================")
    print(f"# {train_type} Supervised Learning")
    print("# ==================================================")

    print("\nLoading dataset...", end=' ')

    # Да, тут сложновато...
    if train_type == 'Weak':
        train_ds = SegmentationDataset(
            os.path.join(GCONF.WEEK_DS_PATH, 'images', 'train'), 
            os.path.join(GCONF.WEEK_DS_PATH, 'masks', 'train'), 
            GCONF.IMAGE_SIZE
        )
        val_ds = SegmentationDataset(
            os.path.join(GCONF.STRONG_DS_PATH, 'images', 'train'), 
            os.path.join(GCONF.STRONG_DS_PATH, 'masks', 'train'), 
            GCONF.IMAGE_SIZE
        )
    else:
        train_ds = SegmentationDataset(
            os.path.join(GCONF.STRONG_DS_PATH, 'images', 'train'), 
            os.path.join(GCONF.STRONG_DS_PATH, 'masks', 'train'), 
            GCONF.IMAGE_SIZE
        )
        val_ds = SegmentationDataset(
            os.path.join(GCONF.STRONG_DS_PATH, 'images', 'val'), 
            os.path.join(GCONF.STRONG_DS_PATH, 'masks', 'val'), 
            GCONF.IMAGE_SIZE
        )

    train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True)

    print("success\n")

    run_name = f"Binary Semantic Segmentation ({train_type})"
    with mlflow.start_run(run_name=run_name):

        # Логируем параметры
        print(f"Start run {run_name} with params:")
        hyperparams = vars(CFG) | vars(GCONF) | {"optimizer": optimizer.__class__.__name__, "criterion": criterion.__class__.__name__}
        for key, value in hyperparams.items():
            print(f"\t{key:<{30}} {str(value):<{30}}")
        print()
        mlflow.log_params(hyperparams)

        epochs = CFG.EPOCHS     # Для удобства
        val_steps = max(1, len(train_loader) // GCONF.EVAL_FREQUENCY)

        for epoch in range(epochs):

            # --- Train ---
            model.train()
            total_loss = 0.0

            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]", leave=False)
            for imgs, masks in progress_bar:
                imgs, masks = imgs.to(GCONF.DEVICE), masks.to(GCONF.DEVICE)
                optimizer.zero_grad()
                preds = model(imgs)
                loss = criterion(preds, masks)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

                # --- Валидация, да посреди эпохи, потому что эпохи долгие ---
                if (progress_bar.n + 1) % val_steps == 0 or (progress_bar.n + 1) == progress_bar.total:
                    model.eval()
                    val_loss, val_iou, val_dice = 0.0, 0.0, 0.0
                    with torch.no_grad():
                        for v_imgs, v_masks in val_loader:
                            v_imgs, v_masks = v_imgs.to(GCONF.DEVICE), v_masks.to(GCONF.DEVICE)
                            v_preds = model(v_imgs)

                            v_loss = criterion(v_preds, v_masks)
                            val_loss += v_loss.item()
                            val_iou += compute_iou(v_preds, v_masks)
                            val_dice += compute_dice(v_preds, v_masks)

                    avg_val_loss = val_loss / len(val_loader)
                    avg_val_iou = val_iou / len(val_loader)
                    avg_val_dice = val_dice / len(val_loader)
                    avg_train_loss = total_loss / (progress_bar.n + 1)

                    progress_bar.set_postfix({
                        "train_loss": f"{avg_train_loss:.4f}",
                        "val_loss": f"{avg_val_loss:.4f}",
                        "IoU": f"{avg_val_iou:.4f}",
                        "Dice": f"{avg_val_dice:.4f}"
                    })

                    mlflow.log_metrics({
                        "train_loss": avg_train_loss,
                        "val_loss": avg_val_loss,
                        "val_iou": avg_val_iou,
                        "val_dice": avg_val_dice
                    }, step=epoch * len(train_loader) + progress_bar.n + 1)

        print("Training completed!\n")
        
        # --- Save Model ---
        save_model(model, f"unet_bin_seg_dice{avg_val_dice:.4f}_iou{val_iou:.4f}")

        save_model_to_server = _input("Save the model to the server (Y/n)? ")
        if save_model_to_server:
            print("Saving model...", end=' ')
            model_path = "model"
            torch.save(model.state_dict(), model_path)
            mlflow.log_artifact(model_path)
            mlflow.pytorch.log_model(model, name="UNetBimarySemanticSegmentation")
            print("success")

        # --- Visualize ---
        print("Making visuzlization...", end=' ')
        vis_path = "predictions.png"
        visualize_predictions(model, val_ds, GCONF.DEVICE, save_path=vis_path, num_samples=CFG.VISUALIZATION_SAMPLES)
        mlflow.log_artifact(vis_path)
        print("success")

        print("\n# Weak Supervised Learning Finish")
        print("# ==================================================")

    

In [None]:
import os
import mlflow
from torch import nn
from segmentation_models_pytorch import Unet

print("# ==================================================")
print("# Creating Binary Semantic Segmentation Model")
print("# ==================================================")


USERNAME = input("Enter your name: ")
os.environ["USER"] = USERNAME

print("Connecting to MLFlow...", end=' ')

mlflow.set_tracking_uri(GCONF.MLFLOW_TRACKING_URI)
mlflow.set_experiment(GCONF.EXPERIMENT_NAME)

print('success')

load_new_model = _input("Load model (Y/n)? ")
if load_new_model:
    model = Unet(
        encoder_name=GCONF.MODEL_ENCODER_NAME,
        encoder_weights=None,
        in_channels=GCONF.IN_CHANNELS,
        classes=GCONF.CLASSES
    )
    model.to(GCONF.DEVICE)

    model_name = input("Enter model name in models/ directory: ").split('.')[0] + '.pth'
    model_path = os.path.join('models', model_name)

    print(f"Loading '{model_path}'... ", end=' ')
    
    state_dict = torch.load(model_path, map_location=GCONF.DEVICE)
    model.load_state_dict(state_dict)

    print("success")
else:
    print(f"Loading weights '{GCONF.MODEL_ENCODER_WEIGHTS}'... ", end=' ')

    model = Unet(
        encoder_name=GCONF.MODEL_ENCODER_NAME,
        encoder_weights=GCONF.MODEL_ENCODER_WEIGHTS,
        in_channels=GCONF.IN_CHANNELS,
        classes=GCONF.CLASSES
    )
    model.to(GCONF.DEVICE)

    print("success")

# --- Freeze encoder (transfer learning) ---
for param in model.encoder.parameters():
    param.requires_grad = False


weak_supervised_learning = _input("Choose type of learning (Weak [ W ] / Strong [ S ] )", ['w', 'weak'], ['s','strong'], False)

if weak_supervised_learning:
    
    # Оптимизатор: только обучаемые параметры (декодер)
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=WSLCONF.LEARNING_RATE
    )

    criterion = nn.BCEWithLogitsLoss()  # Для бинарной сегментации

    train(model, optimizer=optimizer, criterion=criterion, CFG=WSLCONF)
else:

    # Оптимизатор: только обучаемые параметры (декодер)
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=SSLCONF.LEARNING_RATE
    )

    criterion = nn.BCEWithLogitsLoss()  # Для бинарной сегментации

    train(model, optimizer=optimizer, criterion=criterion, CFG=SSLCONF)