In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from matplotlib import pyplot as plt
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import transforms

import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

In [3]:
from dataset import Dataset
from models import UNet
from Trainer import Trainer
from utils import accuracy

In [19]:
train_folder = 'train'
batch_size = 1
device = 'cuda'

train_transforms = A.Compose(
    [
        A.Resize(height=512, width=512),
        #A.Rotate(limit=35, p=1),
        #A.HorizontalFlip(p=0.5),
        #A.VerticalFlip(p=0.5),
        # Это реальные среднее и дисперсия выборки, но такая нормализация не сильно помогает
        #tensor([0.8418, 0.8288, 0.8200]), tensor([0.2174, 0.2178, 0.2234]))
        #A.Normalize(mean=([0.8418, 0.8288, 0.8200]), std=([0.2174, 0.2178, 0.2234])),
        #A.Normalize(),
        A.ColorJitter(),
        
        ToTensorV2()
    ])

train_dataset = Dataset(train_folder, train_transforms)
train_loader =  torch.utils.data.DataLoader(train_dataset, batch_size, pin_memory=True, shuffle=True)

In [20]:
model = UNet(n_filters=16)
criterion = nn.BCEWithLogitsLoss()
metric = {'name' : 'accuracy', 'func' : accuracy}
config = {
    'lr': 1e-3,
    'epochs': 50,
    'early_stopping': 5
}
trainer = Trainer(model, criterion, metric, config)

In [21]:
trainer.fit(train_loader)

Epoch 0: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.93it/s]
Epoch 1:   5%|███▎                                                                      | 1/22 [00:00<00:02,  7.52it/s]

Epoch 0, loss: 0.16817723206159743,                   accuracy: 60.483360290527344


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  4.27it/s]
Epoch 2:   5%|███▎                                                                      | 1/22 [00:00<00:02,  7.40it/s]

Epoch 1, loss: -1.6079603650353171,                   accuracy: 64.88874608820134


Epoch 2: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  4.05it/s]
Epoch 3:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 2, loss: -2.952475615523078,                   accuracy: 64.60250507701527


Epoch 3: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.93it/s]
Epoch 4:   5%|███▎                                                                      | 1/22 [00:00<00:03,  5.30it/s]

Epoch 3, loss: -4.1088043617253955,                   accuracy: 66.4534482088956


Epoch 4: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.90it/s]
Epoch 5:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 4, loss: -5.2788901127536185,                   accuracy: 65.78788757324219


Epoch 5: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  4.08it/s]
Epoch 6:   5%|███▎                                                                      | 1/22 [00:00<00:03,  5.84it/s]

Epoch 5, loss: -6.3197976154359905,                   accuracy: 66.74057353626598


Epoch 6: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  4.01it/s]
Epoch 7:   5%|███▎                                                                      | 1/22 [00:00<00:03,  6.90it/s]

Epoch 6, loss: -8.077602177858353,                   accuracy: 66.92123413085938


Epoch 7: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.87it/s]
Epoch 8:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 7, loss: -9.703822080384601,                   accuracy: 66.36437502774325


Epoch 8: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.97it/s]
Epoch 9:   0%|                                                                                  | 0/22 [00:00<?, ?it/s]

Epoch 8, loss: -10.654900342226028,                   accuracy: 66.54539975253019


Epoch 9: 100%|█████████████████████████████████████████████████████████████████████████| 22/22 [00:06<00:00,  3.49it/s]
Epoch 10:   5%|███▎                                                                     | 1/22 [00:00<00:03,  5.93it/s]

Epoch 9, loss: -12.914805374362253,                   accuracy: 68.26386885209517


Epoch 10: 100%|████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.83it/s]
Epoch 11:   5%|███▎                                                                     | 1/22 [00:00<00:03,  6.23it/s]

Epoch 10, loss: -14.584787401286038,                   accuracy: 65.87787974964489


Epoch 11: 100%|████████████████████████████████████████████████████████████████████████| 22/22 [00:06<00:00,  3.57it/s]
Epoch 12:   5%|███▎                                                                     | 1/22 [00:00<00:04,  5.09it/s]

Epoch 11, loss: -17.04214844378558,                   accuracy: 69.04780647971414


Epoch 12: 100%|████████████████████████████████████████████████████████████████████████| 22/22 [00:06<00:00,  3.49it/s]
Epoch 13:   0%|                                                                                 | 0/22 [00:00<?, ?it/s]

Epoch 12, loss: -18.419309404763307,                   accuracy: 65.2688980102539


Epoch 13: 100%|████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.86it/s]
Epoch 14:   5%|███▎                                                                     | 1/22 [00:00<00:03,  5.34it/s]

Epoch 13, loss: -21.27339068597013,                   accuracy: 67.72916967218572


Epoch 14: 100%|████████████████████████████████████████████████████████████████████████| 22/22 [00:05<00:00,  3.95it/s]
Epoch 15:   0%|                                                                                 | 0/22 [00:00<?, ?it/s]

Epoch 14, loss: -23.905456911433827,                   accuracy: 66.29399386319247


Epoch 15: 100%|████████████████████████████████████████████████████████████████████████| 22/22 [00:06<00:00,  3.32it/s]
Epoch 16:   0%|                                                                                 | 0/22 [00:00<?, ?it/s]

Epoch 15, loss: -25.30314500765367,                   accuracy: 64.21467174183239


Epoch 16: 100%|████████████████████████████████████████████████████████████████████████| 22/22 [00:06<00:00,  3.31it/s]

Epoch 16, loss: -28.616344598206606,                   accuracy: 67.2585747458718





In [22]:
def save_predictions_as_imgs(loader, model, thr=0.5, folder="saved_images/"):
    model.eval()
    for idx, data in enumerate(loader):
        
        x = data['image'].to(device=device)
        y = data['mask']
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > thr).float()
        x = x.float() / 255
        torchvision.utils.save_image(x.data.cpu(), f"{folder}/orig_{idx}.png")
        torchvision.utils.save_image(preds, f"{folder}/pred_{idx}.png")
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")

    model.train()
    
model.load_state_dict(torch.load(os.path.join(Trainer.CHECKPOINTS_PATH, 'weights.pth')))
save_predictions_as_imgs(train_loader, model)