### Libraries

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt
import json

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Hyperparameter

In [None]:
num_epochs = 100
batch_size = 25
learning_rate = 1e-4
continue_training = True

### Get Data

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5], [0.5]
        )
    ]
)

folder_path = "../data/"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)

dataset = datasets.MNIST(root=folder_path, train=True, download=True, transform=transform)
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size = batch_size, shuffle = True)

### Autoencoder

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(
            torch.nn.Linear(28 * 28, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 9)
        )
        
        self.decoder = nn.Sequential(
            torch.nn.Linear(9, 18),
            torch.nn.ReLU(),
            torch.nn.Linear(18, 36),
            torch.nn.ReLU(),
            torch.nn.Linear(36, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 28 * 28),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

### Network

In [None]:
# data_dim = 1
# # automatically calculate the dimension
# for dimension in range(1, dataset.data.ndim):
#     data_dim *= dataset.data.size(dimension)

AE = AutoEncoder().to(device)
lossFunction = nn.BCELoss()

# Optimizers

optimizer = torch.optim.Adam(AE.parameters(), lr=learning_rate, weight_decay= 1e-8)

### Model loading
Load a previous model if ``continue_training`` is set to `True`

In [None]:
if continue_training:
    AE = torch.load('complete.pth')

### Training

In [None]:
# create folder structure if it does not exist

folder_path = "output/"
if not os.path.exists(folder_path):
    os.makedirs(folder_path)


start_epoch = 1  # import epoch ?
total_epochs_completed = 0

if continue_training:
    with open('metadata.json', 'r') as f:
        metadata = json.load(f)
        total_epochs_completed = metadata['total_epochs']

        start_epoch = total_epochs_completed + 1

end_epoch = start_epoch + num_epochs
outputs = []
for epoch in range(start_epoch, end_epoch + 1):
    total_epochs_completed += 1
    losses = []
    for image, _ in loader:
        # take image from loader an flatten it
        image = image.reshape(-1, 28 * 28).to(device)

        # pass (flattened) image through autoencoder
        reconstructed = AE(image)

        # evaluate loss by comparing reconstructed image with actual image
        loss = lossFunction(reconstructed, image)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss)
    outputs.append((epoch, losses[-1]))

    if epoch % 10 == 0:
        # Show input and reconstructed images side by side
        fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
        axes[0].imshow(image[0].reshape(28, 28).to("cpu"), cmap="gray")
        axes[0].axis("off")
        axes[0].set_title("Input Image")
        axes[1].imshow(
            reconstructed[0].detach().to("cpu").numpy().reshape(28, 28), cmap="gray"
        )
        axes[1].axis("off")
        axes[1].set_title("Reconstructed Image")
        plt.tight_layout()
        plt.savefig(os.path.join(folder_path, f"epoch_{epoch}.png"))  # Save the figure
        plt.show()
        plt.close()

# Save the total_epochs to a JSON file
with open('metadata.json', 'w') as f:
    json.dump({'total_epochs': total_epochs_completed}, f)

### Save model

In [None]:
torch.save(AE, 'complete.pth')

### Evaluation

In [None]:
if evaluate:
    AE = torch.load('complete.pth')

    dataset_test = datasets.MNIST(
        root='../data', train=False, transform=transform, download=False
    )

    loader_test = torch.utils.data.DataLoader(
        dataset=dataset_test, batch_size=batch_size, shuffle=False
    )

    # create folder structure if it does not exist

    folder_path = "output/eval"
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)


    with torch.no_grad():
        for batch_idx, (image, _) in enumerate(loader_test):
            # take image from loader an flatten it
            image = image.reshape(-1, 28 * 28).to(device)

            # pass (flattened) image through autoencoder
            reconstructed = AE(image)

            # evaluate loss by comparing reconstructed image with actual image
            loss = lossFunction(reconstructed, image)

            # Show input and reconstructed images side by side
            if batch_idx % 1000:
                fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
                axes[0].imshow(image[0].reshape(28, 28).to("cpu"), cmap="gray")
                axes[0].axis("off")
                axes[0].set_title("Input Image")
                axes[1].imshow(
                    reconstructed[0].detach().to("cpu").numpy().reshape(28, 28), cmap="gray"
                )
                axes[1].axis("off")
                axes[1].set_title("Reconstructed Image")
                plt.tight_layout()
                plt.savefig(os.path.join(folder_path, f"test_{batch_idx}.png"))  # Save the figure
                plt.show()
                plt.close()
