In [1]:
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms

In [2]:
from model import Unet
from hyperparameters import *
from utils import save_checkpoint, get_loaders, check_accuracy, save_predictions_as_imgs

In [3]:
def train_fn(loader, model, optimizer, loss_fn, scaler):
    loop = tqdm(loader)

    for batch_idx, (data, targets) in enumerate(loop):
        data = data.float().to(device = DEVICE)
        targets = targets.float().to(device = DEVICE)

        # Forward
        with torch.cuda.amp.autocast():
            predictions = model(data)
            loss = loss_fn(predictions, targets)
        
        # Backward
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update Loop
        loop.set_postfix(loss = loss.item())

In [4]:
train_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    #transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_transform_gt = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    #transforms.Normalize(mean=[0.0], std=[1.0])
    transforms.Normalize(mean=[0.5], std=[0.5])
])

test_transform = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    #transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

test_transform_gt = transforms.Compose([
    transforms.Resize(IMAGE_SIZE),
    transforms.PILToTensor(),
    transforms.ConvertImageDtype(torch.float),
    #transforms.Normalize(mean=[0.0], std=[1.0])
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [5]:
model = Unet(in_channels = IN_CHANNELS, out_channels = OUT_CHANNELS).to(device = DEVICE)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)

train_loader, test_loader = get_loaders(train_transform, test_transform, train_transform_gt, test_transform_gt)

In [6]:
def main():
    scaler = torch.cuda.amp.GradScaler()
    for epoch in range(NUM_EPOCHS):
        train_fn(train_loader, model, optimizer, loss_fn, scaler)

    checkpoint = {
        "state_dict" : model.state_dict(),
        "optimizer" : optimizer.state_dict()
    }
    
    save_checkpoint(checkpoint)

    check_accuracy(test_loader, model, device = DEVICE)

    save_predictions_as_imgs(test_loader, model, folder = "saved_images/", device = DEVICE)

In [7]:
if __name__ == "__main__":
    main()

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
100%|██████████| 10/10 [00:58<00:00,  5.87s/it, loss=-550]  
100%|██████████| 10/10 [00:59<00:00,  5.97s/it, loss=-4.16e+3]
100%|██████████| 10/10 [00:59<00:00,  5.97s/it, loss=-6.57e+3]
100%|██████████| 10/10 [01:00<00:00,  6.06s/it, loss=-9.05e+3]
100%|██████████| 10/10 [01:04<00:00,  6.43s/it, loss=-6.21e+3]
100%|██████████| 10/10 [01:01<00:00,  6.16s/it, loss=-8.45e+3]
100%|██████████| 10/10 [01:01<00:00,  6.13s/it, loss=-7.86e+3]
100%|██████████| 10/10 [00:59<00:00,  5.99s/it, loss=-6.66e+3]
100%|██████████| 10/10 [01:00<00:00,  6.03s/it, loss=-7.99e+3]
100%|██████████| 10/10 [00:59<00:00,  5.98s/it, loss=-9.01e+3]
100%|██████████| 10/10 [00:59<00:00,  5.98s/it, loss=-9.35e+3]
100%|██████████| 10/10 [01:00<00:00,  6.00s/it, loss=-8.69e+3]
100%|██████████| 10/10 [00:59<00:00,  5.97s/it, loss=-1.19e+4]
100%|██████████| 10/10 [00:59<00:00,  5.99s/it, loss=-1.01e+4]
100%|██████████| 10/10 [01:00<00:00,  6.03s/it,

=> Saving checkpoint
Accuracy: 0.60%, Dice Score: 1.9994
