# Первый этап обучения

## Решаем задачу сегментации на разрешении изображений 512x512

In [None]:
%%capture

# Загрузка датасета
try:
    import google.colab
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    !cp /content/drive/MyDrive/crimea/train_crimea.zip /content/
    !unzip -oq /content/train_crimea.zip

    !cp /content/drive/MyDrive/crimea/test_crimea.zip /content/
    !unzip -oq /content/test_crimea.zip

    !mkdir /content/test_masks

    !python -m pip install --upgrade pip
    !pip install -U segmentation_models_pytorch

    COLAB      = True
    ROOT       = '/content/'
    ROOT_DRIVE = '/content/drive/MyDrive/crimea/'

except:
    COLAB      = False
    ROOT       = './'
    ROOT_DRIVE = './'


In [None]:
!nvidia-smi

Sun Nov 20 13:38:58 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P8     8W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import os
import glob
import random
import warnings
from pathlib import Path
from typing import List
from datetime import datetime

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from sklearn.model_selection import train_test_split


SEED = 1
SIZE = 512
STAGE = 1
DEVICE = 'cuda'
EPOCHS = 15
MAX_LR = 1e-4
BATCHSIZE = 12

ARCH = 'DeepLabV3Plus'
ENCODER = 'tu-xception71'
PREVIOUS= f'stage_{STAGE - 1}'
VERSION = f'stage_{STAGE}'

warnings.filterwarnings("ignore", 'User provided device_type of \'cuda\', but CUDA is not available. Disabling')
warnings.filterwarnings("ignore", 'torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.')

### Заменяем классификационный слой

In [None]:
model = smp.DeepLabV3Plus(encoder_name=ENCODER,
                          encoder_weights=None, classes=3) # RGB

model.load_state_dict(torch.load(f"{ROOT_DRIVE}{ARCH}_{ENCODER}_{PREVIOUS}.pth",
                                 map_location='cpu'))


model.segmentation_head[0] = nn.Conv2d(256, 4,
                                       kernel_size=(1, 1),
                                       stride=(1, 1))
torch.nn.init.xavier_uniform_(model.segmentation_head[0].weight)
model.segmentation_head[0].bias.data.fill_(0.0)

scaler = torch.cuda.amp.GradScaler()

In [None]:
ROOT = Path(ROOT)

train_image_path = ROOT / "images"
train_mask_path = ROOT / "masks"

ALL_IMAGES = sorted(train_image_path.glob("*.png"))
ALL_MASKS = sorted(train_mask_path.glob("*.png"))

assert len(ALL_IMAGES) == len(ALL_MASKS)

print(len(ALL_IMAGES))

3500


