# Training on online dataset

## Module imports

In [6]:
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.nn import MSELoss
from tqdm import tqdm # progress bar
from torch.optim import Adam
import torch
import pickle
import numpy as np

# custom modules
from unet import (StentOnlineDataset, # custom on-the-go augmentation dataset
                   UNet) # our PyTorch U-Net model

## Model and cached weight imports

In [7]:
# parameters about cached weights
use_cached_weights = False # use models weights that we cached at checkpoint
path_weights = "./weights/unet-model_7000.pkl" # path to cached weights
previous_iter_count = 7000 # number of iterations we trained the model for before launching this notebook

model = None
if use_cached_weights:
    # load pretrained model with pickle
    with open(path_weights, "rb") as f:
        model = pickle.load(f)
else:
    # create new model
    model = UNet(in_channels=1, out_channels=1)
    # use double precision
    model.double()

## Training Parameters

In [8]:
# parameters used in training
n_images = 2000 # number of images per epoch

# learning rate
lr = 0.00001 # iter 7001 to 10000, excluding base image 7, 8, 9 for better capturing low contrast stents
# lr = 0.00005 # iter 3001 to 7000
# lr = 0.001 # iter 501 to 3000
# lr = 0.005 # iter 1 to 500

# custom image dataset class with on-the-go augmentation
dataset = StentOnlineDataset(n_images=n_images, base_image_path="data/dataset/base_png")
# pytorch dataset object
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [9]:
# mean square error for avoiding the model to over-whiten the output image when the stent sizes are small
criterion = MSELoss()
optimizer = Adam(model.parameters(), lr=lr)

### Prediction with current weights of the model

In [10]:
# prediction with current weights
# print("Input image:")
# plt.figure()
# test_input = dataset[0][0]
# plt.imshow(test_input[96:-96, 96:-96], cmap="gray")
# plt.show()
#
# print("Output image:")
# prediction = model(test_input.reshape(1, 512, 512))
# plt.figure()
# plt.imshow(prediction.detach().numpy()[0], cmap="gray")
# plt.show()

In [11]:
# save the losses for analysis
losses = []

In [None]:
# training
for i, data in enumerate(tqdm(data_loader)):
    # get the inputs
    inputs, targets = data
    # zero the parameter gradients
    optimizer.zero_grad()
    # forward + backward + optimize
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    loss.backward()
    optimizer.step()
    # print statistics
    losses.append(loss.item())
    if i % 10 == 0:
        print('Step [{}/{}], Loss: {:.10f}'.format(i+1, len(data_loader), loss.item()))
    if i < 200 and i % 20 == 0 or i % 10 == 0:
        print("Original image:")
        plt.figure()
        plt.imshow(inputs.detach().numpy()[0][94:-94, 94:-94], cmap="gray")
        plt.show()
        print("Predicted image:")
        plt.figure()
        plt.imshow(outputs.detach().numpy()[0], cmap="gray")
        plt.show()
        print("Ground truth:")
        plt.figure()
        plt.imshow(targets.detach().numpy()[0], cmap="gray")
        plt.show()

In [None]:
# plot the log loss
plt.plot(np.log(losses))
plt.xlabel("Iteration")
plt.ylabel("Log MES Loss")
plt.show()

# plot the loss
plt.plot(losses)
plt.xlabel("Iteration")
plt.ylabel("MES Loss")
plt.show()

In [None]:
# save the model
torch.save(model.state_dict(), f"./weights/unet-weights_{previous_iter_count + n_images}.pt")

# use pickle to save the model
with open(f"./weights/unet-model_{previous_iter_count + n_images}.pkl", "wb") as f:
    pickle.dump(model, f)

In [None]:
# save the losses
with open(f"./weights/unet-losses-{previous_iter_count + n_images}.pkl", "wb") as f:
    pickle.dump(losses, f)