In [None]:
import sys
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import csv
import time
import json
import hashlib
import random
from SAE_model import StackedAutoencoder  # Import the SAE class
from itertools import product

# Function to set all random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # If you are using multi-GPU.
    
# Set the base seed
set_seed(42)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image to a vector
])

# Load the KMNIST dataset
train_dataset = datasets.KMNIST(root='../data', train=True, transform=transform, download=True)
val_dataset = datasets.KMNIST(root='../data', train=False, transform=transform, download=True)

# Hyperparameter grid for tuning
hyperparameter_grid = {
    'learning_rate': [1e-5, 5e-5, 1e-4, 5e-4, 1e-3, 5e-3],
    'layer_sizes': [
        [800, 200, 25],
        [800, 200, 50],
        [800, 200, 100],
        [800, 200, 150]
    ],
    'dropout_rates': [[0.1, 0.1, 0.1], [0.2, 0.2, 0.2], [0.3, 0.3, 0.3], [0.4, 0.4, 0.4], [0.5, 0.5, 0.5]],
    'weight_decay': [0.0, 1e-5, 1e-4, 1e-3],
    'batch_size': [16, 32, 64, 128, 256],
    'leaky_relu_negative_slope': [0.01, 0.05, 0.1, 0.2],
    'num_epochs': [200],  # Max number of training epochs
    # Early Stopping hyperparameters
    'early_stopping_patience': [5],
    'early_stopping_min_delta': [1e-4],
    # ReduceLROnPlateau hyperparameters
    'scheduler_mode': ['min'],
    'scheduler_factor': [0.1],
    'scheduler_patience': [3],
    'scheduler_threshold': [1e-4],
    'scheduler_cooldown': [0]
}

# Generate all combinations of hyperparameters
keys, values = zip(*hyperparameter_grid.items())
hyperparameter_combinations = [dict(zip(keys, v)) for v in product(*values)]

csv_filename = 'SAE_hyperparameter_tuning_results.csv'

# Initialize an empty set to store hashes of existing results
existing_hashes = set()

if os.path.exists(csv_filename):
    with open(csv_filename, 'r', newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            hparams_hash = row['hparams_hash']
            existing_hashes.add(hparams_hash)
else:
    # If the file does not exist, we'll create it later
    pass

def train_and_evaluate(hparams):
    # Unpack hyperparameters
    learning_rate = hparams['learning_rate']
    layer_sizes = hparams['layer_sizes']
    dropout_rates = hparams['dropout_rates']
    weight_decay = hparams['weight_decay']
    batch_size = hparams['batch_size']
    leaky_relu_negative_slope = hparams['leaky_relu_negative_slope']
    num_epochs = hparams.get('num_epochs', 50)
    
    # Early Stopping hyperparameters
    early_stopping_patience = hparams['early_stopping_patience']
    early_stopping_min_delta = hparams['early_stopping_min_delta']
    
    # Scheduler hyperparameters
    scheduler_mode = hparams['scheduler_mode']
    scheduler_factor = hparams['scheduler_factor']
    scheduler_patience = hparams['scheduler_patience']
    scheduler_threshold = hparams['scheduler_threshold']
    scheduler_cooldown = hparams['scheduler_cooldown']
    
    # Define activation functions
    activation_functions = [nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) for _ in layer_sizes]
    
    # Initialize the model
    model = StackedAutoencoder(
        input_size=784,  # For 28x28 images
        layer_sizes=layer_sizes,
        activation_functions=activation_functions,
        dropout_rates=dropout_rates,
        weight_init=init_weights
    ).to(device)
    
    # Optimizer
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    # Scheduler: ReduceLROnPlateau
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode=scheduler_mode,
        factor=scheduler_factor,
        patience=scheduler_patience,
        threshold=scheduler_threshold,
        threshold_mode='rel',
        cooldown=scheduler_cooldown,
    )
    
    # Data loaders
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
    
    # Loss function
    loss_function = nn.MSELoss()
    
    # Early Stopping variables
    best_val_loss = float('inf')
    epochs_no_improve = 0
    early_stop = False
    
    # Initialize variables for logging
    initial_lr = optimizer.param_groups[0]['lr']
    current_lr = initial_lr
    lr_reduction_epochs = []
    epoch_logs = []
    
    # Training loop
    for epoch in range(num_epochs):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        # Training phase
        model.train()
        train_loss = 0.0
        for data in train_loader:
            inputs, _ = data
            inputs = inputs.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = loss_function(outputs, inputs)
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        train_loss /= len(train_loader)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for data in val_loader:
                inputs, _ = data
                inputs = inputs.to(device)
                outputs = model(inputs)
                loss = loss_function(outputs, inputs)
                val_loss += loss.item()
        val_loss /= len(val_loader)
        
        # Scheduler step
        scheduler.step(val_loss)
        
        # Check if learning rate has been reduced
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr < current_lr:
            lr_reduction_epochs.append(epoch+1)  # Epochs are 1-indexed
            current_lr = new_lr
        
        # Record per-epoch logs
        epoch_logs.append({
            'epoch': epoch+1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'learning_rate': current_lr
        })
        
        # Early Stopping check
        if val_loss < best_val_loss - early_stopping_min_delta:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve >= early_stopping_patience:
            early_stop = True
        
        # Optionally, print progress
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Training Loss: {train_loss:.4f}, "
              f"Validation Loss: {val_loss:.4f}")
        
    # Prepare training log
    training_log = {
        'final_epoch': epoch+1,
        'lr_reduction_epochs': lr_reduction_epochs,
        'epoch_logs': epoch_logs
    }
    
    # Return final validation loss and training log
    return val_loss, training_log