In [None]:
def seed_everything(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

seed_everything(SEED + STAGE)

In [None]:
transform_train = A.Compose([
    A.ShiftScaleRotate(shift_limit  = 0.05,
                       scale_limit  = 0.05,
                       rotate_limit = 15,
                       p=1.0),
    A.Resize(SIZE, SIZE, interpolation=cv2.INTER_AREA),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.PiecewiseAffine(p=0.5),
    A.OneOf([
        A.CLAHE(clip_limit=2, p=1),
        A.Sharpen(p=1),
        A.Emboss(p=1),
    ], p=0.3),
    A.ColorJitter(hue=0.03, p=1.0),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

transform_val = A.Compose([
    A.Resize(SIZE, SIZE, interpolation=cv2.INTER_AREA),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

In [None]:
class SegmentationDataset(Dataset):
    def __init__(
        self,
        images: List[Path],
        masks: List[Path] = None,
        transforms=None
    ) -> None:
        self.images = images
        self.masks = masks
        self.transforms = transforms

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> dict:
        image_path = self.images[idx]
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w, _ = image.shape
        
        result = {"image": image, "hw": [h, w]}
        
        if self.masks is not None:
            mask = cv2.imread(str(self.masks[idx]), 0)
            result["mask"] = mask
        
        if self.transforms is not None:
            result = self.transforms(**result)
        
        result["filename"] = image_path.name

        return result

In [None]:
all_images = np.asarray(ALL_IMAGES)
all_masks  = np.asarray(ALL_MASKS)

train_img, test_img, train_mask, test_mask = train_test_split(all_images, all_masks,
                                                              random_state=SEED, test_size=0.05)

dataset_train = SegmentationDataset(train_img, masks=train_mask, transforms=transform_train)
dataset_val   = SegmentationDataset(test_img, masks=test_mask, transforms=transform_val)

In [None]:
model.train()
model.to(DEVICE)

dice      = smp.losses.DiceLoss(mode='multiclass', classes=[1, 2, 3],
                                log_loss=False, from_logits=True,
                                smooth=1.0, eps=1e-08)
criterion = torch.nn.BCEWithLogitsLoss()

iou       = smp.losses.JaccardLoss(mode='multiclass',
                                   classes=[1, 2, 3], log_loss=False,
                                   smooth=1.0)

## Pretrain head

In [None]:
for param in model.parameters():
    param.requires_grad = False

for param in model.segmentation_head.parameters():
    param.requires_grad = True

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5, weight_decay=1e-6)

loader_train = DataLoader(
  dataset_train,
  batch_size=8,
  shuffle=True,
  num_workers=2,
  drop_last=True,
)

print('start at', datetime.now().strftime("%H:%M:%S"))

for i, batch in enumerate(loader_train, start=1):
    mask = batch['mask'].to(DEVICE)
    with torch.cuda.amp.autocast(enabled=True):
        pred = model.forward(batch['image'].to(DEVICE))

        dc = dice(pred, mask.long())
        y_pred = (pred.argmax(dim=1) > 0).float()
        y_true = (mask > 0).float()
        bce = criterion(y_pred, y_true)

        loss = dc + bce * 0.1

    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

print('done at', datetime.now().strftime("%H:%M:%S"))

start at 13:41:28
done at 13:50:33


## Start train

In [None]:
loader_train = DataLoader(
  dataset_train,
  batch_size=BATCHSIZE,
  shuffle=True,
  num_workers=2,
  drop_last=True,
)

loader_val = DataLoader(
  dataset_val,
  batch_size=BATCHSIZE,
  shuffle=False,
  num_workers=2,
  drop_last=False,
)

seed_everything(SEED + STAGE)

for param in model.parameters():
    param.requires_grad = True

torch.cuda.empty_cache()
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=MAX_LR,
                              weight_decay=1e-6)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                epochs=EPOCHS,
                                                max_lr=MAX_LR,
                                                div_factor=25.0,
                                                final_div_factor=10.0,
                                                steps_per_epoch=1)

print('start at', datetime.now().strftime("%H:%M:%S"))
best_metric = 0.0
best_cnt    = 0
for epoch in range(EPOCHS):
    losses     = []
    losses_bce = []
    losses_dc  = []

    model.train()
    torch.cuda.empty_cache()

    for i, batch in enumerate(loader_train, start=1):
        mask = batch['mask'].to(DEVICE)
        with torch.cuda.amp.autocast(enabled=True):
            pred = model.forward(batch['image'].to(DEVICE))
            
            dc = dice(pred, mask.long())
            y_pred = (pred.argmax(dim=1) > 0).float()
            y_true = (mask > 0).float()
            bce = criterion(y_pred, y_true)
            
            loss = dc + bce * 0.1
        
        losses.append(loss.item())
        losses_bce.append(bce.item())
        losses_dc.append(dc.item())

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
    scheduler.step()
    print(datetime.now().strftime("%H:%M:%S"),
          f'epoch {epoch:02d} loss {np.mean(losses):.3f} bce {np.mean(losses_bce):.3f} '
          f'dice {np.mean(losses_dc):.3f} lr={optimizer.param_groups[0]["lr"]:.8f}')

    val_losses     = []
    val_losses_bce = []
    val_losses_dc  = []
    metrics        = []
    model.eval()
    torch.cuda.empty_cache()
    with torch.no_grad():
        for batch in loader_val:
            pred = model.forward(batch['image'].to(DEVICE))
            mask = batch['mask'].to(DEVICE)

            dc = dice(pred, mask.long())
            y_pred = (pred.argmax(dim=1) > 0).float()
            y_true = (mask > 0).float()
            bce = criterion(y_pred, y_true)

            loss = dc + bce * 0.2
            metric = 1 - iou(pred, mask.long()).item()

            val_losses.append(loss.item())
            val_losses_bce.append(bce.item())
            val_losses_dc.append(dc.item())

            metrics.append(metric)

    print(datetime.now().strftime("%H:%M:%S"), f'valid    loss {np.mean(val_losses):.3f} bce {np.mean(val_losses_bce):.3f} '
          f'dice {np.mean(val_losses_dc):.3f} metric {np.mean(metrics):.3f}')

    best_cnt += 1
    if best_metric <= np.mean(metrics):
        best_cnt    = 0
        best_metric = np.mean(metrics)
        torch.save(model.state_dict(), f"{ROOT_DRIVE}{ARCH}_{ENCODER}_{VERSION}.pth")
        print('Save model')
    if best_cnt > 3:
        best_cnt = 0
        model.load_state_dict(torch.load(f"{ROOT_DRIVE}{ARCH}_{ENCODER}_{VERSION}.pth", map_location='cpu'))
        print('Reload model')


torch.cuda.empty_cache()
print('done at', datetime.now().strftime("%H:%M:%S"))

start at 13:50:35
14:00:16 epoch 00 loss 0.988 bce 0.853 dice 0.902 lr=0.00002207
14:00:27 valid    loss 1.020 bce 0.850 dice 0.850 metric 0.089
Save model
14:10:04 epoch 01 loss 0.645 bce 0.724 dice 0.573 lr=0.00006268
14:10:14 valid    loss 0.513 bce 0.683 dice 0.377 metric 0.470
Save model
14:20:03 epoch 02 loss 0.425 bce 0.678 dice 0.357 lr=0.00009525
14:20:13 valid    loss 0.437 bce 0.673 dice 0.302 metric 0.556
Save model
14:29:57 epoch 03 loss 0.375 bce 0.673 dice 0.307 lr=0.00009944
14:30:08 valid    loss 0.412 bce 0.671 dice 0.278 metric 0.588
Save model
14:39:52 epoch 04 loss 0.356 bce 0.672 dice 0.289 lr=0.00009507
14:40:03 valid    loss 0.393 bce 0.669 dice 0.259 metric 0.609
Save model
14:49:44 epoch 05 loss 0.337 bce 0.671 dice 0.270 lr=0.00008671
14:49:55 valid    loss 0.386 bce 0.669 dice 0.253 metric 0.618
Save model
14:59:33 epoch 06 loss 0.326 bce 0.670 dice 0.259 lr=0.00007510
14:59:44 valid    loss 0.371 bce 0.668 dice 0.238 metric 0.639
Save model
15:09:26 epoch 0

## Делаем прогноз

После первого этапа результат будет не очень хорошим, но он должен подтвердить верность направления.

In [None]:
model.load_state_dict(torch.load(f"{ROOT_DRIVE}{ARCH}_{ENCODER}_{VERSION}.pth",
                                 map_location='cpu'))
model.eval()
model.to(DEVICE)

transform_test = A.Compose([
    A.Resize(SIZE, SIZE, interpolation=cv2.INTER_AREA),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

test_image_path = Path(ROOT) / "test"

with torch.no_grad():
    torch.cuda.empty_cache()
    for i, image_path in enumerate(sorted(test_image_path.glob("*.png"))):
        if i % 100 == 0:
            print(f"{i:04d} {image_path}")

        img = cv2.imread(str(image_path))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w, _ = img.shape
        
        img  = transform_test(image=img)['image'].unsqueeze(0).to(DEVICE)
        pred = model.forward(img)
        mask = torch.argmax(pred, dim=1).squeeze().cpu().numpy()
        mask = cv2.resize(mask, (w, h), 0, 0, interpolation=cv2.INTER_NEAREST)
                
        cv2.imwrite(f'{str(ROOT)}/test_masks/{image_path.name}', mask)

print('done')

0000 /content/test/00388128-357d-44bb-a630-5e8856e4dcb0.png
0100 /content/test/0c2d854d-01e5-4b6e-ba78-1b443edcd784.png
0200 /content/test/1af2e8a6-c3e3-4e80-b6b9-1ef8d693a57f.png
0300 /content/test/2aadb686-e5ec-4c3d-bcb1-b76e3030c62e.png
0400 /content/test/3aa52680-b2f0-487a-b6c7-e38050a4a04e.png
0500 /content/test/4ac6f8a4-957b-4be7-aee6-f40019dbe356.png
0600 /content/test/58b3ccd9-4c6d-477b-8dd0-6b09dfe73d5d.png
0700 /content/test/619d426b-4a34-4d46-ad9e-45a8ece730dd.png
0800 /content/test/6f431a1c-0a6c-480d-bffd-629b0fbd7a0d.png
0900 /content/test/7bec62bf-f5ef-4d0e-b525-0f30aca7f60f.png
1000 /content/test/8b3e4e4c-3499-4fad-9f87-02073fed2fd2.png
1100 /content/test/9b793582-e744-450c-a142-e34d0a7cb2ec.png
1200 /content/test/b32baacd-bf29-49e9-b2eb-9a718ea4496d.png
1300 /content/test/bf3e649d-db84-433e-9f4c-82b7f040c226.png
1400 /content/test/de171555-fcde-48e1-9d11-b5f16a3997a8.png
done


In [None]:
if COLAB:
    %cd /content/test_masks/
    !zip -q ../test_masks_1.zip *.png
    !cp ../test_masks_1.zip /content/drive/MyDrive/crimea/

/content/test_masks
