### Libraries

In [None]:
import os
import sys
sys.path.append('../src')
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import matplotlib.pyplot as plt

# custom imports
from classes.autoencoder.autoencoder import AutoEncoder
from classes.autoencoder.autoencoder4x100 import AutoEncoder4x100
from classes.autoencoder.autoencoderCNN import AutoEncoderCNN

from utils.foldergen import generate_folder
from utils.dataset import load_datasets
from utils.tracker import Tracker
from utils.drawImgs import view_reconstructed

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Training Configuration

In [None]:
datasettype = "MNIST"
continue_training = False
evaluate = False

# Hyperparameter
num_epochs = 50
batch_size = 25
learning_rate = 1e-4

# output settings
graph_every_epoch = 5
comparision_every_epoch = 5

### Get Data

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5], [0.5]
        )
    ]
)

# create folder structure
folders = ["output", "save", "eval"]
generate_folder(folders)

dataset_train, loader_train = load_datasets(datasettype,transform, batch_size)
dataset_test, loader_test = load_datasets(datasettype,transform, batch_size, train=False, download=False)

### Network

In [None]:
network = AutoEncoderCNN(datasettype=datasettype).to(device)

lossFunction = nn.MSELoss()

optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate, weight_decay= 1e-8)

### Model loading
Load a previous model if ``continue_training`` is set to `True`

In [None]:
tracker = Tracker()
if continue_training:
    network = torch.load('complete.pth')
    tracker.load("data.json")
    start_epoch = tracker.epochs_completed + 1
else:
    start_epoch = 1

### Training

In [None]:
training_steps = 0
end_epoch = start_epoch + num_epochs
for epoch in range(start_epoch, end_epoch):
    total_loss = 0
    total_loss_eval = 0

    tracker.epochs_completed += 1
    average_loss = 0.0
    average_loss_eval = 0.0
    eval_loss = 0
    
    for image, label in loader_train:
        training_steps += 1

        reconstructed = network(image.to(device))
        
        # print(image.shape)
        # print(reconstructed.shape)
        # pass image through autoencoder
        reconstructed = network(image.to(device))

        # evaluate loss by comparing reconstructed image with actual image
        loss = lossFunction(reconstructed, image)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * batch_size

    average_loss += total_loss / 60000
    print(epoch, "/", end_epoch-1, average_loss)
    tracker.y_loss["train"].append(average_loss)

    tracker.x_epoch.append(epoch)
    tracker.learning_rate.append(learning_rate)
    

    network.cpu()
    network.eval()

    reconstructed = network(image)

    eval_loss = lossFunction(reconstructed, image)
    # print(eval_loss.item())
    tracker.y_loss["val"].append(eval_loss.item())

    if comparision_every_epoch:
        view_reconstructed(image, reconstructed)
        print(label[0].shape)

    network.to(device)
    network.train()
    torch.save(network, f'save/ep{epoch}.pth')

    if epoch % graph_every_epoch == 0:
        tracker.plotLossGraph()

if not epoch % 10 == 0:
    tracker.plotLossGraph()

### Save model

In [None]:
torch.save(network, 'complete.pth')
tracker.save("data.json")

### Evaluation

In [None]:
# TODO: handle CIFAR10 and test CNN
if evaluate:
    network = torch.load('complete.pth')

    dataset_test = datasets.MNIST(
        root='../data', train=False, transform=transform, download=False
    )

    loader_test = torch.utils.data.DataLoader(
        dataset=dataset_test, batch_size=batch_size, shuffle=False
    )

    with torch.no_grad():
        for batch_idx, (image, _) in enumerate(loader_test):
            # take image from loader an flatten it
            image = image.reshape(-1, 28 * 28).to(device)

            # pass (flattened) image through autoencoder
            reconstructed = network(image)

            # evaluate loss by comparing reconstructed image with actual image
            loss = lossFunction(reconstructed, image)

            # Show input and reconstructed images side by side
            if batch_idx % 1000:
                fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
                axes[0].imshow(image[0].reshape(28, 28).to("cpu"), cmap="gray")
                axes[0].axis("off")
                axes[0].set_title("Input Image")
                axes[1].imshow(
                    reconstructed[0].detach().to("cpu").numpy().reshape(28, 28), cmap="gray"
                )
                axes[1].axis("off")
                axes[1].set_title("Reconstructed Image")
                plt.tight_layout()
                plt.savefig(os.path.join("eval/", f"test_{batch_idx}.png"))  # Save the figure
                plt.show()
                plt.close()