# Main loop for hyperparameter tuning
results = []

for idx, hparams in enumerate(hyperparameter_combinations):
    # Generate a unique id for the hyperparameters
    hparams_str = json.dumps(hparams, sort_keys=True)
    hparams_hash = hashlib.md5(hparams_str.encode('utf-8')).hexdigest()
    
    # Check if this combination exists in existing_results
    if hparams_hash in existing_hashes:
        print(f"Skipping hyperparameter combination {idx+1}/{len(hyperparameter_combinations)}: {hparams} (already tested)")
        continue
    else:
        print(f"Testing hyperparameter combination {idx+1}/{len(hyperparameter_combinations)}: {hparams}")
        try:
            # Call train_and_evaluate and record the computation time
            start_time = time.time()
            val_loss, training_log = train_and_evaluate(hparams)
            end_time = time.time()
            computation_time = end_time - start_time

            # Save the per-epoch logs to a file
            log_filename = f"training_log_{hparams_hash}.json"
            with open(log_filename, 'w') as f:
                json.dump(training_log, f)
            
            # Append the result to the CSV file
            with open(csv_filename, 'a', newline='') as csvfile:
                fieldnames = ['hparams_hash'] + list(hparams.keys()) + ['val_loss', 'computation_time', 'final_num_epochs', 'lr_reduction_epochs', 'log_filename']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                # If the file is new, write the header
                if csvfile.tell() == 0:
                    writer.writeheader()
                result_row = {'hparams_hash': hparams_hash}
                for key, value in hparams.items():
                    # Convert value to string using json.dumps
                    result_row[key] = json.dumps(value)
                result_row.update({
                    'val_loss': val_loss,
                    'computation_time': computation_time,
                    'final_num_epochs': training_log['final_epoch'],
                    'lr_reduction_epochs': json.dumps(training_log['lr_reduction_epochs']),
                    'log_filename': log_filename
                })
                writer.writerow(result_row)

            # Update the set of existing hashes
            existing_hashes.add(hparams_hash)

            results.append({'hparams': hparams, 'val_loss': val_loss})
            print(f"Result - Validation Loss: {val_loss:.4f}\n")
        except Exception as e:
            print(f"Error with hyperparameters {hparams}: {e}\n")
            continue

# Find the best hyperparameters based on validation loss
if results:
    best_result = min(results, key=lambda x: x['val_loss'])
    print("Best hyperparameters based on validation loss:")
    print(best_result['hparams'])
    print(f"Validation Loss: {best_result['val_loss']:.4f}")
else:
    print("No successful runs to report.")


Testing hyperparameter combination 1/9600: {'learning_rate': 1e-05, 'layer_sizes': [800, 200, 25], 'dropout_rates': [0.1, 0.1, 0.1], 'weight_decay': 0.0, 'batch_size': 16, 'leaky_relu_negative_slope': 0.01, 'num_epochs': 200, 'early_stopping_patience': 5, 'early_stopping_min_delta': 0.0001, 'scheduler_mode': 'min', 'scheduler_factor': 0.1, 'scheduler_patience': 3, 'scheduler_threshold': 0.0001, 'scheduler_cooldown': 0}
Epoch [1/200], Training Loss: 0.8338, Validation Loss: 0.6386
Epoch [2/200], Training Loss: 0.5265, Validation Loss: 0.4198
Epoch [3/200], Training Loss: 0.4698, Validation Loss: 0.4094
Epoch [4/200], Training Loss: 0.4600, Validation Loss: 0.3931
Epoch [5/200], Training Loss: 0.4406, Validation Loss: 0.3773
Epoch [6/200], Training Loss: 0.4311, Validation Loss: 0.3726
Epoch [7/200], Training Loss: 0.4273, Validation Loss: 0.3690
Epoch [8/200], Training Loss: 0.4248, Validation Loss: 0.3668
Epoch [9/200], Training Loss: 0.4232, Validation Loss: 0.3652
Epoch [10/200], Tra