In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import torch.optim as optim
from model import UNET
import FILE_PATHS
from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders_masks,
    get_loaders_landmarks,
    check_accuracy,
    save_predictions_as_imgs,
    train_fn,
)

In [2]:
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
NUM_EPOCHS =5
NUM_WORKERS = 2
IMAGE_HEIGHT = 112
IMAGE_WIDTH = 112
PIN_MEMORY = True
LOAD_MODEL = False


TRAIN_IMG_DIR, TRAIN_MASK_DIR, TRAIN_HEATMAPS = FILE_PATHS.split("train")
VAL_IMG_DIR, VAL_MASK_DIR, VAL_HEATMAPS = FILE_PATHS.split("validation")

#Transformations
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

#Model
model = UNET(in_channels=3, out_channels=7).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

#Batches
train_loader, val_loader = get_loaders_landmarks(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    TRAIN_HEATMAPS,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    VAL_HEATMAPS,
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

model_type="landmarks"

In [3]:

def main():

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler, DEVICE,model_type)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        #save_checkpoint(checkpoint)

        #check accuracy
        check_accuracy(val_loader, model, device=DEVICE, model_type=model_type)

        path = f"../saved_images/epoch_{epoch}"

        # Save predictions
        save_predictions_as_imgs(
            val_loader, model, folder=path, device=DEVICE, model_type=model_type
        )


if __name__ == "__main__":
    main()

100%|██████████| 8/8 [00:16<00:00,  2.11s/it, loss=0.734]


Got 4998102/6422528 with acc 77.82
Dice score: 77.82


100%|██████████| 8/8 [00:16<00:00,  2.09s/it, loss=0.691]


Got 4916784/6422528 with acc 76.56
Dice score: 76.56


100%|██████████| 8/8 [00:16<00:00,  2.05s/it, loss=0.651]


Got 5029347/6422528 with acc 78.31
Dice score: 78.31


100%|██████████| 8/8 [00:18<00:00,  2.27s/it, loss=0.62] 


Got 5194405/6422528 with acc 80.88
Dice score: 80.88


100%|██████████| 8/8 [00:17<00:00,  2.16s/it, loss=0.603]


Got 5285110/6422528 with acc 82.29
Dice score: 82.29
