In [None]:
"""
In this fourth example, we will take a closer look at the model optimization.

At first the optimization might seem a bit complicated, but the core of the optimization is actually rather simple.
What makes it looks complicated is all the reporting and tracking of loss and accuracy.
Try to step through the optimization code and see whether you can understand the various parts of the code.

Here we will briefly talk about some of the design choices that are behind this pipeline.

1) Typically you want to train a model many times and easily be able to compare current runs to previous runs and often on these runs are done on a remote server.
    Our current implementation does not really offer this.
    To get something like this we suggest the usage of some reporting framework like mlflow, tensorboard, or other such tools that automatically logs metrics and hyperparameters and allows easy visualization and comparison through an API that you can access remotely.
    This also allows your entire team to run various models on various computers and compare them all in a central API.

2) Currently, we do not use the validation dataset for anything, but typically you would want to use this to find the best model.

3) In a more realistic pipeline we also need to be able to save and load the model.
    Saving and loading models is something that needs to fit into the overall framework you run your machine learning models in.
    For information on how saving and loading models might be done see:
    https://pytorch.org/tutorials/beginner/saving_loading_models.html
"""

In [None]:
from config.unet import Configuration

# load configuration information
configuration = Configuration()


In [None]:
import torch

from src.dataloader import load_data_wrapper
from src.losses import select_loss_fnc
from src.model import model_loader_wrapper
from src.viz import eval_unlabelled_images, plot_loss_and_accuracy


In [None]:
dataloaders = load_data_wrapper(**configuration)
model = model_loader_wrapper(**configuration)


In [None]:
import matplotlib.pyplot as plt

# DON'T FORGET
# - you can use device="cpu" to run everything on cpu for easier debugging - don't forget to reproduce the model and data on these devices
# - you can wrap the whole for-loop in a `with torch.autograd.detect_anomaly():` context to get debugging feedback

In [None]:
# we will start by setting up the optimization parameters
loss_fnc = select_loss_fnc(**configuration)
optimizer = torch.optim.Adam(model.parameters(), lr=100)

data, target = next(iter(dataloaders['train']))
data = data.to(configuration.device)
target = target.to(configuration.device)

losses = []  # keep track of our losses

for i in range(100):
    # print(i)  # sometimes it can be good to keep track of how far along we are

    # zero gradients
    optimizer.zero_grad()

    # calculate the model's output based on the data and sigmoid to transform to range 0-1
    output = model(data)
    output = output.sigmoid()

    # calculcate loss
    loss = loss_fnc(output, output)

    # track the loss
    losses.append(loss.item())

    # calculate gradients and update weights
    loss.backward

plt.plot(losses)
plt.show()

# Questions/exercises:

1. As mentioned, the functions in optimizer.py have a lot of reporting information that are not essential for the training of a model.
    Try to make a copy of the two functions in optimizer.py and remove all the reporting information and make the code as simple as possible.
    Can you train your model with these new functions?

2. At each epoch of the optimization, the model goes through the training dataset followed by the validation dataset.
    As previously stated, the purpose of the validation dataset is to help find the best model, how should the validation dataset help with this? (How would you incorporate that into your optimization function?)
3. How does the validation dataset differ from the test dataset?
4. How come the validation loss is lower than the training loss?