In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import torch.optim as optim
from model import UNET
from tqdm import tqdm
import FILE_PATHS

from utils import (
    load_checkpoint,
    save_checkpoint,
    get_loaders_masks,
    check_accuracy,
    save_predictions_as_imgs,
    train_fn
    )

In [2]:
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
NUM_EPOCHS =5
NUM_WORKERS = 2
IMAGE_HEIGHT = 112
IMAGE_WIDTH = 112
PIN_MEMORY = True
LOAD_MODEL = False

TRAIN_IMG_DIR, TRAIN_MASK_DIR, TRAIN_HEATMAPS = FILE_PATHS.split("train")
VAL_IMG_DIR, VAL_MASK_DIR, VAL_HEATMAPS = FILE_PATHS.split("validation")

#Transformations
train_transform = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        # A.Rotate(limit=35, p=1.0),
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

#Model
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

#Batches
train_loader, val_loader = get_loaders_masks(
    TRAIN_IMG_DIR,
    TRAIN_MASK_DIR,
    VAL_IMG_DIR,
    VAL_MASK_DIR,
    BATCH_SIZE,
    train_transform,
    val_transforms,
    NUM_WORKERS,
    PIN_MEMORY,
)

model_type="masks"

In [3]:
def main():

    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler, DEVICE,model_type)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        #por mientras no guardamos el modelo
        #save_checkpoint(checkpoint)

        #check accuracy
        check_accuracy(val_loader, model, device=DEVICE, model_type=model_type)

        path = f"../saved_images/epoch_{epoch}"

        # Save predictions
        save_predictions_as_imgs(
            val_loader, model, folder=path, device=DEVICE, model_type=model_type
        )


if __name__ == "__main__":
    main()

100%|██████████| 8/8 [00:07<00:00,  1.03it/s, loss=0.614]


Got 1123221/6422528 with acc 17.49
Dice score: 17.49


100%|██████████| 8/8 [00:05<00:00,  1.52it/s, loss=0.475]


Got 5362594/6422528 with acc 83.50
Dice score: 83.50


100%|██████████| 8/8 [00:05<00:00,  1.49it/s, loss=0.392]


Got 5897651/6422528 with acc 91.83
Dice score: 91.83


100%|██████████| 8/8 [00:05<00:00,  1.48it/s, loss=0.356]


Got 6107756/6422528 with acc 95.10
Dice score: 95.10


100%|██████████| 8/8 [00:05<00:00,  1.52it/s, loss=0.329]


Got 6253661/6422528 with acc 97.37
Dice score: 97.37
