In [3]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from model import UNET
from utils import(
load_checkpoint,
save_checkpoint,
get_loaders,
check_accuracy,
save_predictions_as_imgs
)

In [4]:
LEARNING_RATE = 1e-4
DEVICE = "cpu" # if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
NUM_EPOCHS = 3
NUM_WORKERS = 2
IMAGE_HEIGHT = 80
IMAGE_WIDTH  = 120
PIN_MEMORY = True
LOAD_MODEL = True
TRAIN_IMG_DIR = "images/"
TRAIN_MASK_DIR = "masked_images/"
VAL_IMG_DIR = "val_images/"
VAL_MASK_DIR = "val_masks/"

def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)
    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=DEVICE)
        targets = targets.float().unsqueeze(1).to(device=DEVICE)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()


        loop.set_postfix(loss=loss.item())


def main():
    train_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0,1.0,1.0],
            max_pixel_value=255.0
        ),
        ToTensorV2()
    ])

    val_transform = A.Compose([
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0,1.0,1.0],
            max_pixel_value=255.0
        ),
        ToTensorV2()
    ])

    model = UNET(in_channels=3, out_channels=1).to(DEVICE)
    loss_fn =  nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    train_loader, val_loader = get_loaders(
        TRAIN_IMG_DIR,
        TRAIN_MASK_DIR,
        VAL_IMG_DIR,
        VAL_MASK_DIR,
        BATCH_SIZE,
        train_transform,
        val_transform,
        NUM_WORKERS,
        PIN_MEMORY
    )
    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)


        checkpoint = {
            "state_dict": model.state_dict(),
            "optmizer": optimizer.state_dict(),
        }
        save_checkpoint(checkpoint)
        check_accuracy(train_loader, model, device=DEVICE)

        save_predictions_as_imgs(
            train_loader, model, folder = "saved_images", device=DEVICE
        )


In [5]:
if __name__ == "__main__":
    main()



=> Loading checkpoint


100%|██████████| 10/10 [00:44<00:00,  4.40s/it, loss=-26] 


=> Saving checkpoint
got 948098/1440001 with acc 65.84
Dice score: 1.5585976839065552


100%|██████████| 10/10 [00:48<00:00,  4.86s/it, loss=-30.3]


=> Saving checkpoint
got 923130/1440001 with acc 64.11
Dice score: 1.364037275314331


100%|██████████| 10/10 [00:48<00:00,  4.86s/it, loss=-34.2]


=> Saving checkpoint
got 901371/1440001 with acc 62.60
Dice score: 1.7285871505737305
