In [1]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch.nn as nn
import torch.optim as optim
from model import UNET
from dataloader import get_loaders_masks
import FILE_PATHS

from utils import (
    load_checkpoint,
    save_checkpoint,
    check_accuracy,
    save_predictions_as_imgs,
    train_fn,
)

In [2]:
# Hyperparameters
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
NUM_EPOCHS = 5
NUM_WORKERS = 2
IMAGE_HEIGHT = 112
IMAGE_WIDTH = 112
PIN_MEMORY = True
LOAD_MODEL = False

In [4]:
# Transformations
train_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        # A.Rotate(limit=35, p=1.0),
        # A.HorizontalFlip(p=0.5),
        # A.VerticalFlip(p=0.1),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

val_transforms = A.Compose(
    [
        A.Resize(height=IMAGE_HEIGHT, width=IMAGE_WIDTH),
        A.Normalize(
            mean=[0.0, 0.0, 0.0],
            std=[1.0, 1.0, 1.0],
            max_pixel_value=255.0,
        ),
        ToTensorV2(),
    ]
)

In [5]:
# MODEL
model = UNET(in_channels=3, out_channels=1).to(DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
model_type = "masks"

In [6]:
# BATCHES

train_loader, val_loader = get_loaders_masks(
    FILE_PATHS.IMAGES,
    FILE_PATHS.MASKS,
    BATCH_SIZE,
    train_transforms,
    val_transforms,
    NUM_WORKERS,
    test_size=0.2,
    seed=42,
)

TRAIN PATHS LENGTHS: images, masks
16, 16
VALIDATION PATHS LENGTHS: images, masks
4, 4


In [7]:
def main():
    if LOAD_MODEL:
        load_checkpoint(torch.load("my_checkpoint.pth.tar"), model)

    scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler, DEVICE, model_type)

        # save model
        checkpoint = {
            "state_dict": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        }
        # por mientras no guardamos el modelo
        # save_checkpoint(checkpoint)

        # check accuracy
        check_accuracy(val_loader, model, device=DEVICE, model_type=model_type)

        path = f"../saved_images/epoch_{epoch}"

        # Save predictions
        save_predictions_as_imgs(
            val_loader, model, folder=path, device=DEVICE, model_type=model_type
        )


if __name__ == "__main__":
    main()

100%|██████████| 1/1 [00:12<00:00, 12.14s/it, loss=0.856]


Got 4721/50176 with acc 9.41
Dice score: 9.41


100%|██████████| 1/1 [00:05<00:00,  5.99s/it, loss=0.824]


Got 4721/50176 with acc 9.41
Dice score: 9.41


100%|██████████| 1/1 [00:06<00:00,  6.19s/it, loss=0.801]


Got 4721/50176 with acc 9.41
Dice score: 9.41


100%|██████████| 1/1 [00:06<00:00,  6.40s/it, loss=0.782]


Got 4721/50176 with acc 9.41
Dice score: 9.41


100%|██████████| 1/1 [00:05<00:00,  5.46s/it, loss=0.762]


Got 4721/50176 with acc 9.41
Dice score: 9.41
