# Project 2 Unsupervised & Supervised Learning

## Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from itertools import product
import os
import hashlib
import matplotlib.pyplot as plt
import gc

## Iterable Data and Constants

In [2]:
def ret_search_param_grid(archtype="ANN", searchkey="learning_rate"):
    # Possible activations
    activations = ["linear", "relu", "sigmoid", "tanh", "lrelu"]
    lin_activations = ["linear", "relu", "lrelu"]
    nonlin_activations = ["sigmoid", "tanh"]
    # Generate all combinations
    encoder_combinations = [
        [first, second, first]
        for first, second in product(lin_activations, nonlin_activations)
    ]
    decoder_combinations = [
        [first, second, third]
        for first, second, third in product(lin_activations, nonlin_activations, activations)
    ]
    # Grid Search Parameters
    p_griddy_ANN = {
        "unit-test": {
            "learning_rate": [0.01],
            "batch_size": [32],
            "epochs": [1],
            "dropout_rate": [0.0],
            "input_size": [784],
            "layer_PE_nodes": [[2**13, 64, 10]],
            "layer_PE_activations": [["linear", "linear", "softmax"]],
            "loss": ["CE"]
        },
        # In Order...
        "learning_rate": {
            "learning_rate": [0.001, 0.01, 0.1, 0.2, 0.3, 0.4, 0.5],
            "batch_size": [128],
            "epochs": [256],
            "dropout_rate": [0.0],
            "input_size": [784],
            "layer_PE_nodes": [[2**11, 128, 64, 10]],
            "layer_PE_activations": [["lrelu", "lrelu", "lrelu", "softmax"]],
            "loss": ["CE"]
        },
        # "dropout_rate": {
        #     "learning_rate": [0.01],
        #     "batch_size": [32],
        #     "epochs": [256],
        #     "dropout_rate": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5],
        #     "input_size": [784],
        #     "layer_PE_nodes": [[2**11, 128, 64, 10]],
        #     "layer_PE_activations": [["lrelu", "lrelu", "lrelu", "softmax"]],
        #     "loss": ["CE"]
        # },
        # "layer_PE_activations": {
        #     "learning_rate": [0.01],
        #     "batch_size": [32],
        #     "epochs": [256],
        #     "dropout_rate": [0.0],
        #     "input_size": [784],
        #     "layer_PE_nodes": [[2**11, 128, 64, 10]],
        #     "layer_PE_activations": [["linear", "relu", "softmax"], ["linear", "sigmoid", "softmax"], ["linear", "tanh", "softmax"], ["linear", "lrelu", "softmax"]],
        #     "loss": ["CE"]
        # },
    }
    p_griddy_SAE = {
        "unit-test": {
            "learning_rate": [0.01],
            "batch_size": [32],
            "epochs": [1],
            "dropout_rate": [0.0],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["linear", "linear", "linear"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["linear", "linear", "linear"]],
            'loss': ['MSE'],
            '_lambda': [0],
            'norm_type': [2]
        },
        # In Order...
        "learning_rate": {
            "learning_rate": [2**-30, 2**-28, 2**-26, 2**-24, 2**-22, 2**-20, 2**-18, 2**-16, 2**-14, 2**-12, 2**-10, 2**-8, 2**-6, 2**-4],
            "batch_size": [256],
            "epochs": [1024],
            "dropout_rate": [0.0],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["linear", "linear", "linear"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["linear", "linear", "linear"]],
            'loss': ['MSE'],
            '_lambda': [0],
            'norm_type': [2]
        },
        "dropout_rate": {
            "learning_rate": [0.01],
            "batch_size": [256],
            "epochs": [256],
            "dropout_rate": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["linear", "linear", "linear"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["linear", "linear", "linear"]],
            'loss': ['MSE'],
            '_lambda': [0],
            'norm_type': [2]
        },
        "layer_PE_activations": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [256],
            "dropout_rate": [0.0],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": encoder_combinations,
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": decoder_combinations,
            'loss': ['MSE'],
            '_lambda': [0],
            'norm_type': [2]
        },
        "code_length": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.0],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10+4*k] for k in range(30)],
            "encoder_PE_activations": [["relu", "sigmoid", "relu"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["relu", "sigmoid" , "sigmoid"]],
            'loss': ['MSE'],
            '_lambda': [2**-6],
            'norm_type': [2]
        },
        "_lambda": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [256],
            "dropout_rate": [0.1],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["relu", "sigmoid", "relu"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["relu", "sigmoid", "sigmoid"]],
            'loss': ['custom'],
            '_lambda': [2**-20, 2**-19, 2**-18, 2**-17, 2**-16, 2**-15, 2**-14, 2**-13, 2**-12, 2**-11, 2**-10, 2**-9, 2**-8, 2**-7, 2**-6, 2**-5, 2**-4, 2**-3, 2**-2, 2**-1, 2**0],
            'norm_type': [2]
        },
        "norm_type": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [256],
            "dropout_rate": [0.1],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["relu", "sigmoid", "relu"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["relu", "sigmoid", "sigmoid"]],
            'loss': ['custom'],
            '_lambda': [2**-12],
            'norm_type': [0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3, 3.25, 3.5, 3.75, 4],
        },
        "3d_grid": {
            "learning_rate": [2**-10],
            "batch_size": [32, 64, 128, 256, 512],
            "epochs": [2048],
            "dropout_rate": [0.0],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["relu", "sigmoid", "relu"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["relu", "sigmoid", "sigmoid"]],
            'loss': ['custom'],
            '_lambda': [2**-10, 2**-9, 2**-8, 2**-7, 2**-6, 2**-5, 2**-4, 2**-3, 2**-2, 2**-1, 2**0],
            'norm_type': [1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3, 3.25, 3.5],
        }
    }
    p_griddy_SAEANN = {
        "unit-test": {
            "learning_rate": [0.01],
            "batch_size": [32],
            "epochs": [1],
            "dropout_rate": [0.0],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["linear", "linear", "linear"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["linear", "linear", "linear"]],
            'loss': ['CE'],
            "sae_model_path": [None],  # Update with actual path
            "freeze_encoder": [False],  # Option to freeze or fine-tune the encoder
            "layer_PE_nodes": [[2**6, 10]],
            "layer_PE_activations": [["relu", "softmax"]],
            "_lambda": [1],
        },
        # In Order...
        "learning_rate": {
            "learning_rate": [2**-20, 2**-18, 2**-16, 2**-14, 2**-12, 2**-10, 2**-8, 2**-6, 2**-4],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            'loss': ['CE'],
            "sae_model_path": [None],  # Update with actual path
            "freeze_encoder": [False],  # Option to freeze or fine-tune the encoder
            "layer_PE_nodes": [[2**6, 10]],
            "layer_PE_activations": [["relu", "softmax"]],
            "_lambda": [1],
        },
        "dropout_rate": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.0, 0.025, 0.05, 0.075, 0.1, 0.125, 0.15, 0.175, 0.2, 0.225, 0.25],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            'loss': ['CE'],
            "sae_model_path": [None],  # Update with actual path
            "freeze_encoder": [False],  # Option to freeze or fine-tune the encoder
            "layer_PE_nodes": [[2**6, 10]],
            "layer_PE_activations": [["relu", "softmax"]],
            "_lambda": [1],
        },
        "batch_size": {
            "learning_rate": [2**-10],
            "batch_size": [32, 64, 128, 256, 512],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            'loss': ['CE'],
            "sae_model_path": [None],  # Update with actual path
            "freeze_encoder": [False],  # Option to freeze or fine-tune the encoder
            "layer_PE_nodes": [[2**6, 10]],
            "layer_PE_activations": [["relu", "softmax"]],
            "_lambda": [1],
        },
        "_lambda": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            'loss': ['CE'],
            "sae_model_path": [None],  # Update with actual path
            "freeze_encoder": [False],  # Option to freeze or fine-tune the encoder
            "layer_PE_nodes": [[2**6, 10]],
            "layer_PE_activations": [["relu", "softmax"]],
            "_lambda": [2**-12, 2**-10, 2**-8, 2**-6, 2**-4, 2**-2, 2**0, 2**2, 2**4],
        },
        "code_length": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10+4*k] for k in range(30)],
            "encoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            'loss': ['CE'],
            "sae_model_path": [None],  # Update with actual path
            "freeze_encoder": [False],  # Option to freeze or fine-tune the encoder
            "layer_PE_nodes": [[2**6, 10]],
            "layer_PE_activations": [["relu", "softmax"]],
            "_lambda": [1],
        },
        "3d_grid": {
            "learning_rate": [2**-10],
            "batch_size": [32, 64, 128, 256, 512],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "encoder_PE_nodes": [[800, 200, 10+4*k] for k in range(20)],
            "encoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["lrelu", "lrelu", "sigmoid"]],
            'loss': ['CE'],
            "sae_model_path": [None],  # Update with actual path
            "freeze_encoder": [False],  # Option to freeze or fine-tune the encoder
            "layer_PE_nodes": [[2**7, 10]],
            "layer_PE_activations": [["relu", "softmax"]],
            "_lambda": [2**-12, 2**-10, 2**-8, 2**-6, 2**-4, 2**-2, 2**0, 2**2, 2**4],
        },
        # ... Add other hyperparameters if needed
    }
    p_griddy_VAE = {
        "unit-test": {
            "learning_rate": [0.01],
            "batch_size": [32],
            "epochs": [5],
            "dropout_rate": [0.0],
            "input_size": [784],
            "latent_dim": [10],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["linear", "linear", "linear"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["linear", "linear", "sigmoid"]],
            'loss': ['MSE'],
            '_lambda': [0],
            'norm_type': [2]
        },
        # In Order...
        "learning_rate": {
            "learning_rate": [2**-30, 2**-28, 2**-26, 2**-24, 2**-22, 2**-20, 2**-18, 2**-16, 2**-14, 2**-12, 2**-10, 2**-8, 2**-6, 2**-4],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "latent_dim": [10],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["relu", "sigmoid", "relu"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["relu", "sigmoid", "sigmoid"]],
            'loss': ['custom'],
            '_lambda': [2**-12],
            'norm_type': [2]
        },
        "dropout_rate": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5],
            "input_size": [784],
            "latent_dim": [10],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["relu", "sigmoid", "relu"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["relu", "sigmoid", "sigmoid"]],
            'loss': ['custom'],
            '_lambda': [2**-12],
            'norm_type': [2]
        },
        "layer_PE_activations": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "latent_dim": [10],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": encoder_combinations,
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": decoder_combinations,
            'loss': ['custom'],
            '_lambda': [2**-12],
            'norm_type': [2]
        },
        "_lambda": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "latent_dim": [10],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["relu", "sigmoid", "relu"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["relu", "sigmoid", "sigmoid"]],
            'loss': ['custom'],
            '_lambda': [2**-20, 2**-19, 2**-18, 2**-17, 2**-16, 2**-15, 2**-14, 2**-13, 2**-12, 2**-11, 2**-10, 2**-9, 2**-8, 2**-7, 2**-6, 2**-5, 2**-4, 2**-3, 2**-2, 2**-1, 2**0],
            'norm_type': [2]
        },
        "norm_type": {
            "learning_rate": [2**-10],
            "batch_size": [128],
            "epochs": [1024],
            "dropout_rate": [0.1],
            "input_size": [784],
            "latent_dim": [10],
            "encoder_PE_nodes": [[800, 200, 10]],
            "encoder_PE_activations": [["relu", "sigmoid", "relu"]],
            "decoder_PE_nodes": [[200, 800, 784]],
            "decoder_PE_activations": [["relu", "sigmoid", "sigmoid"]],
            'loss': ['custom'],
            '_lambda': [2**-12],
            'norm_type': [0, 0.25, 0.5, 0.75, 1, 1.25, 1.5, 1.75, 2, 2.25, 2.5, 2.75, 3, 3.25, 3.5, 3.75, 4],
        },
    }
    # If just keys are needed
    if archtype == "ANN" and searchkey == None:
        return p_griddy_ANN.keys()
    elif archtype == "SAE" and searchkey == None:
        return p_griddy_SAE.keys()
    elif archtype == "SAEANN" and searchkey == None:
        return p_griddy_SAEANN.keys()
    elif archtype == "VAE" and searchkey == None:
        return p_griddy_VAE.keys()
    # If the search grids are needed
    elif archtype == "ANN" and searchkey in p_griddy_ANN:
        return p_griddy_ANN[searchkey]
    elif archtype == "SAE" and searchkey in p_griddy_SAE:
        return p_griddy_SAE[searchkey]
    elif archtype == "SAEANN" and searchkey in p_griddy_SAEANN:
        return p_griddy_SAEANN[searchkey]
    elif archtype == "VAE" and searchkey in p_griddy_VAE:
        return p_griddy_VAE[searchkey]
    else:
        raise ValueError(f"Unsupported archtype '{archtype}' or hyperparameter '{searchkey}'.")

