# Model/Dataset Learn Check

In this notebook an overfit to the test data slice is done to check the capability of our model's architecture to learn a specific output image as expected, and to confirm that the test data contains enough information for the network to differentiate between the points.

In [None]:
# This is very important! Do not delete! 
# This adds the parent directory to python's path so it can correctly find 
# our submodules

import sys
sys.path.append('..')

In [None]:
import torch
from PIL import Image
import numpy as np
import logging

from networks.ReferenceCNN import ReferenceCNN as Model
from trainer import Trainer
from hyperparameters import Hyperparameters

Load dataset

In [None]:
# load test input/output data
inputs = torch.load("../data/inputs.pt")
outputs = torch.load("../data/outputs.pt")

inputs.shape, outputs.shape

Inspect one input/output sample

In [None]:
from plotting import plot_data_samples

plot_data_samples(inputs, outputs)


We take 10 random data points and create a loader with them to test the model

In [None]:
params = Hyperparameters(epochs=4000)
n_in_channels = 1
n_out_channels = 2

# expected shape within the dataloader/train-loop: (batch_size, n_in_channels, height, width)
random_idxs = torch.randint(0, inputs.shape[0], (params.n_samples,))
x = inputs[random_idxs, :, :, :]
y = outputs[random_idxs, :, :, :]

# create tensor dataset and dataloader
train_dataset = torch.utils.data.TensorDataset(x, y)
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batch_size)
validation_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batch_size)

random_idxs

In [None]:
x.shape, y.shape

Instantiate and train the model with the data

In [None]:
# Model can be instantiated
network = Model()
optimizer = torch.optim.Adam(network.parameters(), lr=params.learning_rate)
criterion = torch.nn.MSELoss()
trainer = Trainer(optimizer, criterion, training_dataloader, validation_dataloader)
losses = trainer.train(network, epochs=params.epochs)

In [None]:
network.eval()

In [None]:
y_predict = network.forward(x.float()).detach().numpy()
y_predict.shape

In [None]:
from plotting import plot_data_samples

idx = 7

plot_data_samples(x, y, idx)
plot_data_samples(x, y_predict, idx)