In [1]:
import torch
from albumentations.pytorch import ToTensorV2
import albumentations as A
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs
)

In [2]:
learning_rate = 1e-4
batch_size = 16
num_epochs = 3
image_height = 160
image_width = 240
load_model = True
train_img_dir = "./Dataset/train"
train_mask_dir = "./Dataset/train_masks"
val_img_dir = "./Dataset/val"
val_mask_dir = "./Dataset/val_masks"

In [3]:
def train_fn(loader, model, optimizer, loss_fn):
    loop = tqdm(loader)

    for batch_index, (data, targets) in enumerate(loop):
        targets = targets.float().unsqueeze(1)

        #forward
        predictions = model(data)
        loss = loss_fn(predictions, targets)

        #backward
        optimizer.zero_grad()
        loss.backward()

        #update
        with torch.no_grad():
            optimizer.step()
        

        # update tqdm loop
        loop.set_postfix(loss = loss.item)

In [4]:
def main():
    train_transform = A.Compose(
        [
            A.Resize(height=image_height, width=image_width),
            A.Rotate(limit = 35, p = 1.0),
            A.HorizontalFlip(p = 0.5),
            A.VerticalFlip(p = 0.1),
            A.Normalize(
                mean = [0.0, 0.0, 0.0],
                std = [1.0, 1.0, 1.0],
                max_pixel_value = 255.0
            ),
            ToTensorV2(),
        ],
    )

    val_transform = A.Compose(
        [
            A.Resize(height=image_height, width=image_width),
            A.Normalize(
                mean = [0.0, 0.0, 0.0],
                std = [1.0, 1.0, 1.0],
                max_pixel_value = 255.0
            ),
            ToTensorV2(),
        ],
    )        

    model = UNET(in_channels=3, out_channels=1)
    loss_fn = nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    train_loader, val_loader = get_loaders(
        train_img_dir,
        train_mask_dir,
        val_img_dir,
        val_mask_dir,
        batch_size,
        train_transform,
        val_transform
    )

    for epoch in range(num_epochs):
        train_fn(train_loader, model, optimizer, loss_fn)
        
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        save_checkpoint(checkpoint)
        
        check_accuracy(val_loader, model)
        
        save_predictions_as_imgs(
            val_loader, model, folder="saved_images/"
        )

In [5]:
if __name__ == "__main__":
    main()

100%|████████████| 255/255 [53:46<00:00, 12.65s/it, loss=<built-in method item of Tensor object at 0x0000016E2921BA20>]


=> Saving checkpoint
Got 38680904/39052800 with acc 99.05
Dice score: 0.9774115085601807


100%|████████████| 255/255 [50:58<00:00, 11.99s/it, loss=<built-in method item of Tensor object at 0x0000016E2D897E30>]


=> Saving checkpoint
Got 38751381/39052800 with acc 99.23
Dice score: 0.9818095564842224


100%|████████████| 255/255 [53:46<00:00, 12.65s/it, loss=<built-in method item of Tensor object at 0x0000016E2D897C00>]


=> Saving checkpoint
Got 38772860/39052800 with acc 99.28
Dice score: 0.9831939339637756
