In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders,
    check_accuracy,
    save_predictions_as_imgs
)
import toml
import import_ipynb
from model import UNET
import gc

importing Jupyter notebook from model.ipynb


In [2]:
# Loading hyperparameters
with open('hp.toml', 'r') as f:
    hp = toml.load(f)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
def train(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.to(device=device)
        targets = targets.float().unsqueeze(1).to(device=device)

        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # update loop
        loop.set_postfix(loss=loss.item())

train_transform = A.Compose(
    [
        A.Resize(height=hp['IMAGE_HEIGHT'], width=hp['IMAGE_WIDTH']),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],      # Normalized_Pixel = ((original_pixel/max_pixel_value) - mean)/std
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ], is_check_shapes=False
)

val_transform = A.Compose(
    [
        A.Resize(height=hp['IMAGE_HEIGHT'], width=hp['IMAGE_WIDTH']),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],      # Normalized_Pixel = ((original_pixel/max_pixel_value) - mean)/std
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ], is_check_shapes=False
)

model = UNET(in_channels=3, out_channels=1).to(device)
loss_fn = nn.BCEWithLogitsLoss()
optim = optim.Adam(model.parameters(), lr = hp['LEARNING_RATE'])
train_loader, val_loader = get_loaders(
    hp['TRAIN_IMG_DIR'],
    hp['TRAIN_MASK_DIR'],
    hp['VAL_IMG_DIR'],
    hp['VAL_MASK_DIR'],
    hp['BATCH_SIZE'],
    train_transform,
    val_transform,
    hp['NUM_WORKERS'],
    hp['PIN_MEMORY'],
)

scaler = torch.cuda.amp.GradScaler()
gc.collect()
torch.cuda.empty_cache()
acc = 0.0

for epoch in range(hp['NUM_EPOCHS']):
    train(train_loader,
            model,
            optim,
            loss_fn,
            scaler)
    
    curr_acc = check_accuracy(val_loader, model, device=device)

    if curr_acc > acc:
        acc = curr_acc
        print('Saving new best model...')

        # Saving the model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optim.state_dict()
        }

        save_checkpoint(checkpoint)

        save_predictions_as_imgs(
            val_loader, model, device=device
        )
