In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from matplotlib import pyplot as plt
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [8]:
from dataset import Dataset
from models import UNet
from Trainer import Trainer
from utils import accuracy
from utils import save_predictions_as_imgs

In [4]:
train_folder = 'train'
batch_size = 1
device = 'cuda'

train_transforms = A.Compose(
    [
        A.Resize(height=512, width=512),
        #A.Rotate(limit=35, p=1),
        #A.HorizontalFlip(p=0.5),
        #A.VerticalFlip(p=0.5),
        #A.ColorJitter(),
        # Это реальные среднее и дисперсия выборки, но такая нормализация не сильно помогает
        #tensor([0.8418, 0.8288, 0.8200]), tensor([0.2174, 0.2178, 0.2234]))
        #A.Normalize(mean=([0.8418, 0.8288, 0.8200]), std=([0.2174, 0.2178, 0.2234])),
        #A.Normalize(),
        ToTensorV2()
    ])

train_dataset = Dataset(train_folder, train_transforms)
train_loader =  torch.utils.data.DataLoader(train_dataset, batch_size, pin_memory=True, shuffle=True)

In [5]:
model = UNet(n_filters=16)
criterion = nn.BCEWithLogitsLoss()
metric = {'name' : 'accuracy', 'func' : accuracy}
config = {
    'lr': 1e-3,
    'epochs': 50,
    'early_stopping': 5
}
trainer = Trainer(model, criterion, metric, config)

In [6]:
trainer.fit(train_loader)

Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:33<00:00,  1.54s/it]
Epoch 1:   5%|███▎                                                                      | 1/22 [00:00<00:02,  7.16it/s]

Epoch 0, loss: -3.117213314229792,                   accuracy: 59.12593494762074


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  4.00it/s]
Epoch 2:   5%|███▎                                                                      | 1/22 [00:00<00:03,  5.81it/s]

Epoch 1, loss: -5.30955069038001,                   accuracy: 63.841039484197445


Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.68it/s]
Epoch 3:   5%|███▎                                                                      | 1/22 [00:00<00:03,  6.51it/s]

Epoch 2, loss: -7.146934200416911,                   accuracy: 62.44257146661932


Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.97it/s]
Epoch 4:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 3, loss: -9.601265945217826,                   accuracy: 65.63053131103516


Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.82it/s]
Epoch 5:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 4, loss: -11.58519631895152,                   accuracy: 67.30107394131747


Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.93it/s]
Epoch 6:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 5, loss: -13.592935499819843,                   accuracy: 65.72936664928089


Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.98it/s]
Epoch 7:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 6, loss: -16.1587425253608,                   accuracy: 66.55374006791548


Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  4.01it/s]
Epoch 8:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 7, loss: -18.17015093023127,                   accuracy: 67.1845869584517


Epoch 8: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  4.02it/s]
Epoch 9:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 8, loss: -19.76015875285322,                   accuracy: 65.67129655317827


Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.82it/s]

Epoch 9, loss: -23.444617466493085,                   accuracy: 66.84811332009055





In [11]:
model.load_state_dict(torch.load(os.path.join(Trainer.CHECKPOINTS_PATH, 'weights.pth')))
save_predictions_as_imgs(train_loader, model)