In [None]:
"""
This part is designed to teach you some things to do if training seems to fail. 
    In general, it is good practice to use some of these tricks before training even begins, but sometimes stuff simply goes wrong anyway.

1) Typically you want to train a model many times with the full dataset. 
    During debugging, we attempt to isolate the problem. In this case, that means overfitting to a single minibatch.

2) We use the structures and code generated during previous steps to make the experience uniform, but change the training loop for a simpler overfitting loop.
"""

In [None]:
from config.unet import Configuration

# load configuration information
configuration = Configuration()


In [None]:
import torch

from src.dataloader import load_data_wrapper
from src.losses import select_loss_fnc
from src.model import model_loader_wrapper
from src.viz import eval_unlabelled_images, plot_loss_and_accuracy


In [None]:
dataloaders = load_data_wrapper(**configuration)
model = model_loader_wrapper(**configuration)


In [None]:
import matplotlib.pyplot as plt

# DON'T FORGET
# - you can use device="cpu" to run everything on cpu for easier debugging - don't forget to reproduce the model and data on these devices
# - you can wrap the whole for-loop in a `with torch.autograd.detect_anomaly():` context to get debugging feedback

In [None]:
# we will start by setting up the optimization parameters
loss_fnc = select_loss_fnc(**configuration)
optimizer = torch.optim.Adam(model.parameters(), lr=100)

data, target = next(iter(dataloaders['train']))
data = data.to(configuration.device)
target = target.to(configuration.device)

losses = []  # keep track of our losses

for i in range(100):
    # print(i)  # sometimes it can be good to keep track of how far along we are

    # zero gradients
    optimizer.zero_grad()

    # calculate the model's output based on the data and sigmoid to transform to range 0-1
    output = model(data)
    output = output.sigmoid()

    # calculcate loss
    loss = loss_fnc(output, output)

    # track the loss
    losses.append(loss.item())

    # calculate gradients and update weights
    loss.backward

plt.plot(losses)
plt.show()

# Questions/exercises:

1. As mentioned, the functions in optimizer.py have a lot of reporting information that are not essential for the training of a model.
    Try to make a copy of the two functions in optimizer.py and remove all the reporting information and make the code as simple as possible.
    Can you train your model with these new functions?

2. At each epoch of the optimization, the model goes through the training dataset followed by the validation dataset.
    As previously stated, the purpose of the validation dataset is to help find the best model, how should the validation dataset help with this? (How would you incorporate that into your optimization function?)
3. How does the validation dataset differ from the test dataset?
4. How come the validation loss is lower than the training loss?