def load_and_preprocess_data():
    tr_X = np.load('kmnist-train-imgs.npz')['arr_0']
    tr_Y = np.load('kmnist-train-labels.npz')['arr_0']
    ts_X = np.load('kmnist-test-imgs.npz')['arr_0']
    ts_Y = np.load('kmnist-test-labels.npz')['arr_0']
    
    # Normalize Data
    tr_X = tr_X / 255.0
    ts_X = ts_X / 255.0
    
    # Do not one-hot encode labels
    return tr_X, tr_Y, ts_X, ts_Y

def train_val_split(X, Y, val_size=0.2):
    '''
    Split data into training and validation sets.

    Parameters:
    - X: input data (n_samples, n_features)
    - Y: true output (n_samples, n_output)
    - val_size: fraction of the dataset to be used as validation (default: 0.2)

    Returns:
    - X_train, Y_train: training data
    - X_val, Y_val: validation data
    '''
    n_samples = X.shape[0]
    split_idx = int(n_samples * (1 - val_size))
    
    # Shuffle the data
    perm = np.random.permutation(n_samples)
    X_shuffled, Y_shuffled = X[perm], Y[perm]
    
    # Split into training and validation sets
    X_train, Y_train = X_shuffled[:split_idx], Y_shuffled[:split_idx]
    X_val, Y_val = X_shuffled[split_idx:], Y_shuffled[split_idx:]
    
    return X_train, Y_train, X_val, Y_val

def split_by_class(X, Y):
    '''
    Split data by class label.

    Parameters:
    - X: input data (n_samples, n_features)
    - Y: true output (n_samples, n_output)

    Returns:
    - class_data: dictionary where keys are class labels and values are tuples (X_class, Y_class)
    '''
    class_data = {}
    for i in range(10):
        class_data[i] = (X[Y == i], Y[Y == i])
    return class_data

def split_data(X, Y, k, seed=0):
    '''
    Split data into k folds for cross-validation.

    Parameters:
    - X: input data (n_samples, n_features)
    - Y: true output (n_samples, n_output)
    - k: number of folds
    
    Returns:
    - folds: list of tuples, where each tuple contains (X_train, Y_train, X_val, Y_val)
    '''
    np.random.seed(seed)
    n_samples = X.shape[0]
    fold_size = n_samples // k
    indices = np.random.permutation(n_samples)
    
    folds = []
    for i in range(k):
        val_indices = indices[i * fold_size: (i + 1) * fold_size]
        train_indices = np.concatenate([indices[:i * fold_size], indices[(i + 1) * fold_size:]])
        
        X_train, Y_train = X[train_indices], Y[train_indices]
        X_val, Y_val = X[val_indices], Y[val_indices]
        folds.append((X_train, Y_train, X_val, Y_val))
    
    return folds

## Model Definitions

