In [1]:
# Imports
import torch
from tqdm import tqdm
import numpy as np
import torch.nn as nn
import torch.optim as optim
from unet import unet as UNET
import albumentations as A 
from albumentations.pytorch import ToTensorV2

from utils import (
    check_accuracy,
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    save_predictions
)

In [2]:
# HyperParams
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 20
NUM_WORKERS = 1
PIN_MEMORY = False
LOAD_MODEL = False
TRAIN_IMG_DIR = "../dataset/original/train_img/"
TRAIN_MASK_DIR = "../dataset/original/train_mask//"
VAL_IMG_DIR = "../dataset/original/val_img/"
VAL_MASK_DIR = "../dataset/original/val_mask//"

In [3]:
# Train function

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        # forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        # backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update tqdm loop
        loop.set_postfix(loss=loss.item())

In [4]:
# Transormations

transform = A.Compose(
    [
        A.Resize(height=160, width=240),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ],
)


In [5]:
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_loader, val_loader = get_loaders(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    NUM_WORKERS,
    PIN_MEMORY,
    transform=transform
)


In [6]:
if LOAD_MODEL:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

In [7]:
torch.cuda.empty_cache()

In [8]:
check_accuracy(val_loader, model, device=DEVICE)

Got 3815253/4915200 with acc 77.62
Dice score: 0.0


In [9]:
scaler = torch.cuda.amp.GradScaler()

In [10]:
for epoch in range(NUM_EPOCHS):
    print("Epoch: {}".format(epoch))
    train_fn(train_loader, model, optimizer, loss_fn, scaler)

    # save model
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)

    # check accuracy
    check_accuracy(val_loader, model, device=DEVICE)

    # print some examples to a folder
    save_predictions(
        val_loader, model, folder="saved_images/", device=DEVICE
    )

Epoch: 0


100%|██████████| 310/310 [01:55<00:00,  2.69it/s, loss=0.0843]


=> Saving checkpoint
Got 4880741/4915200 with acc 99.30
Dice score: 0.9843461513519287
Epoch: 1


100%|██████████| 310/310 [01:52<00:00,  2.76it/s, loss=0.0491]


=> Saving checkpoint
Got 4886373/4915200 with acc 99.41
Dice score: 0.9869333505630493
Epoch: 2


100%|██████████| 310/310 [02:01<00:00,  2.54it/s, loss=0.0339]


=> Saving checkpoint
Got 4885128/4915200 with acc 99.39
Dice score: 0.9866297245025635
Epoch: 3


100%|██████████| 310/310 [02:00<00:00,  2.58it/s, loss=0.023] 


=> Saving checkpoint
Got 4891395/4915200 with acc 99.52
Dice score: 0.9891982674598694
Epoch: 4


100%|██████████| 310/310 [01:59<00:00,  2.60it/s, loss=0.0175]


=> Saving checkpoint
Got 4892426/4915200 with acc 99.54
Dice score: 0.9897316098213196
Epoch: 5


100%|██████████| 310/310 [01:59<00:00,  2.60it/s, loss=0.015] 


=> Saving checkpoint
Got 4893569/4915200 with acc 99.56
Dice score: 0.9902461767196655
Epoch: 6


100%|██████████| 310/310 [01:55<00:00,  2.69it/s, loss=0.0127]


=> Saving checkpoint
Got 4893920/4915200 with acc 99.57
Dice score: 0.9904050827026367
Epoch: 7


100%|██████████| 310/310 [01:55<00:00,  2.68it/s, loss=0.0121]


=> Saving checkpoint
Got 4892657/4915200 with acc 99.54
Dice score: 0.9898416996002197
Epoch: 8


100%|██████████| 310/310 [01:56<00:00,  2.65it/s, loss=0.0107] 


=> Saving checkpoint
Got 4892631/4915200 with acc 99.54
Dice score: 0.9897710084915161
Epoch: 9


100%|██████████| 310/310 [02:02<00:00,  2.53it/s, loss=0.00938]


=> Saving checkpoint
Got 4894440/4915200 with acc 99.58
Dice score: 0.990737795829773
Epoch: 10


100%|██████████| 310/310 [02:02<00:00,  2.53it/s, loss=0.00836]


=> Saving checkpoint
Got 4894703/4915200 with acc 99.58
Dice score: 0.9907465577125549
Epoch: 11


100%|██████████| 310/310 [01:56<00:00,  2.65it/s, loss=0.00746]


=> Saving checkpoint
Got 4895691/4915200 with acc 99.60
Dice score: 0.9911710619926453
Epoch: 12


100%|██████████| 310/310 [02:02<00:00,  2.53it/s, loss=0.00692]


=> Saving checkpoint
Got 4895671/4915200 with acc 99.60
Dice score: 0.991190493106842
Epoch: 13


100%|██████████| 310/310 [02:00<00:00,  2.58it/s, loss=0.00635]


=> Saving checkpoint
Got 4896287/4915200 with acc 99.62
Dice score: 0.9914529323577881
Epoch: 14


100%|██████████| 310/310 [01:55<00:00,  2.68it/s, loss=0.00707]


=> Saving checkpoint
Got 4896384/4915200 with acc 99.62
Dice score: 0.9914913177490234
Epoch: 15


100%|██████████| 310/310 [02:01<00:00,  2.55it/s, loss=0.00634]


=> Saving checkpoint
Got 4896738/4915200 with acc 99.62
Dice score: 0.9916954040527344
Epoch: 16


100%|██████████| 310/310 [02:01<00:00,  2.54it/s, loss=0.00605]


=> Saving checkpoint
Got 4896484/4915200 with acc 99.62
Dice score: 0.9915444254875183
Epoch: 17


100%|██████████| 310/310 [01:59<00:00,  2.60it/s, loss=0.00577]


=> Saving checkpoint
Got 4896346/4915200 with acc 99.62
Dice score: 0.9914595484733582
Epoch: 18


100%|██████████| 310/310 [02:00<00:00,  2.56it/s, loss=0.00529]


=> Saving checkpoint
Got 4896814/4915200 with acc 99.63
Dice score: 0.9917051792144775
Epoch: 19


100%|██████████| 310/310 [02:13<00:00,  2.32it/s, loss=0.0049] 


=> Saving checkpoint
Got 4897136/4915200 with acc 99.63
Dice score: 0.9918534755706787
