# Лабораторная работа №7
## Модели сегментации

In [2]:
!pip install segmentation_models_pytorch

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cublas_cu12-12.4.5.8-

In [3]:
import os
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import VOCSegmentation
from torchvision import transforms as T

import segmentation_models_pytorch as smp
from segmentation_models_pytorch.metrics.functional import (
    get_stats,
    precision,
    recall,
    iou_score
)

Зададим основные параметры

In [4]:
DATA_ROOT   = "data/VOCdevkit/"
BATCH_SIZE  = 4
NUM_EPOCHS  = 3
LR          = 1e-3
NUM_CLASSES = 21  # в VOC 21 класс (20 объектов + фон)
IMG_SIZE    = (256, 256)
DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"

DEVICE

'cuda'

Преобразования изображений

In [5]:
img_transform = T.Compose([
    T.Resize(IMG_SIZE),
    T.ToTensor(),
    T.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])

In [6]:
def mask_transform(mask: Image.Image) -> torch.Tensor:
    mask = mask.resize(IMG_SIZE, resample=Image.NEAREST)
    mask_np = np.array(mask, dtype=np.int64)

    mask_np[mask_np == 255] = 0
    return torch.from_numpy(mask_np)

Добавим датасеты

In [7]:
train_ds = VOCSegmentation(
    root=DATA_ROOT,
    year="2012",
    image_set="train",
    download=True,
    transform=img_transform,
    target_transform=mask_transform,
)

val_ds = VOCSegmentation(
    root=DATA_ROOT,
    year="2012",
    image_set="val",
    download=False,
    transform=img_transform,
    target_transform=mask_transform,
)

100%|██████████| 2.00G/2.00G [00:59<00:00, 33.9MB/s]


Создадим даталоадеры

In [8]:
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False)

Опишем модель с конволюционными слоями

In [9]:
model_cnn = smp.Unet(
    encoder_name="resnet18",
    in_channels=3,
    classes=NUM_CLASSES,
).to(DEVICE)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/46.8M [00:00<?, ?B/s]

Опишем модель с трансформером

In [10]:
model_transformer = smp.create_model(
    arch="segformer",
    in_channels=3,
    classes=NUM_CLASSES
).to(DEVICE)

config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/87.3M [00:00<?, ?B/s]

In [11]:
criterion = smp.losses.DiceLoss(mode="multiclass")

optimizer_cnn = torch.optim.Adam(model_cnn.parameters(), lr=LR)
optimizer_transformer = torch.optim.Adam(model_transformer.parameters(), lr=LR)

Функция обучения

