# Нулевой этап обучения

## Заимствуем идею из трансформеров и на нулевом этапе будем обучать модель восстанавливать замаскированные участки, затем заменим выходной слой и будем обучать задаче сегментации на уже требуемые классы.

Маскировать будем фрагментами готовых масок, оценивать через MSELoss.

In [2]:
%%capture

# Загрузка датасета
try:
    import google.colab
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    !cp /content/drive/MyDrive/crimea/train_crimea.zip /content/
    !unzip -oq /content/train_crimea.zip

    !python -m pip install --upgrade pip
    !pip install -U segmentation_models_pytorch

    COLAB      = True
    ROOT       = '/content/'
    ROOT_DRIVE = '/content/drive/MyDrive/crimea/'

except:
    COLAB      = False
    ROOT       = './'
    ROOT_DRIVE = './'


In [3]:
!nvidia-smi

Sun Nov 20 10:08:25 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
import os
import glob
import random
import warnings
from pathlib import Path
from typing import List
from datetime import datetime

import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

SEED = 1
SIZE = 512
STAGE = 0
DEVICE = 'cuda'
EPOCHS = 30
MAX_LR = 1e-4
BATCHSIZE = 12

ARCH = 'DeepLabV3Plus'
ENCODER = 'tu-xception71'
VERSION = f'stage_{STAGE}'

warnings.filterwarnings("ignore", 'User provided device_type of \'cuda\', but CUDA is not available. Disabling')
warnings.filterwarnings("ignore", 'torch.cuda.amp.GradScaler is enabled, but CUDA is not available.  Disabling.')

model = smp.DeepLabV3Plus(encoder_name=ENCODER,
                          encoder_weights='imagenet', classes=3) # RGB
scaler = torch.cuda.amp.GradScaler()

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth" to /root/.cache/torch/hub/checkpoints/tf_xception_71-8eec7df1.pth


In [5]:
ROOT = Path(ROOT)

train_image_path = ROOT / "images"
train_mask_path = ROOT / "masks"

ALL_IMAGES = sorted(train_image_path.glob("*.png"))
ALL_MASKS = sorted(train_mask_path.glob("*.png"))

assert len(ALL_IMAGES) == len(ALL_MASKS)

print(len(ALL_IMAGES))

3500


In [6]:
def seed_everything(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark     = False

seed_everything(SEED + STAGE)

In [7]:
transform_train = A.Compose([
    A.Resize(SIZE, SIZE, interpolation=cv2.INTER_AREA),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

In [8]:
class SegmentationDataset(Dataset):
    def __init__(
        self,
        images: List[Path],
        masks: List[Path],
        transforms: A.Compose,
        mask_count:int=8
    ) -> None:
        self.images = images
        self.masks = masks
        self.transforms = transforms
        self.mask_count = mask_count

    def __len__(self) -> int:
        return len(self.images)

    def __getitem__(self, idx: int) -> dict:
        image_path = self.images[idx]
        image = cv2.imread(str(image_path))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        h, w, _ = image.shape
        
        result = {"image":image, "hw": [h, w]}
        
        mask = cv2.imread(str(self.masks[idx]), 0)
        result["mask"] = mask

        result = self.transforms(**result)

        image = result['image']
        mask  = result['mask']

        result['image_target'] = image

        inv_mask = np.ones_like(mask, dtype='uint8')
        inv_mask[mask > 0] = 0

        for _ in range(self.mask_count):
            r1 = np.random.randint(0, h)
            r2 = np.random.randint(1, h // self.mask_count)
            inv_mask[r1 : r1 + r2, :] = 1

            c1 = np.random.randint(0, w)
            c2 = np.random.randint(1, w // self.mask_count)
            inv_mask[:, c1 : c1 + c2] = 1

        image = image * inv_mask[np.newaxis, ...]

        result["mask"]  = 1 - inv_mask
        result["image"] = image
        
        result["filename"] = image_path.name

        return result

In [9]:
all_images = np.asarray(ALL_IMAGES)
all_masks  = np.asarray(ALL_MASKS)

dataset_train = SegmentationDataset(all_images, masks=all_masks, transforms=transform_train)

In [10]:
model.train()
model.to(DEVICE)

criterion = torch.nn.MSELoss()

## Start train

In [11]:
loader_train = DataLoader(
  dataset_train,
  batch_size=BATCHSIZE,
  shuffle=True,
  num_workers=2,
  drop_last=True,
)

seed_everything(SEED + STAGE)

model.to(DEVICE)

for param in model.parameters():
    param.requires_grad = True

optimizer = torch.optim.AdamW(model.parameters(), lr=MAX_LR, weight_decay=1e-6)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, epochs=EPOCHS, max_lr=MAX_LR,
                                                div_factor=10.0, final_div_factor=5.0,
                                                steps_per_epoch=1)

best_loss = 1e9
print('start at', datetime.now().strftime("%H:%M:%S"))
for epoch in range(EPOCHS):
    losses     = []

    model.train()
    torch.cuda.empty_cache()

    for i, batch in enumerate(loader_train, start=1):
        mask = batch['mask'].to(DEVICE)
        image_target = batch['image_target'].to(DEVICE)
        with torch.cuda.amp.autocast(enabled=True):
            pred = model.forward(batch['image'].to(DEVICE))
            
            y_pred = (pred * mask.unsqueeze(1)).float()
            y_true = (image_target * mask.unsqueeze(1)).float()
            loss = criterion(y_pred, y_true)
        
        losses.append(loss.item())

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
    scheduler.step()
    print(datetime.now().strftime("%H:%M:%S"),
          f'epoch {epoch:02d} loss {np.mean(losses):.3f} lr={optimizer.param_groups[0]["lr"]:.8f}')

    if best_loss >= np.mean(losses):
        best_loss = np.mean(losses)
        torch.save(model.state_dict(), f"{ROOT_DRIVE}{ARCH}_{ENCODER}_{VERSION}.pth")

torch.cuda.empty_cache()
print('done at', datetime.now().strftime("%H:%M:%S"))

start at 10:38:36
10:43:18 epoch 00 loss 54.206 lr=0.00001343
10:47:59 epoch 01 loss 0.647 lr=0.00002318
10:52:40 epoch 02 loss 0.162 lr=0.00003778
10:57:19 epoch 03 loss 0.093 lr=0.00005500
11:01:59 epoch 04 loss 0.071 lr=0.00007222
11:06:37 epoch 05 loss 0.059 lr=0.00008682
11:11:15 epoch 06 loss 0.050 lr=0.00009657
11:15:54 epoch 07 loss 0.046 lr=0.00010000
11:20:33 epoch 08 loss 0.044 lr=0.00009945
11:25:12 epoch 09 loss 0.042 lr=0.00009782
11:29:50 epoch 10 loss 0.041 lr=0.00009515
11:34:29 epoch 11 loss 0.039 lr=0.00009149
11:39:07 epoch 12 loss 0.038 lr=0.00008692
11:43:45 epoch 13 loss 0.038 lr=0.00008155
11:48:25 epoch 14 loss 0.037 lr=0.00007550
11:53:04 epoch 15 loss 0.036 lr=0.00006890
11:57:44 epoch 16 loss 0.036 lr=0.00006190
12:02:22 epoch 17 loss 0.036 lr=0.00005466
12:07:01 epoch 18 loss 0.035 lr=0.00004734
12:11:40 epoch 19 loss 0.034 lr=0.00004010
12:16:19 epoch 20 loss 0.033 lr=0.00003310
12:20:59 epoch 21 loss 0.034 lr=0.00002650
12:25:37 epoch 22 loss 0.033 lr=0.0