# Model/Dataset Train Test

In this notebook an overfit is done to check the capability of our model's architecture to learn a specific output image from a randomized input. 

In [None]:
# This is very important! Do not delete! 
# This adds the parent directory to python's path so it can correctly find 
# our submodules

import sys
sys.path.append('..')

In [None]:
import torch
from PIL import Image
import numpy as np
import logging

from networks.ReferenceCNN import ReferenceCNN as Model
from trainer import Trainer
from hyperparameters import Hyperparameters

In [None]:
# Initialize logging

logging.basicConfig(level=logging.INFO)

We first generate the dummy input- and target data.

In [None]:
# image is expected output i.e., target
# in greyscale (first layer is one channel)
image_path = "../img/sgs_logo.webp"
out_img = np.array(Image.open(image_path).convert("L"))

# as input we generate a dummy random
# image of the same size (white noise)
in_img = np.random.randint(
    0, 255, out_img.shape, dtype=np.uint8
    )

In [None]:
# confirm in and output images are the same size
assert out_img.shape == in_img.shape

# visualize in and output images
import matplotlib.pyplot as plt
plt.imshow(in_img)
plt.show()
plt.imshow(out_img)
plt.show()

We take the dummy data to build the dataloader

In [None]:
params = Hyperparameters(epochs=4000, batch_size=1)

In [None]:
n_in_channels = 1
n_out_channels = 1

# expected shape within the dataloader/train-loop: (batch_size, n_in_channels, height, width)
x = torch.tensor(in_img).reshape(params.batch_size, n_in_channels, *in_img.shape)
y = torch.tensor(out_img).reshape(params.batch_size, 1, *out_img.shape)

# normalize the data
x = x / 255.0
y = y / 255.0

# create tensor dataset and dataloader
train_dataset = torch.utils.data.TensorDataset(x, y)
training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batch_size)
validation_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=params.batch_size)

We now train the network

In [None]:
# Model can be instantiated
network = Model(".")
optimizer = torch.optim.Adam(network.parameters(), lr=params.learning_rate)
criterion = torch.nn.MSELoss()
trainer = Trainer(optimizer, criterion, training_dataloader, validation_dataloader)
losses = trainer.train(network, params.epochs)

In [None]:
network.load_model()


We now evaluate the network

In [None]:
network.eval()

In [None]:
y_predict = network(x)
y_predict.shape

In [None]:
# represent y predict as image and show it
y_predict_img = y_predict.squeeze().detach().numpy()

plt.imshow(y_predict_img[1])