In [12]:
def train_epoch(model, optimizer):
    model.train()
    running_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        preds = model(imgs)
        loss  = criterion(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
    return running_loss / len(train_loader.dataset)

Функция валидации

In [13]:
@torch.no_grad()
def valid_epoch(model):
    model.eval()
    total_loss = 0.0
    total_prec = 0.0
    total_rec  = 0.0
    total_iou50 = 0.0
    n_samples = 0

    for imgs, masks in val_loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
        preds = model(imgs)

        loss = criterion(preds, masks)
        batch_size = imgs.size(0)
        n_samples += batch_size


        tp, fp, fn, tn = get_stats(
            preds.argmax(dim=1), masks,
            mode="multiclass",
            threshold=None,
            ignore_index=None,
            num_classes=NUM_CLASSES
        )
        total_loss += loss.item() * batch_size
        total_prec += precision(tp, fp, fn, tn).mean().item() / batch_size
        total_rec  += recall(tp, fp, fn, tn).mean().item() / batch_size
        total_iou50 += iou_score(tp, fp, fn, tn).mean().item() / batch_size



    return {
        "val_loss":  total_loss / n_samples,
        "precision": total_prec / n_samples,
        "recall":    total_rec  / n_samples,
        "mAP50":     total_iou50 / n_samples
    }


Опишем цикл обучения

In [14]:
def train_and_evaluate(model, optimizer):
    for epoch in range(1, NUM_EPOCHS+1):
        train_loss = train_epoch(model, optimizer)
        metrics = valid_epoch(model)
        print(
            f"Epoch {epoch:02d}  "
            f"train_loss={train_loss:.4f}  "
            f"val_loss={metrics['val_loss']:.4f}  "
            f"Prec={metrics['precision']:.4f}  "
            f"Rec={metrics['recall']:.4f}  "
            f"mAP50={metrics['mAP50']:.4f}  "
        )

In [14]:
print("============ Segmentation with CNN ============")
train_and_evaluate(model_cnn, optimizer_cnn)

Epoch 01  train_loss=0.2419  val_loss=0.2375  Prec=0.7482  Rec=0.7585  mAP50=0.7446  
Epoch 02  train_loss=0.2364  val_loss=0.2333  Prec=0.7438  Rec=0.7585  mAP50=0.7404  
Epoch 03  train_loss=0.2331  val_loss=0.2311  Prec=0.7528  Rec=0.7587  mAP50=0.7598  


In [15]:
print("============ Segmentation with transformer ============")
train_and_evaluate(model_transformer, optimizer_transformer)

Epoch 01  train_loss=0.2393  val_loss=0.2346  Prec=0.7292  Rec=0.7584  mAP50=0.7255  
Epoch 02  train_loss=0.2308  val_loss=0.2296  Prec=0.7354  Rec=0.7585  mAP50=0.7323  
Epoch 03  train_loss=0.2246  val_loss=0.2274  Prec=0.7300  Rec=0.7585  mAP50=0.7472  


Улучшение бейзлайна

In [16]:
data_transforms = T.Compose([
    T.Resize((256, 256)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.2),
    T.RandomRotation(degrees=20),
    T.ToTensor(),
])

In [17]:
train_ds = VOCSegmentation(
    root=DATA_ROOT,
    year="2012",
    image_set="train",
    download=False,
    transform=data_transforms,
    target_transform=mask_transform,
)

val_ds = VOCSegmentation(
    root=DATA_ROOT,
    year="2012",
    image_set="val",
    download=False,
    transform=data_transforms,
    target_transform=mask_transform,
)

In [18]:
print("============ Improved segmentation with CNN ============")
train_and_evaluate(model_cnn, optimizer_cnn)

Epoch 01  train_loss=0.2310  val_loss=0.2297  Prec=0.8341  Rec=0.8585  mAP50=0.8111  
Epoch 02  train_loss=0.2282  val_loss=0.2296  Prec=0.8481  Rec=0.8585  mAP50=0.8254  
Epoch 03  train_loss=0.2258  val_loss=0.2263  Prec=0.8554  Rec=0.8586  mAP50=0.8325  


In [19]:
print("============ Improved segmentation with transformer ============")
train_and_evaluate(model_transformer, optimizer_transformer)

Epoch 01  train_loss=0.2209  val_loss=0.2254  Prec=0.8325  Rec=0.8585  mAP50=0.8297  
Epoch 02  train_loss=0.2159  val_loss=0.2185  Prec=0.8389  Rec=0.8585  mAP50=0.8263  
Epoch 03  train_loss=0.2138  val_loss=0.2206  Prec=0.8310  Rec=0.8585  mAP50=0.8282  


Опишем имплементацию модели U-net

In [20]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=21, features=[64, 128, 256, 512]):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()

        for feature in features:
            self.downs.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, feature, 3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(feature, feature, 3, padding=1),
                    nn.ReLU(inplace=True),
                )
            )
            in_channels = feature

        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
            )
            self.ups.append(
                nn.Sequential(
                    nn.Conv2d(feature*2, feature, 3, padding=1),
                    nn.ReLU(inplace=True),
                    nn.Conv2d(feature, feature, 3, padding=1),
                    nn.ReLU(inplace=True),
                )
            )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features[-1], features[-1]*2, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(features[-1]*2, features[-1]*2, 3, padding=1),
            nn.ReLU(inplace=True),
        )
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]
        for idx in range(0, len(self.ups), 2):
            x = self.ups[idx](x)
            skip = skip_connections[idx//2]
            if x.shape != skip.shape:
                x = torch.nn.functional.interpolate(x, size=skip.shape[2:])
            x = torch.cat((skip, x), dim=1)
            x = self.ups[idx+1](x)
        return self.final_conv(x)

In [21]:
model_unet = UNet(in_channels=3, out_channels=NUM_CLASSES).to(DEVICE)
optimizer_unet = torch.optim.Adam(model_unet.parameters(), lr=LR)

Обучим модель

In [22]:
print("============ Segmentation with U-net ============")
train_and_evaluate(model_unet, optimizer_unet)

Epoch 01  train_loss=0.2488  val_loss=0.2487  Prec=0.8624  Rec=0.8587  mAP50=0.8579  
Epoch 02  train_loss=0.2451  val_loss=0.2487  Prec=0.8624  Rec=0.8587  mAP50=0.8579  
Epoch 03  train_loss=0.2476  val_loss=0.2487  Prec=0.8624  Rec=0.8587  mAP50=0.8579  


Опишем имплементацию SegFormer

In [17]:
class SimpleSegFormer(nn.Module):
    def __init__(self, in_channels=3, out_channels=21, embed_dim=64, num_heads=4):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_channels, embed_dim, kernel_size=16, stride=16)
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4),
            num_layers=2
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embed_dim, embed_dim, kernel_size=16, stride=16),
            nn.ReLU(inplace=True),
            nn.Conv2d(embed_dim, out_channels, kernel_size=1)
        )

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        x = x.flatten(2).permute(2, 0, 1)
        x = self.transformer(x)
        x = x.permute(1, 2, 0).reshape(B, C, H, W)
        x = self.decoder(x)
        return x

In [18]:
model_segformer = SimpleSegFormer(in_channels=3, out_channels=NUM_CLASSES).to(DEVICE)
optimizer_segformer = torch.optim.Adam(model_segformer.parameters(), lr=LR)

In [19]:
print("============ Segmentation with SegFormer ============")
train_and_evaluate(model_segformer, optimizer_segformer)

Epoch 01  train_loss=0.2473  val_loss=0.2454  Prec=0.8597  Rec=0.8586  mAP50=0.8559  
Epoch 02  train_loss=0.2449  val_loss=0.2467  Prec=0.8600  Rec=0.8587  mAP50=0.8562  
Epoch 03  train_loss=0.2444  val_loss=0.2459  Prec=0.8570  Rec=0.8585  mAP50=0.8533  
