In [1]:
# Based on proposed model from https://arxiv.org/pdf/1505.04597.pdf

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
from unet import UNet
# from data_util import *
from torch.utils.tensorboard.writer import SummaryWriter


# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [2]:
# Data preprocessing
# Want to take in a 3 channel H x W image in a batch of N



In [3]:
# Training hyperparameters
num_epochs = 100
batch_size = 4
learning_rate = 0.001
depth = 3
initial_channel_size = 64
channel_multiplier = 2
dropout = 0.5
pool_kernel_size = 2
pool_stride = 2
double_conv_kernel_size = 3
up_conv_kernel_size = 2
up_conv_stride = 2

# Save model
writer = SummaryWriter("loss_plots")
step = 0
save_model = True
load_model = False
save_epoch = 10

In [6]:
image_size = 572
image = torch.rand((1, 1, image_size, image_size)).to(device)
model = UNet(
    depth=depth,
    initial_channel_size=initial_channel_size,
    channel_multiplier=channel_multiplier,
    dropout=dropout,
    pool_kernel=pool_kernel_size,
    pool_stride=pool_stride,
    double_conv_kernel=double_conv_kernel_size,
    up_conv_kernel=up_conv_kernel_size,
    up_conv_stride=up_conv_stride
).to(device)
print('------------------ Original Image ------------------')
print(image.shape)
mask = model(image)

------------------ Original Image ------------------
torch.Size([1, 1, 572, 572])
----------------- Downsampling -----------------
torch.Size([1, 64, 284, 284])
torch.Size([1, 128, 140, 140])
torch.Size([1, 256, 68, 68])
torch.Size([1, 512, 32, 32])
----------------- Bottom of U-Net -----------------
torch.Size([1, 1024, 28, 28])
----------------- Upsampling -----------------
torch.Size([1, 1024, 28, 28])
torch.Size([1, 512, 52, 52])
torch.Size([1, 256, 100, 100])
torch.Size([1, 128, 196, 196])
torch.Size([1, 64, 388, 388])
----------------- Final Mask -----------------
torch.Size([1, 2, 388, 388])


In [None]:
# Training Model
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

if load_model:
    load_checkpoint(torch.load("my_checkpoint.pth.tar"))
    
# Train the model
for epoch in range(num_epochs):
    print(f"[Epoch {epoch} / {num_epochs}]")

    if epoch % save_epoch == 0:
        if save_model:
            print("=> Saving checkpoint...")
            checkpoint = {
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }

            save_checkpoint(checkpoint)

    model.train()
    for batch_idx, (batch_src, batch_trg) in enumerate(train_loader):
        output = model(batch_src)


        optimizer.zero_grad()
        loss = criterion(output, batch_trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

        optimizer.step()

        writer.add_scalar("Training loss", loss, global_step=step)

    step += 1
    model.eval()

    # Choose one image with its mask at random in each epoch to visualize the training process