### Libraries

In [None]:
import os
import sys
sys.path.append('../src')
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import json

# custom imports

from classes.autoencoder import AutoEncoder
from utils.foldergen import generate_folder
from utils.dataset import load_datasets
from utils.tracker import Tracker

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Training Configuration

In [None]:
# Hyperparameter
num_epochs = 2
batch_size = 25
learning_rate = 1e-4

# Model Settings
continue_training = False
evaluate = False

### Get Data

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5], [0.5]
        )
    ]
)

# create folder structure
folders = ["output", "save", "eval"]
generate_folder(folders)

dataset_train, loader_train = load_datasets(transform, batch_size)
dataset_test, loader_test = load_datasets(transform, batch_size, False)

### Network

In [None]:
AE = AutoEncoder().to(device)

lossFunction = nn.BCELoss()

optimizer = torch.optim.Adam(AE.parameters(), lr=learning_rate, weight_decay= 1e-8)

### Model loading
Load a previous model if ``continue_training`` is set to `True`

In [None]:
tracker = Tracker()
if continue_training:
    AE = torch.load('complete.pth')
    tracker.load("data.json")

### Training

In [None]:
if continue_training:
    start_epoch = tracker.epochs_completed + 1
else:
    start_epoch = 1

end_epoch = start_epoch + num_epochs
running_loss = 0.0
running_corrects = 0.0
for epoch in range(start_epoch, end_epoch):
    total_loss = 0
    count = 0

    tracker.epochs_completed += 1
    losses = []
    for image, _ in loader:
        # take image from loader an flatten it
        image = image.reshape(-1, 28 * 28).to(device)

        # pass (flattened) image through autoencoder
        reconstructed = AE(image)

        # evaluate loss by comparing reconstructed image with actual image
        loss = lossFunction(reconstructed, image)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss)

        total_loss += loss.item() * batch_size
        count += 1

    average_loss = total_loss / count
    tracker.y_loss["train"].append(average_loss / 60000)

    tracker.x_epoch.append(epoch)
    tracker.learning_rate.append(learning_rate)

    if epoch % 10 == 0:
        tracker.plotLossGraph()

        # Show input and reconstructed images side by side
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
        axes[0].imshow(image[0].reshape(28, 28).to("cpu"), cmap="gray")
        axes[0].axis("off")
        axes[0].set_title("Input Image")
        axes[1].imshow(
            reconstructed[0].detach().to("cpu").numpy().reshape(28, 28), cmap="gray"
        )
        axes[1].axis("off")
        axes[1].set_title("Reconstructed Image")
        plt.tight_layout()
        plt.savefig(os.path.join("output", f"epoch_{epoch}.png"))  # Save the figure
        plt.show()
        plt.close()

if not epoch % 10 == 0:
    tracker.plotLossGraph()

### Save model

In [None]:
torch.save(AE, 'complete.pth')
tracker.save("data.json")

### Evaluation

In [None]:
if evaluate:
    AE = torch.load('complete.pth')

    dataset_test = datasets.MNIST(
        root='../data', train=False, transform=transform, download=False
    )

    loader_test = torch.utils.data.DataLoader(
        dataset=dataset_test, batch_size=batch_size, shuffle=False
    )

    with torch.no_grad():
        for batch_idx, (image, _) in enumerate(loader_test):
            # take image from loader an flatten it
            image = image.reshape(-1, 28 * 28).to(device)

            # pass (flattened) image through autoencoder
            reconstructed = AE(image)

            # evaluate loss by comparing reconstructed image with actual image
            loss = lossFunction(reconstructed, image)

            # Show input and reconstructed images side by side
            if batch_idx % 1000:
                fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
                axes[0].imshow(image[0].reshape(28, 28).to("cpu"), cmap="gray")
                axes[0].axis("off")
                axes[0].set_title("Input Image")
                axes[1].imshow(
                    reconstructed[0].detach().to("cpu").numpy().reshape(28, 28), cmap="gray"
                )
                axes[1].axis("off")
                axes[1].set_title("Reconstructed Image")
                plt.tight_layout()
                plt.savefig(os.path.join("eval/", f"test_{batch_idx}.png"))  # Save the figure
                plt.show()
                plt.close()
