## Train a convolutional neural network

In [None]:
%load_ext autoreload
%autoreload 2

# External
import torch
from torchvision import transforms
from torchvision import datasets
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
from PIL import Image, ImageOps

# Internal
from model import ConvNet
from dataset import HeightReconstructionDataset

In [None]:
# transforms images to a PyTorch Tensor
tensor_transform = transforms.ToTensor()

# create dataset
train_dataset = HeightReconstructionDataset('./dataset_csv/161_train_dataset.csv', './grayscale_tensors', './quadratic_±100um', transform = tensor_transform)
dev_dataset = HeightReconstructionDataset('./dataset_csv/161_dev_dataset.csv', './grayscale_tensors', './quadratic_±100um', transform = tensor_transform)
test_dataset = HeightReconstructionDataset('./dataset_csv/161_test_dataset.csv', './grayscale_tensors', './quadratic_±100um', transform = tensor_transform)
test_dataset.img_labels

In [None]:
# DataLoader is used to load the dataset for training
train_loader = torch.utils.data.DataLoader(dataset = train_dataset, batch_size = 8, shuffle = True)
dev_loader = torch.utils.data.DataLoader(dataset = dev_dataset, batch_size = 1, shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset, batch_size = 1, shuffle = True)

In [None]:
# Model Initialization
device = "cuda" if torch.cuda.is_available() else "cpu"

model = ConvNet().to(device)
  
# Validation using MSE Loss function
loss_function = torch.nn.MSELoss()
  
# Using an Adam Optimizer with lr = 0.1
optimizer = torch.optim.Adam(model.parameters(),
                             lr = 1e-2,
                             weight_decay = 1e-8)

In [None]:
print(model)

In [None]:
print(model.state_dict())

In [None]:
epochs = 20
outputs = []
train_losses = []
dev_losses = []
for epoch in range(epochs):
    # train loop
    for (input_tensor, heightmap) in train_loader:
        # pass data to cuda
        input_tensor, heightmap = input_tensor.to(device), heightmap.to(device)

        # Output of Network
        reconstructed = model(input_tensor)

        # Calculate loss
        loss = loss_function(reconstructed, heightmap)

        # The gradients are set to zero,
        # the the gradient is computed and stored.
        # .step() performs parameter update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Storing the losses in a list for plotting
        train_loss = loss.cpu().detach().item()
        train_losses += [train_loss]

        print(train_loss)
        
    outputs += [(epochs, heightmap, reconstructed)]
    
    # validation loop
    for (input_tensor, heightmap) in dev_loader:
        # pass data to cuda
        input_tensor, heightmap = input_tensor.to(device), heightmap.to(device)

        # Output of Network
        reconstructed = model(input_tensor)
        
        # Calculate loss
        dev_loss = loss_function(reconstructed, heightmap).cpu().detach().item()
        
        dev_losses += [dev_loss]
        
    print(f"<-------------------------- Epoch {epoch} -------------------------->")

In [None]:
reconstructed = []
actual = []

for (input_tensor, heightmap) in test_loader:
    reconstructed = model(input_tensor.to(device))
    reconstructed += [reconstructed.cpu().detach().numpy()[0, 0]]
    
    actual += [heightmap.cpu().detach().numpy()[0]]

In [None]:
# plot results
rows = len(reconstructed) + 1  # for legibility
cols = 2

fig, axes = plt.subplots(rows, cols, figsize=(10, rows*6))
fig.set_dpi(200)

for i, ax_row in enumerate(axes):
    # add train / dev loss curves
    if i == rows - 1:
        ax_row[0].set_xlabel('Iterations')
        ax_row[0].set_ylabel('Loss')
        ax_row[0].plot(train_losses);
        ax_row[0].set_title('Training Loss')
        ax_row[1].set_xlabel('Measurements')
        ax_row[1].set_ylabel('Loss')
        ax_row[1].plot(dev_losses);
        ax_row[1].set_title('Validation Loss')
        
    else:
        ax_row[0].imshow(actual[i])
        ax_row[0].set_axis_off()
        ax_row[0].set_title(f"Ground Truth\nmin: {np.nanmin(actual[i])}\nmed: {np.nanmedian(actual[i])}\nmax: {np.nanmax(actual[i])}")
        ax_row[1].imshow(reconstructed[i])
        ax_row[1].set_axis_off()
        ax_row[1].set_title(f"Reconstructed\nmin: {np.nanmin(reconstructed[i])}\nmed: {np.nanmedian(reconstructed[i])}\nmax: {np.nanmax(reconstructed[i])}")
    
plt.savefig('./outputs/Model_7_±100um_Overview.png', dpi=200, facecolor='w')