# Training

In [None]:
import torch

import train as tr
from unet import UNet

In [None]:
# choose the upsampling strategy
# True for Bilinear, False for UpConvolution
up_bilinear = True

# choose the loss
# "iou" or "bce"
loss = "iou"

# choose the number of epochs
n_epoch = 20

In [None]:
net = UNet(n_channels=3, bilinear=up_bilinear)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device=device)
print(device)

In [None]:
tr.train_net(net,
              device,
              epochs=n_epoch,
              batch_size=1,
              lr=0.001,
              save_cp=True,
              img_scale=1,
              loss=loss)

# Testing

In [None]:
import torch
from torch.utils.data import DataLoader

import eval
from unet import UNet
import utils.dataset as ds

In [None]:
n_epoch_test = n_epoch
weights_path = "checkpoints/CP_epoch"+str(n_epoch_test)+".pth"

up_bilinear_test = up_bilinear

In [None]:
net = UNet(n_channels=3, bilinear=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net.to(device=device)
print(device)

In [None]:
net.load_state_dict(torch.load(weights_path, map_location=device))

In [None]:
test_ds = ds.BasicDataset("dataset_custom/", img_scale=1, my_set="test", data_aug=False)
print(len(test_ds))
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

In [None]:
# TEST SCORE
test_score = eval.eval_net(net, test_loader, device)

print(test_score)

# Prediction

In [None]:
import predict as pred
import numpy as np
from PIL import Image
from torchvision import transforms

In [None]:
nbr_sample = 40
nbr_sample %= len(test_ds)

img, grnd = test_ds[nbr_sample]

print("Sample {}/{}".format(nbr_sample, len(test_ds)-1))

In [None]:
mask = pred.predict_img(net,
                img,
                device,
                out_threshold=0.5)

In [None]:
gray1 = transforms.ToPILImage()(img[0])
gray2 = transforms.ToPILImage()(img[2])
mask1 = transforms.ToPILImage()(img[1])
mask2 = transforms.ToPILImage()(grnd)

mask2_pred = mask.astype(int)

mask2_pred = Image.fromarray(np.uint8(mask2_pred*255))
print("Grayscale window from moving image (input 1)")
display(gray1)
print("Mask annotation from moving image (input 2)")
display(mask1)
print("Grayscale window from fixed image (input 3)")
display(gray2)
print("Mask annotation from fixed image (ground truth)")
display(mask2)
print("Mask annotation predicted from fixed image (output)")
display(mask2_pred)