## Загружаем библиотеки

In [None]:
import os
import glob
import random
from pathlib import Path
from typing import List
from datetime import datetime

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

from sklearn.model_selection import train_test_split
import segmentation_models_pytorch as smp


SEED = 1975
DEVICE = 'cuda'

ARCH = 'DeepLabV3Plus'
ENCODER = 'tu-xception71'
VERSION = '1807_resize_640x864'

model = smp.DeepLabV3Plus(encoder_name=ENCODER, encoder_weights='imagenet', classes=4)

## Формируем список файлов

In [None]:
ROOT = Path(".")

train_image_path = ROOT / "train/images"
train_mask_path = ROOT / "train/mask"
test_image_path = ROOT / "test"

ALL_IMAGES = sorted(train_image_path.glob("*.png"))
ALL_MASKS = sorted(train_mask_path.glob("*.png"))

assert len(ALL_IMAGES) == len(ALL_MASKS)

print(len(ALL_IMAGES))

## Фиксируем ГСЧ

In [None]:
def seed_everything(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

seed_everything(SEED)

## Подготавливаем трансформации для обучения и валидации

In [None]:
transform_train = A.Compose([
    A.Resize(640, 864),
    A.CoarseDropout (max_holes=6, max_height=0.1, max_width=0.1, min_holes=2, min_height=0.01, min_width=0.01, fill_value=0, mask_fill_value=0, p=0.5),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.05),
    A.GaussNoise(p=0.2),
    A.OneOf([
        A.MotionBlur(p=0.2),
        A.MedianBlur(blur_limit=3, p=0.1),
        A.Blur(blur_limit=3, p=0.1),
    ], p=0.2),
    A.ShiftScaleRotate(shift_limit=0.15, scale_limit=0.2, rotate_limit=12, p=0.50),
    A.OneOf([
      A.RandomFog(fog_coef_lower=0.1, fog_coef_upper=0.6, p=0.2),
      A.RandomRain(p=0.2),
      A.RandomShadow(p=0.2),
    ], p=0.3),
    A.OneOf([
        A.OpticalDistortion(p=0.3),
        A.GridDistortion(p=.1),
        A.PiecewiseAffine(p=0.3),
    ], p=0.2),
    A.OneOf([
        A.ToGray(p=0.1),
        A.ToSepia(p=0.1),
    ], p=0.1),
    A.OneOf([
        A.CLAHE(clip_limit=2),
        A.Sharpen(),
        A.Emboss(),
        A.RandomBrightnessContrast(),            
    ], p=0.3),
    A.HueSaturationValue(p=0.3),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

transform_val = A.Compose([
    A.Resize(640, 864),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

## Описываем класс для работы с датасетом

In [None]:
class SegmentationDataset(Dataset):
    def __init__(
        self,
        images: List[Path],
        masks: List[Path] = None,
        transforms=None
    ) -> None:
        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> dict:
        image_path = self.images[idx]
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w, _ = image.shape
        
        result = {"image": image, "hw": [h, w]}
        
        if self.masks is not None:
            mask = cv2.imread(str(self.masks[idx]), 0)
            mask[mask ==  6] = 1
            mask[mask ==  7] = 2
            mask[mask == 10] = 3
            result["mask"] = mask
        
        if self.transforms is not None:
            result = self.transforms(**result)
        
        result["filename"] = image_path.name

        return result

## Определяем даталоадеры, функции потерь, метрику.

In [None]:
all_images = np.asarray(ALL_IMAGES)
all_masks  = np.asarray(ALL_MASKS)

dataset_train = SegmentationDataset(all_images, masks=all_masks, transforms=transform_train)
dataset_val   = SegmentationDataset(all_images, masks=all_masks, transforms=transform_val)

loader_train = DataLoader(
  dataset_train,
  batch_size=6,
  shuffle=True,
  num_workers=6,
  drop_last=True,
)

loader_val = DataLoader(
  dataset_val,
  batch_size=8,
  shuffle=False,
  num_workers=4,
  drop_last=False,
)

model.train()
model.to(DEVICE)

dice      = smp.losses.DiceLoss(mode='multiclass', classes=[1, 2, 3], log_loss=False, from_logits=True, smooth=1.0, eps=1e-07)
criterion = torch.nn.BCEWithLogitsLoss()

iou       = smp.losses.JaccardLoss(mode='multiclass', classes=[1, 2, 3], log_loss=False, smooth=1.0)

## Предварительное обучение для инициализации модели

In [None]:
for param in model.parameters():
    param.requires_grad = False

for param in model.segmentation_head.parameters():
    param.requires_grad = True

optimizer = torch.optim.Adam(model.segmentation_head.parameters(), lr=1e-5, weight_decay=1e-4)
print('start at', datetime.now().strftime("%H:%M:%S"))

for batch in loader_train:
    model.zero_grad()
    pred = model.forward(batch['image'].to(DEVICE))
    mask = batch['mask'].to(DEVICE)

    dc = dice(pred, mask.long())
    y_pred = (pred.argmax(dim=1) > 0).float()
    y_true = (mask > 0).float()
    bce = criterion(y_pred, y_true)

    loss = dc + bce * 0.3

    loss.backward()
    optimizer.step()


torch.cuda.empty_cache()

print('done at', datetime.now().strftime("%H:%M:%S"))

## Обучаем только декодер + голову сегментации

In [None]:
# head + decoder train

for param in model.parameters():
    param.requires_grad = False

for param in list(model.decoder.parameters()) + list(model.segmentation_head.parameters()):
    param.requires_grad = True

optimizer = torch.optim.Adam(list(model.decoder.parameters()) + list(model.segmentation_head.parameters()),
                             lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
                                                                 T_0=5, 
                                                                 T_mult=1, 
                                                                 eta_min=1e-6,
                                                                 last_epoch=-1) 
print('start at', datetime.now().strftime("%H:%M:%S"))
best_metric = 0.0
for epoch in range(20):
    losses     = []
    losses_bce = []
    losses_dc  = []

    model.train()
    torch.cuda.empty_cache()
    for batch in loader_train:
        model.zero_grad()
        pred = model.forward(batch['image'].to(DEVICE))
        mask = batch['mask'].to(DEVICE)
        
        dc = dice(pred, mask.long())
        y_pred = (pred.argmax(dim=1) > 0).float()
        y_true = (mask > 0).float()
        bce = criterion(y_pred, y_true)
        
        loss = dc + bce * 0.3
        
        losses.append(loss.item())
        losses_bce.append(bce.item())
        losses_dc.append(dc.item())
        loss.backward()
        optimizer.step()
    scheduler.step(np.mean(losses))
    print(datetime.now().strftime("%H:%M:%S"),
          f'epoch {epoch:02d} loss {np.mean(losses):.3f} bce {np.mean(losses_bce):.3f} '
          f'dice {np.mean(losses_dc):.3f}')

    val_losses    = []
    val_losses_bce = []
    val_losses_dc = []
    metrics       = []
    model.eval()
    torch.cuda.empty_cache()
    with torch.no_grad():
        for batch in loader_val:
            pred = model.forward(batch['image'].to(DEVICE))
            mask = batch['mask'].to(DEVICE)

            dc = dice(pred, mask.long())
            y_pred = (pred.argmax(dim=1) > 0).float()
            y_true = (mask > 0).float()
            bce = criterion(y_pred, y_true)

            loss = dc + bce * 0.3
            metric = 1 - iou(pred, mask.long()).item()

            val_losses.append(loss.item())
            val_losses_bce.append(bce.item())
            val_losses_dc.append(dc.item())

            metrics.append(metric)

    if best_metric <= np.mean(metrics):
        best_metric = np.mean(metrics)
        torch.save(model.state_dict(), f"{ARCH}_{ENCODER}_{VERSION}_1.pth")
        torch.save(model.state_dict(), f"{ARCH}_{ENCODER}_{VERSION}_{best_metric:.4f}.pth")

    print(datetime.now().strftime("%H:%M:%S"), f'valid   loss {np.mean(val_losses):.3f} bce {np.mean(val_losses_bce):.3f} '
          f'dice {np.mean(val_losses_dc):.3f} metric {np.mean(metrics):.3f}')


torch.cuda.empty_cache()
print('done at', datetime.now().strftime("%H:%M:%S"))

## Основной цикл обучения (4 дня)

In [None]:
loader_train = DataLoader(
  dataset_train,
  batch_size=2,
  shuffle=True,
  num_workers=2,
  drop_last=True,
)

loader_val = DataLoader(
  dataset_val,
  batch_size=4,
  shuffle=False,
  num_workers=4,
  drop_last=False,
)

model.load_state_dict(torch.load(f"{ARCH}_{ENCODER}_{VERSION}_1.pth", map_location='cpu'))
seed_everything(SEED)

model.to(DEVICE)

for param in model.parameters():
    param.requires_grad = True

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 
                                                                 T_0=5, 
                                                                 T_mult=1, 
                                                                 eta_min=1e-6,
                                                                 last_epoch=-1) 

accumulation_steps = 8

print('start at', datetime.now().strftime("%H:%M:%S"))
best_metric = 0.0
for epoch in range(80):
    losses     = []
    losses_bce = []
    losses_dc  = []

    model.train()
    torch.cuda.empty_cache()
    for i, batch in enumerate(loader_train, start=1):
        pred = model.forward(batch['image'].to(DEVICE))
        mask = batch['mask'].to(DEVICE)
        
        dc = dice(pred, mask.long())
        y_pred = (pred.argmax(dim=1) > 0).float()
        y_true = (mask > 0).float()
        bce = criterion(y_pred, y_true)
        
        loss = (dc + bce * 0.3) / accumulation_steps
        
        losses.append(loss.item())
        losses_bce.append(bce.item())
        losses_dc.append(dc.item())
        loss.backward()
        
        if i % accumulation_steps == 0:
            optimizer.step()
            model.zero_grad()
        
    scheduler.step(np.mean(losses))
    print(datetime.now().strftime("%H:%M:%S"),
          f'epoch {epoch:02d} loss {np.mean(losses):.3f} bce {np.mean(losses_bce):.3f} '
          f'dice {np.mean(losses_dc):.3f}')

    val_losses     = []
    val_losses_bce = []
    val_losses_dc  = []
    metrics        = []
    model.eval()
    torch.cuda.empty_cache()
    with torch.no_grad():
        for batch in loader_val:
            pred = model.forward(batch['image'].to(DEVICE))
            mask = batch['mask'].to(DEVICE)

            dc = dice(pred, mask.long())
            y_pred = (pred.argmax(dim=1) > 0).float()
            y_true = (mask > 0).float()
            bce = criterion(y_pred, y_true)

            loss = dc + bce * 0.3
            metric = 1 - iou(pred, mask.long()).item()

            val_losses.append(loss.item())
            val_losses_bce.append(bce.item())
            val_losses_dc.append(dc.item())

            metrics.append(metric)

    if best_metric <= np.mean(metrics):
        best_metric = np.mean(metrics)
        torch.save(model.state_dict(), f"{ARCH}_{ENCODER}_{VERSION}_2.pth")
        torch.save(model.state_dict(), f"{ARCH}_{ENCODER}_{VERSION}_{best_metric:.4f}.pth")

    print(datetime.now().strftime("%H:%M:%S"), f'valid   loss {np.mean(val_losses):.3f} bce {np.mean(val_losses_bce):.3f} '
          f'dice {np.mean(val_losses_dc):.3f} metric {np.mean(metrics):.3f}')


torch.cuda.empty_cache()
print('done at', datetime.now().strftime("%H:%M:%S"))

## Формируем маски для тестового датасета

In [None]:
transform_test = A.Compose([
    A.Resize(640, 864),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

model.load_state_dict(torch.load(f"{ARCH}_{ENCODER}_{VERSION}_2.pth", map_location='cpu'))
model.eval()
model.to(DEVICE)


with torch.no_grad():
    torch.cuda.empty_cache()
    for i, image_path in enumerate(sorted(test_image_path.glob("*.png"))):
        if i % 100 == 0:
            print(f"{i:04d} {image_path}")

        img = cv2.imread(str(image_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w, _ = img.shape
        
        img  = transform_test(image=img)['image'].unsqueeze(0).to(DEVICE)
        pred = model.forward(img)
        mask = torch.argmax(pred, dim=1).squeeze().cpu().numpy()

        mask[mask == 1] = 6
        mask[mask == 2] = 7
        mask[mask == 3] = 10
        
        mask = cv2.resize(mask, (w, h), 0, 0, interpolation=cv2.INTER_NEAREST)
                
        cv2.imwrite('mask/' + image_path.name, mask)

print('done')