In [3]:
class ConfigurableANN(nn.Module):
    """
    Configurable Artificial Neural Network (ANN) implementation.
    """
    def __init__(self, config):
        super(ConfigurableANN, self).__init__()
        layers = []
        input_size = config['input_size']
        
        # Extract layer configurations from the config dictionary
        layer_PE_nodes = config['layer_PE_nodes']  # Assuming the first (or only) list of nodes
        layer_PE_activations = config['layer_PE_activations']  # Assuming the first (or only) list of activations

        # Add fully connected layers
        for i, units in enumerate(layer_PE_nodes):
            layers.append(nn.Linear(input_size, units))
            layers.append(nn.BatchNorm1d(units))
            layers.append(self._get_activation(layer_PE_activations[i]))
            if 'dropout_rate' in config and config['dropout_rate'] > 0:
                layers.append(nn.Dropout(config['dropout_rate']))
            input_size = units

        self.network = nn.Sequential(*layers)

    def _get_activation(self, activation_name):
        """
        Helper method to return the activation function based on the name.
        """
        activations = {
            "relu": nn.ReLU(),
            "sigmoid": nn.Sigmoid(),
            "tanh": nn.Tanh(),
            "linear": nn.Identity(),
            "softmax": nn.Softmax(dim=1),
            "lrelu": nn.LeakyReLU()
        }
        return activations.get(activation_name.lower(), nn.Identity())

    def forward(self, x):
        return self.network(x)
    
    def fit(self, train_loader, val_loader, config, device='cuda'):
        """
        Train the ANN model using the given configuration.
        """
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.parameters(), lr=config['learning_rate'])
        patience = 16
        best_val_loss = np.inf
        
        # Training loop
        train_losses, val_losses = [], []
        for epoch in range(config['epochs']):
            super().train()
            train_loss = 0.0
            for i, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
                optimizer.zero_grad()
                outputs = self(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
            train_losses.append(train_loss / len(train_loader))

            # Validation loss
            super().eval()
            val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
                    outputs = self(inputs)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item()
                val_losses.append(val_loss / len(val_loader))

            # Early stopping
            if epoch > 0 and val_losses[-1] > best_val_loss:
                patience -= 1
                if patience == 0:
                    print(f"Early stopping at epoch {epoch+1} - Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")
                    break
            else:
                best_val_loss = val_losses[-1]
                patience = 16

            # Print progress
            if epoch % 16 == 15:
                print(f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")

        return train_losses, val_losses
    
    def predict(self, X, device='cuda'):
        """
        Make predictions using the trained model.
        """
        super().eval()
        X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
        X_tensor = X_tensor.view(X_tensor.size(0), -1)  # Flatten the input
        with torch.no_grad():
            outputs = self(X_tensor)
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
        return predictions
    
    def evaluate(self, val_loader, device='cuda'):
        """
        Evaluate the model using the given data.
        """
        super().eval()
        loss = 0.0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()
        X, Y = val_loader.dataset.tensors
        X, Y = X.to(device), Y.to(device)
        X = X.view(X.size(0), -1)  # Flatten the input
        with torch.no_grad():
            outputs = self(X)
            loss = criterion(outputs, Y).item()
            predictions = torch.argmax(outputs, dim=1).cpu().numpy()
            Y = Y.cpu().numpy()
            accuracy = (predictions == Y).mean()
        return loss, accuracy
    
    def save(self, path):
        """
        Save the model to the given path.
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """
        Load the model from the given path.
        """
        self.load_state_dict(torch.load(path))

class ConfigurableSAE(nn.Module):
    """
    Configurable Stacked Autoencoder implementation.
    """
    def __init__(self, config):
        super(ConfigurableSAE, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self._lambda = config.get('_lambda', 0.0)
        self.norm_type = config.get('norm_type', 2)  # Default to L2 norm
        dropout_rate = config.get('dropout_rate', 0.0)

        # Extract encoder configuration
        encoder_PE_nodes = config['encoder_PE_nodes']
        encoder_PE_activations = config['encoder_PE_activations']

        # Build encoder layers
        input_size = config['input_size']
        for i, units in enumerate(encoder_PE_nodes):
            self.encoder.append(nn.Linear(input_size, units))
            self.encoder.append(nn.BatchNorm1d(units))
            self.encoder.append(self._get_activation(encoder_PE_activations[i]))
            if dropout_rate > 0:
                self.encoder.append(nn.Dropout(dropout_rate))
            input_size = units  # Update for the next layer

        # Latent representation
        self.code_size = encoder_PE_nodes[-1]

        # Extract decoder configuration
        decoder_PE_nodes = config['decoder_PE_nodes']
        decoder_PE_activations = config['decoder_PE_activations']

        # Build decoder layers
        for i, units in enumerate(decoder_PE_nodes):
            self.decoder.append(nn.Linear(input_size, units))
            self.decoder.append(nn.BatchNorm1d(units))
            self.decoder.append(self._get_activation(decoder_PE_activations[i]))
            if dropout_rate > 0:
                self.decoder.append(nn.Dropout(dropout_rate))
            input_size = units  # Update for the next layer

    def _get_activation(self, activation_name):
        """
        Helper method to return the activation function based on the name.
        """
        activations = {
            "relu": nn.ReLU(),
            "sigmoid": nn.Sigmoid(),
            "tanh": nn.Tanh(),
            "linear": nn.Identity(),
            "softmax": nn.Softmax(dim=1),
            "lrelu": nn.LeakyReLU()
        }
        return activations.get(activation_name.lower(), nn.Identity())

    def encode(self, x, device='cuda'):
        """
        Encode the input data.
        """
        for layer in self.encoder:
            x = layer(x)
        return x
    
    def decode(self, x, device='cuda'):
        """
        Decode the latent representation.
        """
        for layer in self.decoder:
            x = layer(x)
        return x

    def forward(self, x):
        """
        Forward pass of the SAE model.
        """

        x = self.encode(x) # Encode the input
        self.codes = x     # Save the latent representation
        x = self.decode(x) # Decode the latent representation

        return x

    def fit(self, train_loader, val_loader, config, device='cuda'):
        """
        Train the SAE model using the given configuration.
        """
        
        # Define loss function and optimizer
        if config['loss'] == 'MSE':
            criterion = nn.MSELoss()
        elif config['loss'] == 'custom':
            criterion = scsae_loss
        elif config['loss'] == 'CE':
            criterion = nn.CrossEntropyLoss()
        else:
            raise ValueError(f"Unsupported loss function '{config['loss']}'")
        optimizer = optim.Adam(self.parameters(), lr=config['learning_rate'])
        patience = 16
        best_val_loss = np.inf
        
        # Training loop
        train_losses, val_losses = [], []
        for epoch in range(config['epochs']):
            super().train()
            train_loss = 0.0
            for i, (inputs, targets) in enumerate(train_loader):
                inputs, targets = inputs.to(device), targets.to(device)
                inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
                optimizer.zero_grad()
                outputs = self(inputs)
                if config['loss'] == 'custom':
                    loss = criterion(inputs, outputs, self.codes, targets, config['norm_type'], config['_lambda'])
                else:
                    loss = criterion(outputs, inputs)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
            train_losses.append(train_loss / len(train_loader))

            # Validation loss
            super().eval()
            val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    inputs = inputs.view(inputs.size(0), -1)
                    outputs = self(inputs)
                    if config['loss'] == 'custom':
                        loss = criterion(inputs, outputs, self.codes, targets, config['norm_type'], config['_lambda'])
                    else:
                        loss = criterion(outputs, inputs)
                    val_loss += loss.item()
                val_losses.append(val_loss / len(val_loader))

            # Early stopping
            if epoch > 0 and val_losses[-1] > best_val_loss:
                patience -= 1
                if patience == 0:
                    print(f"Early stopping at epoch {epoch+1} - Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")
                    break
            else:
                best_val_loss = val_losses[-1]
                patience = 16

            # Print progress
            if epoch % 16 == 15:
                print(f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")

        return train_losses, val_losses
    
    def predict(self, X, device='cuda'):
        """
        Make predictions using the trained model.
        """
        super().eval()
        X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
        X_tensor = X_tensor.view(X_tensor.size(0), -1)
        with torch.no_grad():
            outputs = self(X_tensor)
            predictions = outputs.cpu().numpy()
        return predictions
    
    def evaluate_reconstruction(self, val_loader, device='cuda'):
        """
        Evaluate the model using the given data.
        """
        super().eval()
        total_loss = 0.0
        total_samples = 0
        criterion = nn.MSELoss()
        with torch.no_grad():
            for inputs, _ in val_loader:
                inputs = inputs.to(device)
                inputs = inputs.view(inputs.size(0), -1)
                outputs = self(inputs)
                loss = criterion(outputs, inputs)
                total_loss += loss.item() * inputs.size(0)
                total_samples += inputs.size(0)
        avg_loss = total_loss / total_samples
        return avg_loss
    
    def evaluate_code(self, val_loader, device='cuda'):
        """
        Evaluate the model using the given data.
        Returns:
            avg_loss: Average MSE loss between codes and one-hot encoded labels.
            avg_accuracy: Classification accuracy based on codes.
        """
        self.eval()
        total_loss = 0.0
        correct = 0
        total_samples = 0
        criterion = nn.MSELoss()
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                inputs = inputs.view(inputs.size(0), -1)
                codes = self.encode(inputs)
                labels_onehot = F.one_hot(labels, num_classes=self.code_size).float()
                loss = criterion(codes, labels_onehot)
                total_loss += loss.item() * inputs.size(0)
                total_samples += inputs.size(0)
                predictions = torch.argmax(codes, dim=1)
                correct += (predictions == labels).sum().item()
        avg_loss = total_loss / total_samples
        avg_accuracy = correct / total_samples
        return avg_loss, avg_accuracy

    def evaluate(self, val_loader, device='cuda'):
        """
        Evaluate the model using evaluate_reconstruction and evaluate_code.
        """
        avg_reconstruction_loss = self.evaluate_reconstruction(val_loader, device)
        avg_code_loss, avg_accuracy = self.evaluate_code(val_loader, device)
        return avg_reconstruction_loss + avg_code_loss, avg_accuracy
    
    def save(self, path):
        """
        Save the model to the given path.
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """
        Load the model from the given path.
        """
        self.load_state_dict(torch.load(path))

class ConfigurableSAEANN(nn.Module):
    """
    Configurable Stacked Autoencoder - Artificial Neural Network (SAE-ANN) implementation.
    """
    def __init__(self, sae_model_path, config, ann_model_path=None, freeze_encoder=False, device='cuda'):
        super(ConfigurableSAEANN, self).__init__()
        
        if sae_model_path is not(None):
            # Load the SAE checkpoint
            checkpoint = torch.load(sae_model_path, map_location=device)
            sae_state_dict = checkpoint['model_state_dict']
            sae_config = checkpoint['config']  # Load the saved configuration
            
            # Initialize the SAE model with the loaded configuration
            self.sae_model = ConfigurableSAE(sae_config).to(device)
            self.sae_model.load_state_dict(sae_state_dict)
        else:
            self.sae_model = self._build_sae(config)
        
        # Optionally freeze the encoder
        if freeze_encoder:
            for param in self.sae_model.encoder.parameters():
                param.requires_grad = False
        
        # Initialize the ANN classifier
        self.classifier = self._build_classifier(config)
        
        # Optionally load ANN classifier weights
        if ann_model_path is not None:
            self.classifier.load_state_dict(torch.load(ann_model_path, map_location=device))

    def _build_sae(self, config):
        """
        Builds the SAE part of the SAE-ANN based on the configuration.
        """
        return ConfigurableSAE(config)

    def _build_classifier(self, config):
        """
        Builds the classifier part of the ANN based on the configuration.
        """
        layers = []
        input_size = config['encoder_PE_nodes'][-1]  # Should match the encoder's output size
        layer_PE_nodes = config['layer_PE_nodes']
        layer_PE_activations = config['layer_PE_activations']
        
        for i, units in enumerate(layer_PE_nodes):
            layers.append(nn.Linear(input_size, units))
            layers.append(nn.BatchNorm1d(units))
            layers.append(self._get_activation(layer_PE_activations[i]))
            if 'dropout_rate' in config and config['dropout_rate'] > 0:
                layers.append(nn.Dropout(config['dropout_rate']))
            input_size = units
        
        return nn.Sequential(*layers)

    def _get_activation(self, activation_name):
        """
        Helper method to return the activation function based on the name.
        """
        activations = {
            "relu": nn.ReLU(),
            "sigmoid": nn.Sigmoid(),
            "tanh": nn.Tanh(),
            "linear": nn.Identity(),
            "softmax": nn.Softmax(dim=1),
            "lrelu": nn.LeakyReLU()
        }
        return activations.get(activation_name.lower(), nn.Identity())

    def forward(self, x, device='cuda'):
        # Encoder pass
        codes = self.sae_model.encode(x)
        classification = self.classifier(codes)
        reconstruction = self.sae_model.decode(codes)
        return classification, reconstruction

    def freeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

    def unfreeze_encoder(self):
        for param in self.encoder.parameters():
            param.requires_grad = True

    def fit(self, train_loader, val_loader, config, device='cuda'):
        """
        Train the SAE-ANN model in two stages per epoch: classification and reconstruction.
        """
        self.to(device)
        criterion_classification = nn.CrossEntropyLoss()
        optimizer_classifier = optim.Adam(
            list(self.classifier.parameters()) + list(self.sae_model.encoder.parameters()), 
            lr=config['learning_rate']
        )
        
        criterion_reconstruction = nn.MSELoss()
        optimizer_reconstruction = optim.Adam(
            list(self.sae_model.decoder.parameters()) + list(self.sae_model.encoder.parameters()), 
            lr=config['learning_rate']
        )
        patience = 16
        best_val_loss = np.inf
        _lambda = config.get('_lambda', 1.0)
        
        train_losses, val_losses = [], []
        for epoch in range(config['epochs']):
            super().train()
            train_loss = 0.0

            # Classification phase
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                inputs = inputs.view(inputs.size(0), -1)
                optimizer_classifier.zero_grad()
                classification, _ = self(inputs)
                loss_classification = _lambda*criterion_classification(classification, targets)
                loss_classification.backward()
                optimizer_classifier.step()
                train_loss += loss_classification.item()

            # Reconstruction phase
            for inputs, _ in train_loader:
                inputs = inputs.to(device)
                inputs = inputs.view(inputs.size(0), -1)
                optimizer_reconstruction.zero_grad()
                _, reconstruction = self(inputs)
                loss_reconstruction = criterion_reconstruction(reconstruction, inputs)
                loss_reconstruction.backward()
                optimizer_reconstruction.step()
                train_loss += loss_reconstruction.item()

            train_losses.append(train_loss / len(train_loader))

            # Validation phase
            super().eval()
            val_loss = 0.0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(device), targets.to(device)
                    inputs = inputs.view(inputs.size(0), -1)
                    classification, reconstruction = self(inputs)
                    loss_classification = criterion_classification(classification, targets)
                    loss_reconstruction = criterion_reconstruction(reconstruction, inputs)
                    val_loss += loss_classification.item() + loss_reconstruction.item()
                val_losses.append(val_loss / len(val_loader))

            # Early stopping
            if epoch > 0 and val_losses[-1] > best_val_loss:
                patience -= 1
                if patience == 0:
                    print(f"Early stopping at epoch {epoch+1} - Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")
                    break
            else:
                best_val_loss = val_losses[-1]
                patience = 16

            # Print progress
            if epoch % 16 == 15:
                print(f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")

        return train_losses, val_losses
    
    def predict(self, val_loader, device='cuda'):
        """
        Make predictions using the trained model.
        """
        super().eval()
        predictions = []
        with torch.no_grad():
            for inputs, _ in val_loader:
                inputs = inputs.to(device)
                codes = self.sae_model.encode(inputs)  # Pass inputs through encoder
                outputs = self.classifier(codes)  # Classify using the ANN
                batch_predictions = torch.argmax(outputs, dim=1).cpu().numpy()
                predictions.extend(batch_predictions)
        return predictions

    def evaluate(self, val_loader, device='cuda'):
        """
        Evaluate the model using the given data.
        """
        super().eval()
        total_loss = 0.0
        correct = 0
        total_samples = 0
        criterion_classification = nn.CrossEntropyLoss()
        criterion_reconstruction = nn.MSELoss()

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                inputs = inputs.view(inputs.size(0), -1)
                codes = self.sae_model.encode(inputs)  # Pass inputs through encoder
                classification = self.classifier(codes)  # Classify using the ANN
                reconstruction = self.sae_model.decode(codes)  # Reconstruct the input
                loss = criterion_classification(classification, targets) + criterion_reconstruction(reconstruction, inputs)
                total_loss += loss.item()
                predictions = torch.argmax(classification, dim=1)
                correct += (predictions == targets).sum().item()
                total_samples += targets.size(0)

        avg_loss = total_loss / len(val_loader)
        accuracy = correct / total_samples
        return avg_loss, accuracy

    def save(self, path):
        """
        Save the model to the given path.
        """
        torch.save(self.state_dict(), path)

    def load(self, path):
        """
        Load the model from the given path.
        """
        self.load_state_dict(torch.load(path))

class ConfigurableVAE(nn.Module):
    """
    Configurable Variational Autoencoder (VAE) implementation.
    """
    def __init__(self, config):
        super(ConfigurableVAE, self).__init__()
        
        # Encoder configurations
        input_size = config["input_size"]
        encoder_PE_nodes = config["encoder_PE_nodes"]
        encoder_PE_activations = config["encoder_PE_activations"]
        self.latent_dim = config["latent_dim"]
        
        # Decoder configurations
        decoder_PE_nodes = config["decoder_PE_nodes"]
        decoder_PE_activations = config["decoder_PE_activations"]


        # Build encoder
        self.encoder = nn.ModuleList()
        for i, units in enumerate(encoder_PE_nodes):
            self.encoder.append(nn.Linear(input_size, units))
            self.encoder.append(self._get_activation(encoder_PE_activations[i]))
            input_size = units
        
        # Latent space
        self.mu_layer = nn.Linear(input_size, self.latent_dim)
        self.logvar_layer = nn.Linear(input_size, self.latent_dim)
        
        # Build decoder
        self.decoder = nn.ModuleList()
        for i, units in enumerate(decoder_PE_nodes):
            self.decoder.append(nn.Linear(self.latent_dim, units))
            self.decoder.append(self._get_activation(decoder_PE_activations[i]))
            self.latent_dim = units
    
    def _get_activation(self, activation_name):
        """
        Helper method to return the activation function based on the name.
        """
        activations = {
            "relu": nn.ReLU(),
            "sigmoid": nn.Sigmoid(),
            "tanh": nn.Tanh(),
            "linear": nn.Identity(),
            "softmax": nn.Softmax(dim=1),
            "lrelu": nn.LeakyReLU()
        }
        return activations.get(activation_name.lower(), nn.Identity())
    
    def encode(self, x):
        """
        Encoder pass: compute latent mean and log-variance.
        """
        for layer in self.encoder:
            x = layer(x)
        mu = self.mu_layer(x)
        logvar = self.logvar_layer(x)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick to sample from N(mu, var) using N(0, 1).
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        """
        Decoder pass: reconstruct input from latent space.
        """
        for layer in self.decoder:
            z = layer(z)

        return z
    
    def forward(self, x):
        """
        Full VAE forward pass.
        """
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

    def fit(self, train_loader, val_loader, config, device='cuda'):
        """
        Train the VAE model using the given configuration.
        """
        
        # Define loss function and optimizer
        optimizer = optim.Adam(self.parameters(), lr=config['learning_rate'])
        best_val_loss = np.inf
        
        # Training loop
        train_losses, val_losses = [], []
        for epoch in range(config['epochs']):
            super().train()
            train_loss = 0.0
            for i, (inputs, _) in enumerate(train_loader):
                inputs = inputs.to(device)
                inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
                optimizer.zero_grad()
                recon_x, mu, logvar = self(inputs)
                loss = vae_loss(recon_x, inputs, mu, logvar)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
            train_losses.append(train_loss / len(train_loader))

            # Validation loss
            super().eval()
            val_loss = 0.0
            with torch.no_grad():
                for inputs, _ in val_loader:
                    inputs = inputs.to(device)
                    inputs = inputs.view(inputs.size(0), -1)  # Flatten the input
                    recon_x, mu, logvar = self(inputs)
                    loss = vae_loss(recon_x, inputs, mu, logvar)
                    val_loss += loss.item()
                val_losses.append(val_loss / len(val_loader))

            # Early stopping
            if epoch > 0 and val_losses[-1] > best_val_loss:
                patience -= 1
                if patience == 0:
                    print(f"Early stopping at epoch {epoch+1} - Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")
                    break
            else:
                best_val_loss = val_losses[-1]
                patience = 16

            # Print progress
            if epoch % 16 == 15:
                print(f"Epoch {epoch+1}/{config['epochs']} - Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")

        return train_losses, val_losses
    
    def predict(self, X, device='cuda'):
        """
        Make predictions using the trained model.
        """
        super().eval()
        X_tensor = torch.tensor(X, dtype=torch.float32).to(device)
        with torch.no_grad():
            outputs, _, _ = self(X_tensor)
            predictions = outputs.cpu().numpy()
        return predictions
    
    def evaluate_reconstruction(self, val_loader, device='cuda'):
        """
        Evaluate the model using the given data.
        """
        super().eval()
        total_loss = 0.0
        total_samples = 0
        with torch.no_grad():
            for inputs, _ in val_loader:
                inputs = inputs.to(device)
                inputs = inputs.view(inputs.size(0), -1)
                recon_x, mu, logvar = self(inputs)
                loss = nn.MSELoss()(recon_x, inputs).item()
                total_loss += loss
                total_samples += inputs.size(0)
        avg_loss = total_loss / total_samples
        return avg_loss
    
    def evaluate_code(self, val_loader, device='cuda'):
        """
        Evaluate the model using the given data.
        """
        self.eval()
        total_loss = 0.0
        correct = 0
        total_samples = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)  # Ensure labels are on the correct device
                inputs = inputs.view(inputs.size(0), -1)
                _, mu, logvar = self(inputs)
                z = self.reparameterize(mu, logvar)
                outputs = self.decode(z)
                loss = nn.MSELoss()(outputs, inputs)
                total_loss += loss.item() * inputs.size(0)  # Multiply by batch size if needed
                total_samples += inputs.size(0)
                predictions = torch.argmax(z, dim=1)
                correct += (predictions == labels).float().sum().item()  # Convert to float before summing
        avg_loss = total_loss / total_samples
        avg_accuracy = correct / total_samples
        return avg_loss, avg_accuracy

    def evaluate(self, val_loader, device='cuda'):
        """
        Evaluate the model using evaluate_reconstruction and evaluate_code.
        """
        avg_reconstruction_loss = self.evaluate_reconstruction(val_loader, device)
        avg_code_loss, accuracy = self.evaluate_code(val_loader, device)
        return avg_reconstruction_loss + avg_code_loss, accuracy

    def save(self, path):
        """
        Save the model to the given path.
        """
        torch.save({
            'model_state_dict': self.state_dict(),
            'config': {
                'input_size': self.input_size,
                'encoder_PE_nodes': self.encoder_PE_nodes,
                'encoder_PE_activations': self.encoder_PE_activations,
                'latent_dim': self.latent_dim,
                'decoder_PE_nodes': self.decoder_PE_nodes,
                'decoder_PE_activations': self.decoder_PE_activations,
                'output_size': self.output_size,
                'output_activation': self.output_activation
            }
        }, path)

    def load(self, path):
        """
        Load the model from the given path.
        """
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.input_size = checkpoint['config']['input_size']
        self.encoder_PE_nodes = checkpoint['config']['encoder_PE_nodes']
        self.encoder_PE_activations = checkpoint['config']['encoder_PE_activations']
        self.latent_dim = checkpoint['config']['latent_dim']
        self.decoder_PE_nodes = checkpoint['config']['decoder_PE_nodes']
        self.decoder_PE_activations = checkpoint['config']['decoder_PE_activations']
        self.output_size = checkpoint['config']['output_size']
        self.output_activation = checkpoint['config']['output_activation']

## Loss Functions

In [4]:
def vae_loss(recon_x, x, mu, logvar):
    """
    VAE Loss: Reconstruction + KL Divergence.
    """
    recon_loss = nn.MSELoss()(recon_x, x)
    kl_divergence = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_divergence

def scsae_loss(inputs, outputs, codes, targets, norm_type=2, _lambda=0.0):
    """
    Custom loss function for Stacked Autoencoders with adjustable sparsity norm and lambda.
    """
    # Reconstruction loss
    recon_loss = nn.MSELoss()(outputs, inputs)
    
    # Convert targets to one-hot encoding
    labels_onehot = F.one_hot(targets, num_classes=codes.shape[1]).float()
    
    # Ensure labels_onehot is on the same device as codes
    labels_onehot = labels_onehot.to(codes.device)
    
    # Sparsity penalty
    sparsity_loss = torch.mean(torch.norm(codes - labels_onehot, p=norm_type, dim=1))
    
    # Total loss
    total_loss = recon_loss + _lambda * sparsity_loss
    return total_loss

## Model Information Searching

In [5]:
def grid_search_cv(train_data, val_data, param_grid, model_class, search_key, reps=5, metric="loss", device='cuda', save_path="./"):
    """
    Perform grid search to find the best hyperparameters for a given model class.
    """
    param_keys = list(param_grid.keys())
    param_values = [param_grid[key] for key in param_keys]
    param_combinations = [dict(zip(param_keys, v)) for v in product(*param_values)]
    results = []
    best_model = None
    best_params = None

    # Create Save Path
    folder_path = os.path.join(save_path, search_key)
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
    # Initialize best score and comparison function
    if metric == "loss":
        best_score = float('inf')
        score_compare = lambda new, best: new < best
    elif metric == "accuracy":
        best_score = float('-inf')
        score_compare = lambda new, best: new > best
    else:
        raise ValueError(f"Unsupported metric '{metric}'")
    # Define a lambda function to get the hyperparameter value
    get_hyperparam_value = lambda params: params['encoder_PE_nodes'][-1] if search_key == 'code_length' else params.get(search_key)
    # Check model class
    if model_class == ConfigurableANN:
        for params in param_combinations:
            print(f"Testing parameters: {params}")
            # Prepare DataLoaders
            train_loader = DataLoader(train_data, batch_size=params.get('batch_size', 32), shuffle=True)
            val_loader = DataLoader(val_data, batch_size=params.get('batch_size', 32), shuffle=False)
            # Initialize lists to store losses and accuracies per repetition
            reps_val_losses = []
            # Initialize lists to store final losses and accuracies
            fin_losses = []
            fin_accuracies = []
            for i in range(reps):
                # Instantiate the model
                model = model_class(params).to(device)
                # Train and evaluate model
                train_losses, val_losses = model.fit(train_loader, val_loader, params, device)
                val_loss, val_accuracy = model.evaluate(val_loader, device)
                reps_val_losses.append(val_losses)
                fin_losses.append(val_loss)
                fin_accuracies.append(val_accuracy)
                # Generate unique model filename
                param_hash = hashlib.md5(str(params).encode()).hexdigest()[:8]
                model_filename = os.path.join(folder_path, f"ConfigurableANN_{search_key}_{i}_{val_losses[-1]:.4f}_{param_hash}.pth")
                # Save model weights
                try:
                    torch.save({'model_state_dict': model.state_dict(),'config': params}, model_filename)
                    print(f"Model saved to {model_filename}")
                except Exception as e:
                    print(f"Error saving model: {e}")
                # Record results
                results.append({
                    **params,
                    search_key: get_hyperparam_value(params),
                    'repetition': i,
                    'val_loss': fin_losses[-1],
                    'val_accuracy': fin_accuracies[-1],
                    'model_filename': model_filename
                })
            max_length = max(len(vl) for vl in reps_val_losses)
            reps_val_losses_padded = [np.pad(vl, (0, max_length - len(vl)), 'constant', constant_values=np.nan) for vl in reps_val_losses]
            reps_val_losses_array = np.array(reps_val_losses_padded)
            avg_val_losses = np.nanmean(reps_val_losses_array, axis=0)
            std_val_losses = np.nanstd(reps_val_losses_array, axis=0)
            # Save losses and accuracies
            np.save(os.path.join(folder_path, f"ConfigurableANN_avg_val_losses_{search_key}_{param_hash}.npy"), avg_val_losses)
            np.save(os.path.join(folder_path, f"ConfigurableANN_std_val_losses_{search_key}_{param_hash}.npy"), std_val_losses)
            # Select score based on model type
            score = np.mean(fin_losses) if metric == "loss" else np.mean(fin_accuracies)
            # Update the best model
            if score_compare(score, best_score):
                best_score = score
                best_model = model
                best_params = params
        # Save results
        results_df = pd.DataFrame(results)
        results_df.to_csv(os.path.join(folder_path, f"ConfigurableANN_results_{search_key}.csv"), index=False)
        return results_df, best_model, best_params
    elif model_class == ConfigurableSAE:
        for params in param_combinations:
            print(f"Testing parameters: {params}")
            # Prepare DataLoaders
            train_loader = DataLoader(train_data, batch_size=params.get('batch_size', 32), shuffle=True)
            val_loader = DataLoader(val_data, batch_size=params.get('batch_size', 32), shuffle=False)
            # Initialize lists to store losses and accuracies per repetition
            reps_val_losses = []
            reps_val_accuracies = []
            # Initialize lists to store final losses and accuracies
            fin_losses = []
            fin_accuracies = []
            for i in range(reps):
                # Instantiate the model
                model = model_class(params).to(device)
                # Train and evaluate model
                train_losses, val_losses = model.fit(train_loader, val_loader, params, device)
                reconstr_loss = model.evaluate_reconstruction(val_loader, device)
                code_loss, code_accuracy = model.evaluate_code(val_loader, device)
                reps_val_losses.append(val_losses)
                fin_losses.append(code_loss + reconstr_loss)
                fin_accuracies.append(code_accuracy)
                # Generate unique model filename
                param_hash = hashlib.md5(str(params).encode()).hexdigest()[:8]
                model_filename = os.path.join(folder_path, f"ConfigurableSAE_{search_key}_{i}_{val_losses[-1]:.4f}_{param_hash}.pth")
                # Save model weights
                try:
                    torch.save({'model_state_dict': model.state_dict(),'config': params}, model_filename)
                    print(f"Model saved to {model_filename}")
                except Exception as e:
                    print(f"Error saving model: {e}")
                # Get the hyperparameter value using the lambda function
                hyperparam_value = get_hyperparam_value(params)
                # Record results
                results.append({
                    **params,
                    search_key: hyperparam_value,
                    'repetition': i,
                    'val_loss': val_losses[-1],
                    'val_accuracy': fin_accuracies[-1],
                    'model_filename': model_filename
                })
            max_length = max(len(vl) for vl in reps_val_losses)
            reps_val_losses_padded = [np.pad(vl, (0, max_length - len(vl)), 'constant', constant_values=np.nan) for vl in reps_val_losses]
            reps_val_losses_array = np.array(reps_val_losses_padded)
            avg_val_losses = np.nanmean(reps_val_losses_array, axis=0)
            std_val_losses = np.nanstd(reps_val_losses_array, axis=0)
            # Save losses and accuracies
            np.save(os.path.join(folder_path, f"ConfigurableSAE_avg_val_losses_{search_key}_{param_hash}.npy"), avg_val_losses)
            np.save(os.path.join(folder_path, f"ConfigurableSAE_std_val_losses_{search_key}_{param_hash}.npy"), std_val_losses)
            # Select score based on model type
            score = np.mean(fin_losses)
            # Update the best model
            if score_compare(score, best_score):
                best_score = score
                best_model = model
                best_params = params
        # Save results
        results_df = pd.DataFrame(results)
        results_df.to_csv(os.path.join(folder_path, f"ConfigurableSAE_results_{search_key}.csv"), index=False)
        return results_df, best_model, best_params
    
    elif model_class == ConfigurableSAEANN:
        for params in param_combinations:
            print(f"Testing parameters: {params}")
            # Prepare DataLoaders
            train_loader = DataLoader(train_data, batch_size=params.get('batch_size', 32), shuffle=True)
            val_loader = DataLoader(val_data, batch_size=params.get('batch_size', 32), shuffle=False)
            # Initialize lists to store losses and accuracies per repetition
            reps_val_losses = []
            reps_val_accuracies = []
            # Initialize lists to store final losses and accuracies
            fin_losses = []
            fin_accuracies = []
            for i in range(reps):
                # Instantiate the model
                model = model_class(
                    sae_model_path=params['sae_model_path'],
                    config=params,
                    ann_model_path=params.get('ann_model_path', None),
                    freeze_encoder=params.get('freeze_encoder', True),
                    device=device
                ).to(device)
                # Train and evaluate model
                train_losses, val_losses = model.fit(train_loader, val_loader, params, device)
                val_loss, val_accuracy = model.evaluate(val_loader, device)
                reps_val_losses.append(val_losses)
                fin_losses.append(val_loss)
                fin_accuracies.append(val_accuracy)
                # Generate unique model filename
                param_hash = hashlib.md5(str(params).encode()).hexdigest()[:8]
                model_filename = os.path.join(folder_path, f"ConfigurableSAEANN_{search_key}_{i}_{val_loss:.4f}_{param_hash}.pth")
                # Save model weights
                try:
                    torch.save({'model_state_dict': model.state_dict(),'config': params}, model_filename)
                    print(f"Model saved to {model_filename}")
                except Exception as e:
                    print(f"Error saving model: {e}")
                # Get the hyperparameter value using the lambda function
                hyperparam_value = get_hyperparam_value(params)
                # Record results
                results.append({
                    **params,
                    search_key: hyperparam_value,
                    'repetition': i,
                    'val_loss': val_loss,
                    'val_accuracy': val_accuracy,
                    'model_filename': model_filename
                })
            max_length = max(len(vl) for vl in reps_val_losses)
            reps_val_losses_padded = [np.pad(vl, (0, max_length - len(vl)), 'constant', constant_values=np.nan) for vl in reps_val_losses]
            reps_val_losses_array = np.array(reps_val_losses_padded)
            avg_val_losses = np.nanmean(reps_val_losses_array, axis=0)
            std_val_losses = np.nanstd(reps_val_losses_array, axis=0)
            # Save losses and accuracies
            np.save(os.path.join(folder_path, f"ConfigurableSAEANN_avg_val_losses_{search_key}_{param_hash}.npy"), avg_val_losses)
            np.save(os.path.join(folder_path, f"ConfigurableSAEANN_std_val_losses_{search_key}_{param_hash}.npy"), std_val_losses)
            # Select score based on model type
            score = np.mean(fin_losses)
            # Update the best model
            if score_compare(score, best_score):
                best_score = score
                best_model = model
                best_params = params
        # Save results
        results_df = pd.DataFrame(results)
        results_df.to_csv(os.path.join(folder_path, f"ConfigurableSAEANN_results_{search_key}.csv"), index=False)
        return results_df, best_model, best_params   
     
    elif model_class == ConfigurableVAE:
        for param in param_combinations:
            print(f"Testing parameters: {param}")
            # Prepare DataLoaders
            train_loader = DataLoader(train_data, batch_size=param.get('batch_size', 32), shuffle=True)
            val_loader = DataLoader(val_data, batch_size=param.get('batch_size', 32), shuffle=False)
            # Initialize lists to store losses and accuracies per repetition
            reps_val_losses = []
            reps_val_accuracies = []
            # Initialize lists to store final losses and accuracies
            fin_losses = []
            fin_accuracies = []
            for i in range(reps):
                # Instantiate the model
                model = model_class(param).to(device)
                # Train and evaluate model
                train_losses, val_losses = model.fit(train_loader, val_loader, param, device)
                val_loss, val_accuracy = model.evaluate(val_loader, device)
                reps_val_losses.append(val_losses)
                fin_losses.append(val_loss)
                fin_accuracies.append(val_accuracy)
                # Generate unique model filename
                param_hash = hashlib.md5(str(param).encode()).hexdigest()[:8]
                model_filename = os.path.join(folder_path, f"ConfigurableVAE_{search_key}_{i}_{val_losses[-1]:.4f}_{param_hash}.pth")
                # Save model weights
                try:
                    torch.save({'model_state_dict': model.state_dict(),'config': param}, model_filename)
                    print(f"Model saved to {model_filename}")
                except Exception as e:
                    print(f"Error saving model: {e}")
                # Get the hyperparameter value using the lambda function
                hyperparam_value = get_hyperparam_value(param)
                # Record results
                results.append({
                    **param,
                    search_key: hyperparam_value,
                    'repetition': i,
                    'val_loss': val_loss,
                    'val_accuracy': val_accuracy,
                    'model_filename': model_filename
                })
            max_length = max(len(vl) for vl in reps_val_losses)
            reps_val_losses_padded = [np.pad(vl, (0, max_length - len(vl)), 'constant', constant_values=np.nan) for vl in reps_val_losses]
            reps_val_losses_array = np.array(reps_val_losses_padded)
            avg_val_losses = np.nanmean(reps_val_losses_array, axis=0)
            std_val_losses = np.nanstd(reps_val_losses_array, axis=0)
            # Save losses and accuracies
            np.save(os.path.join(folder_path, f"ConfigurableVAE_avg_val_losses_{search_key}_{param_hash}.npy"), avg_val_losses)
            np.save(os.path.join(folder_path, f"ConfigurableVAE_std_val_losses_{search_key}_{param_hash}.npy"), std_val_losses)
            # Select score based on model type
            score = np.mean(fin_losses)
            # Update the best model
            if score_compare(score, best_score):
                best_score = score
                best_model = model
                best_params = param
        # Save results
        results_df = pd.DataFrame(results)
        results_df.to_csv(os.path.join(folder_path, f"ConfigurableVAE_results_{search_key}.csv"), index=False)
        return results_df, best_model, best_params

    else:
        raise ValueError(f"Unsupported model class '{model_class}'")

## Model Information Query & Evaluation

In [6]:
def query_best_hyperparameters(results_df, model_class, target_metric=None):
    """
    Query the dataframe for the best performing hyperparameters.
    """
    # Set default target metric based on model class
    if target_metric is None:
        target_metric = 'val_loss' if model_class == ConfigurableSAE else 'val_accuracy'

    # Ensure the column is numeric and contains valid data
    if target_metric not in results_df or results_df[target_metric].isnull().all():
        raise ValueError(f"Target metric '{target_metric}' is not available or contains no valid data.")

    # Filter out rows with NaN or None in the target metric
    valid_results = results_df.dropna(subset=[target_metric])

    if valid_results.empty:
        raise ValueError(f"No valid results available for metric '{target_metric}'.")

    # Find the row with the best value for the target metric
    if model_class == ConfigurableSAE:
        best_row = valid_results.loc[valid_results[target_metric].idxmin()]
    else:
        best_row = valid_results.loc[valid_results[target_metric].idxmax()]
    return best_row.to_dict()

def filter_hyperparameters(results_df, **criteria):
    """
    Return all entries that match the specified hyperparameters.
    """
    filtered_df = results_df
    for key, value in criteria.items():
        filtered_df = filtered_df[filtered_df[key] == value]
    return filtered_df

def plot_hyperparameters(results_df, hyperparameter, metric='val_loss', save_path=None):
    """
    Plot the evaluation metric vs hyperparameters using Matplotlib,
    including error bars representing the standard deviation.
    """
    plt.figure(figsize=(12, 6))

    # Ensure hyperparameter values are hashable (e.g., convert lists to strings)
    if results_df[hyperparameter].apply(lambda x: isinstance(x, list)).any():
        results_df[hyperparameter] = results_df[hyperparameter].apply(lambda x: str(x))

    # Group the results by the hyperparameter
    grouped = results_df.groupby(hyperparameter)[metric]

    # Compute mean and standard deviation
    mean_metrics = grouped.mean()
    std_metrics = grouped.std()

    # Get the hyperparameter values and corresponding metrics
    X = mean_metrics.index
    Y = mean_metrics.values
    Yerr = std_metrics.values

    # If hyperparameter values are strings (non-numeric), convert to numerical indices
    if not np.issubdtype(np.array(X).dtype, np.number):
        X_labels = X
        X = np.arange(len(X))
        plt.xticks(X, X_labels, rotation='vertical')

    # Plot the mean metric with error bars for the standard deviation
    plt.errorbar(X, Y, yerr=Yerr, fmt='o-', capsize=5, elinewidth=2, markeredgewidth=2)

    plt.xlabel(hyperparameter)
    plt.ylabel(metric)
    if hyperparameter in ['learning_rate', '_lambda']:
        plt.xscale('log')
    plt.title(f"{metric.capitalize()} vs {hyperparameter}")
    plt.grid(True)
    plt.tight_layout()  # Adjust layout to prevent label cutoff
    if save_path:
        plt.savefig(save_path)
    plt.show()

def load_model(model_class, filename, device='cuda'):
    """
    Load a model from a file.
    """
    model = model_class()
    model.load_state_dict(torch.load(filename, map_location=device))
    model.to(device)
    return model

## Model Information Generation

In [7]:
def generate_image_from_code(model, code, device='cuda'):
    """
    Generate an image from a given code using the SAE model.
    """
    model.to(device)
    code_tensor = torch.tensor(code, dtype=torch.float32).to(device)
    code_tensor = code_tensor.view(1, -1)  # Reshape to (1, code_size)
    image = model.code_forward(code_tensor)
    return image.cpu().detach().numpy().reshape(28, 28)

def generate_eigenimages(model, device='cuda'):
    """
    Generate eigenimages from the decoder weights of the SAE model.
    """
    eigenimages = []
    model.to(device)
    code_size = model.code_size
    for i in range(code_size):
        code = torch.zeros(1, code_size).to(device)
        code[0, i] = 1.0
        image = model.code_forward(code)
        eigenimage = image.cpu().detach().numpy().reshape(28, 28)
        eigenimages.append(eigenimage)
    return eigenimages

def plot_eigenimages(eigenimages, n_cols=8, save_path=None):
    """
    Plot the eigenimages using Matplotlib.
    """
    n_rows = len(eigenimages) // n_cols + 1
    plt.figure(figsize=(16, 2 * n_rows))
    for i, eigenimage in enumerate(eigenimages):
        plt.subplot(n_rows, n_cols, i + 1)
        plt.imshow(eigenimage, cmap='gray')
        plt.axis('off')
    if save_path:
        plt.savefig(save_path)
    plt.show()

def generate_class_codes(model, data_loader, device='cuda'):
    """
    Generate class codes from the SAE model. Get 1 sample from each class, encode it, and store the code.
    """
    model.to(device)
    class_codes = {}
    with torch.no_grad():
        for X_batch, Y_batch in data_loader:
            X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
            X_batch = X_batch.view(X_batch.size(0), -1)  # Flatten for input to the model
            outputs = model.encoder(X_batch)
            for i, code in enumerate(outputs):
                label = Y_batch[i].item()
                if label not in class_codes:
                    class_codes[label] = code.cpu().detach().numpy()
                    if len(class_codes) == 10:
                        break
    return 

def generate_class_images(model, class_codes, device='cuda'):
    """
    Generate class images from the SAE model using the class codes. Add a 0 mean, varying variance gaussian noise to the code. Decode the noisy code to get the class image.
    """
    model.to(device)
    variances = [2**-4.5, 2**-4, 2**-3.5, 2**-3, 2**-2.5, 2**-2, 2**-1.5, 2**-1, 2**-0.5, 2**0]
    class_images = {}
    for label, code in class_codes.items():
        for variance in variances:
            for i in range(10):
                noisy_code = code + torch.randn_like(code) * variance
                image = model.code_forward(noisy_code)
                image = image.cpu().detach().numpy().reshape(28, 28)
                class_images[(label, variance)] = image
    return class_images

def plot_class_images(class_images, n_cols=10, save_path=None):
    """
    Plot the class images using Matplotlib.
    """
    n_rows = len(class_images) // n_cols + 1
    plt.figure(figsize=(16, 2 * n_rows))
    for i, (key, image) in enumerate(class_images.items()):
        plt.subplot(n_rows, n_cols, i + 1)
        plt.imshow(image, cmap='gray')
        plt.title(f"Class {key[0]}, Variance {key[1]:.4f}")
        plt.axis('off')
    if save_path:
        plt.savefig(save_path)
    plt.show()

## Main

In [8]:
def unit_test():
    for model_archtype in ["ANN", "SAE", "SAEANN", "VAE"]:
        # Load and preprocess KMNIST data
        tr_X, tr_Y, ts_X, ts_Y = load_and_preprocess_data()

        # Split training data into training and validation sets
        X_train, Y_train, X_val, Y_val = train_val_split(tr_X, tr_Y, val_size=0.35)

        # Convert to PyTorch tensors
        train_data = TensorDataset(torch.tensor(X_train).float(), torch.tensor(Y_train).long())
        val_data = TensorDataset(torch.tensor(X_val).float(), torch.tensor(Y_val).long())
        test_data = TensorDataset(torch.tensor(ts_X).float(), torch.tensor(ts_Y).long())

        # Choose architecture and hyperparameter grid
        archtype = model_archtype  # Set to "SAEANN" for testing the combined model
        searchkey = "unit-test"  # Primary hyperparameter to test

        # Determine the model class based on architecture
        if archtype == "SAE":
            model_class = ConfigurableSAE
        elif archtype == "ANN":
            model_class = ConfigurableANN
        elif archtype == "SAEANN":
            model_class = ConfigurableSAEANN
        elif archtype == "VAE":
            model_class = ConfigurableVAE
        else:
            raise ValueError(f"Unsupported archtype '{archtype}'.")

        # Get the parameter grid for the selected architecture
        param_grid = ret_search_param_grid(archtype=archtype, searchkey=searchkey)

        # Perform grid search
        results_df, best_model, best_params = grid_search_cv(train_data, val_data, param_grid, model_class, searchkey, device='cuda')

        # Plot the results (optional)
        if archtype == "SAE" and searchkey != "3d_grid":
            plot_hyperparameters(results_df, searchkey, metric='val_loss', save_path=f"{searchkey}/val_loss.png")
        elif archtype == "ANN":
            plot_hyperparameters(results_df, searchkey, metric='val_accuracy', save_path=f"{searchkey}/val_accuracy.png")

        # Display results
        print("Grid Search Results:")
        print(results_df)

        # Identify the best hyperparameters
        best_hyperparams = query_best_hyperparameters(results_df, model_class)
        print("Best Hyperparameters:", best_hyperparams)

        # Evaluate the best model on the test set
        print("Evaluating best model on the test set...")
        if best_model is None:
            raise RuntimeError("No valid model found. Check grid search results.")

        best_model.eval()
        test_loss = 0
        correct = 0
        total = 0

        criterion = nn.CrossEntropyLoss()

        test_loader = DataLoader(test_data, batch_size=best_params.get('batch_size', 32), shuffle=False)

        if model_archtype == "ANN":
            # For ANN, proceed as before
            with torch.no_grad():
                for X_batch, Y_batch in test_loader:
                    X_batch, Y_batch = X_batch.to('cuda'), Y_batch.to('cuda')
                    X_batch = X_batch.view(X_batch.size(0), -1)
                    outputs = best_model(X_batch)
                    loss = criterion(outputs, Y_batch)
                    test_loss += loss.item() * X_batch.size(0)
                    _, predicted = torch.max(outputs, 1)
                    correct += (predicted == Y_batch).sum().item()
                    total += Y_batch.size(0)
            test_loss /= total
            test_accuracy = correct / total
        elif model_archtype == "SAE":
            # For SAE, use the evaluate methods
            test_loss_rc = best_model.evaluate_reconstruction(test_loader)
            test_loss_cd, test_accuracy = best_model.evaluate_code(test_loader)
            test_loss = test_loss_rc + test_loss_cd
            print(f"Test Reconstruction Loss: {test_loss_rc:.4f}")
            print(f"Test Code Loss: {test_loss_cd:.4f}")
        elif model_archtype == "SAEANN":
            # For SAEANN, adjust outputs accordingly
            with torch.no_grad():
                for X_batch, Y_batch in test_loader:
                    X_batch, Y_batch = X_batch.to('cuda'), Y_batch.to('cuda')
                    X_batch = X_batch.view(X_batch.size(0), -1)
                    outputs, _ = best_model(X_batch)
                    loss = criterion(outputs, Y_batch)
                    test_loss += loss.item() * X_batch.size(0)
                    _, predicted = torch.max(outputs, 1)
                    correct += (predicted == Y_batch).sum().item()
                    total += Y_batch.size(0)
            test_loss /= total
            test_accuracy = correct / total
        elif model_archtype == "VAE":
            # For VAE, use appropriate evaluation methods
            test_loss = best_model.evaluate_reconstruction(test_loader)
            test_accuracy = None  # VAE may not have accuracy metric

        print(f"Test Loss: {test_loss:.4f}")
        if test_accuracy is not None:
            print(f"Test Accuracy: {test_accuracy:.4f}")

        # Clean up
        del param_grid, results_df, best_model, best_params, test_loss, test_accuracy, best_hyperparams, model_class, correct, total, criterion, test_loader, tr_X, tr_Y, ts_X, ts_Y, X_train, Y_train, X_val, Y_val, train_data, val_data, test_data, archtype, searchkey, model_archtype
        # Garbage collection
        gc.collect()

def testing():
    model_archtypes = ["SAE"]
    #model_archtypes = ["ANN", "SAE", "SAEANN", "VAE"]
    for model_archtype in model_archtypes:
        #keys = ret_search_param_grid(archtype=model_archtype, searchkey=None)
        keys = ["3d_grid"]
        for key in keys:
            if key == "unit-test":
                continue
            # Load and preprocess KMNIST data
            tr_X, tr_Y, ts_X, ts_Y = load_and_preprocess_data()

            # Split training data into training and validation sets
            X_train, Y_train, X_val, Y_val = train_val_split(tr_X, tr_Y, val_size=0.35)

            # Convert to PyTorch tensors
            train_data = TensorDataset(torch.tensor(X_train).float(), torch.tensor(Y_train).long())
            val_data = TensorDataset(torch.tensor(X_val).float(), torch.tensor(Y_val).long())
            test_data = TensorDataset(torch.tensor(ts_X).float(), torch.tensor(ts_Y).long())

            # Choose architecture and hyperparameter grid
            archtype = model_archtype  # Set to "SAEANN" for testing the combined model
            searchkey = key  # Primary hyperparameter to test

            # Determine the model class based on architecture
            if archtype == "SAE":
                model_class = ConfigurableSAE
            elif archtype == "ANN":
                model_class = ConfigurableANN
            elif archtype == "SAEANN":
                model_class = ConfigurableSAEANN
            elif archtype == "VAE":
                model_class = ConfigurableVAE
            else:
                raise ValueError(f"Unsupported archtype '{archtype}'.")

            # Get the parameter grid for the selected architecture
            param_grid = ret_search_param_grid(archtype=archtype, searchkey=searchkey)

            # Perform grid search
            results_df, best_model, best_params = grid_search_cv(train_data, val_data, param_grid, model_class, searchkey, device='cuda')

            # Plot the results (optional)
            if searchkey in ["learning_rate", "batch_size", "_lambda", "dropout_rate", "code_length"]:
                plot_hyperparameters(results_df, searchkey, metric='val_loss', save_path=f"{searchkey}/{archtype}_val_loss.png")

            # Display results
            print("Grid Search Results:")
            print(results_df)

            # Identify the best hyperparameters
            best_hyperparams = query_best_hyperparameters(results_df, model_class)
            print("Best Hyperparameters:", best_hyperparams)

            # Evaluate the best model on the test set
            print("Evaluating best model on the test set...")
            if best_model is None:
                raise RuntimeError("No valid model found. Check grid search results.")

            best_model.eval()
            test_loss = 0
            correct = 0
            total = 0

            criterion = nn.CrossEntropyLoss()

            test_loader = DataLoader(test_data, batch_size=best_params.get('batch_size', 32), shuffle=False)

            if model_archtype == "ANN":
                # For ANN, proceed as before
                with torch.no_grad():
                    for X_batch, Y_batch in test_loader:
                        X_batch, Y_batch = X_batch.to('cuda'), Y_batch.to('cuda')
                        X_batch = X_batch.view(X_batch.size(0), -1)
                        outputs = best_model(X_batch)
                        loss = criterion(outputs, Y_batch)
                        test_loss += loss.item() * X_batch.size(0)
                        _, predicted = torch.max(outputs, 1)
                        correct += (predicted == Y_batch).sum().item()
                        total += Y_batch.size(0)
                test_loss /= total
                test_accuracy = correct / total
            elif model_archtype == "SAE":
                # For SAE, use the evaluate methods
                test_loss_rc = best_model.evaluate_reconstruction(test_loader)
                test_loss_cd, test_accuracy = best_model.evaluate_code(test_loader)
                test_loss = test_loss_rc + test_loss_cd
                print(f"Test Reconstruction Loss: {test_loss_rc:.4f}")
                print(f"Test Code Loss: {test_loss_cd:.4f}")
            elif model_archtype == "SAEANN":
                # For SAEANN, adjust outputs accordingly
                with torch.no_grad():
                    for X_batch, Y_batch in test_loader:
                        X_batch, Y_batch = X_batch.to('cuda'), Y_batch.to('cuda')
                        X_batch = X_batch.view(X_batch.size(0), -1)
                        outputs, _ = best_model(X_batch)
                        loss = criterion(outputs, Y_batch)
                        test_loss += loss.item() * X_batch.size(0)
                        _, predicted = torch.max(outputs, 1)
                        correct += (predicted == Y_batch).sum().item()
                        total += Y_batch.size(0)
                test_loss /= total
                test_accuracy = correct / total
            elif model_archtype == "VAE":
                # For VAE, use appropriate evaluation methods
                test_loss = best_model.evaluate_reconstruction(test_loader)
                test_accuracy = None  # VAE may not have accuracy metric

            print(f"Test Loss: {test_loss:.4f}")
            if test_accuracy is not None:
                print(f"Test Accuracy: {test_accuracy:.4f}")


In [None]:
#unit_test()
testing()

Testing parameters: {'learning_rate': 0.0009765625, 'batch_size': 32, 'epochs': 2048, 'dropout_rate': 0.0, 'input_size': 784, 'encoder_PE_nodes': [800, 200, 10], 'encoder_PE_activations': ['relu', 'sigmoid', 'relu'], 'decoder_PE_nodes': [200, 800, 784], 'decoder_PE_activations': ['relu', 'sigmoid', 'sigmoid'], 'loss': 'custom', '_lambda': 0.0009765625, 'norm_type': 1}
Epoch 16/2048 - Train Loss: 0.0501, Val Loss: 0.0461
Epoch 32/2048 - Train Loss: 0.0471, Val Loss: 0.0433
Epoch 48/2048 - Train Loss: 0.0455, Val Loss: 0.0430
Epoch 64/2048 - Train Loss: 0.0446, Val Loss: 0.0422
Epoch 80/2048 - Train Loss: 0.0439, Val Loss: 0.0426
Epoch 96/2048 - Train Loss: 0.0434, Val Loss: 0.0426
Early stopping at epoch 97 - Train Loss: 0.0434, Val Loss: 0.0421
Model saved to ./3d_grid/ConfigurableSAE_3d_grid_0_0.0421_deb9ba71.pth
Epoch 16/2048 - Train Loss: 0.0500, Val Loss: 0.0464
Epoch 32/2048 - Train Loss: 0.0469, Val Loss: 0.0438
Epoch 48/2048 - Train Loss: 0.0455, Val Loss: 0.0438
Epoch 64/2048 -

### Main Generating Code