# Relocate the folder, import libraries

In [None]:
import os
current_dir = os.path.abspath("")
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))

import sys
sys.path.append(parent_dir)

In [None]:
import idx2numpy
import numpy as np

#from torchvision import transforms
from torchvision.transforms import v2 # for torchvision > 0.15
import torch.nn as nn
import torcheval.metrics as metrics

%matplotlib inline 
import matplotlib.pyplot as plt

from utils.dataloaders import *
from utils.preprocessing import *
from utils.training_utils import *
from utils.metrics import *
from utils.losses import *

from sparse_autoencoder.models import *
from datetime import date

# Load datasets

In [None]:
data_dir = "../datasets/MNIST"

In [None]:
train_images = idx2numpy.convert_from_file(os.path.join(data_dir, "train", "train-images-idx3-ubyte"))
train_labels = idx2numpy.convert_from_file(os.path.join(data_dir, "train", "train-labels-idx1-ubyte"))
test_images = idx2numpy.convert_from_file(os.path.join(data_dir, "test", "t10k-images-idx3-ubyte"))
test_labels = idx2numpy.convert_from_file(os.path.join(data_dir, "test", "t10k-labels-idx1-ubyte"))

In [None]:
# Data stats
print("[INFO]: Training set shape: {}".format(train_images.shape))
print("[INFO]: Test set shape: {}".format(test_images.shape))
print("[INFO]: Training label shape: {}".format(train_labels.shape))
print("[INFO]: Test label shape: {}".format(test_labels.shape))

In [None]:
# Visualize the first few images for checking
fig, axes = plt.subplots(nrows=2, ncols=4)
for i in range(8):
    image = train_images[i, :, :]
    axes[int(i/4), i%4].imshow(image)
plt.show()

# Preprocessing Data

In [None]:
# Image scaling
train_images = train_images/255
test_images = test_images/255

# Calculate mean and std
mean = np.mean(train_images)
std = np.std(train_images)

In [None]:
train_transform = v2.Compose([
    v2.ToTensor(), # deprecated
    Flatten(),
    Linear_Normalize(mean, std),
    v2.ToDtype(torch.float32, scale=False),
])

In [None]:
train_dataset = MnistDataset(images=train_images, labels=train_labels, transforms=train_transform)
test_dataset = MnistDataset(images=test_images, labels=test_labels, transforms=train_transform)

# Define training configurations

In [None]:
# Model configuration
in_features = 784
enc_hidden_features = [256, 128]
emb_features = 64
dec_hidden_features = [128, 256]

# Training configuration
epochs = 25
batchsize = 128
learning_rate = 0.001
num_workers = 4
device = "cuda"
use_wandb = True
use_tensorboard = True

# Sprase Autoencoder Loss configuration
beta = 1.0
rho = 0.05

# Checkpoint dir
save_dir = "./checkpoints"
exp_name = "linear_sparse_autoencoder_{}".format(str(date.today()))
save_dir = os.path.join(save_dir, exp_name)
if not os.path.exists(save_dir):
    os.system("mkdir {}".format(save_dir))

In [None]:
train_loader = get_loader(train_dataset, batchsize, num_workers, shuffle=True, drop_last=False)
test_loader = get_loader(test_dataset, batchsize, num_workers, shuffle=False, drop_last=False)

In [None]:
model = Linear_AutoEncoder(in_features, enc_hidden_features, emb_features, dec_hidden_features)

In [None]:
loss_func = SparseAutoencoderLoss(beta=beta, rho=rho)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
metrics = [
    # The first thre metrics only work on images
    #metrics.FrechetInceptionDistance(device=device),
    #metrics.StructuralSimilarity(device=device),
    #metrics.PeakSignalNoiseRatio(device=device),
    metrics.MeanSquaredError(device=device)
]
metric_weights = [1.0]

In [None]:
if use_wandb:
    # initialize wandb for usage
    wandb.login()
    wandb.init(
        # Set the project where this run will be logged
        project="linear-mnist-sparse-autoencoder", 
        # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
        name=exp_name, 
        # Track hyperparameters and run metadata
        config={
            "learning_rate": learning_rate,
            "architecture": str(model.__class__.__name__),
            "dataset": "MNIST",
            "epochs": epochs,
        })

In [None]:
train_epochs(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    optimizer=optimizer,
    loss_func=loss_func,
    metrics=metrics,
    metric_weights=metric_weights,
    device=device,
    num_epochs=epochs,
    log_rate=5,
    save_rate=10,
    save_dir=save_dir,
    use_tensorboard=use_tensorboard,
    use_wandb=use_wandb,
    interval=0
)