In [None]:
import sys
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import csv
import time
import json
import hashlib
import random
from MLP_model import MLPClassifier  # Import the MLPClassifier class
from itertools import product

# Function to set all random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # If you are using multi-GPU.
    
# Set the base seed
set_seed(42)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Adjust if necessary
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

# Load the KMNIST dataset
train_dataset = datasets.KMNIST(root='../data', train=True, transform=transform, download=True)
val_dataset = datasets.KMNIST(root='../data', train=False, transform=transform, download=True)

# Input and output sizes
input_size = 784  # For 28x28 images
output_size = 10  # Number of classes

# Hyperparameter grid for tuning
hyperparameter_grid = {
    'learning_rate': [1e-3, 1e-4, 1e-5],
    'layer_sizes': [[128, 64], [256, 128], [512, 256]],
    'dropout_rates': [[0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4], [0.5, 0.5]],
    'batch_norm': [[True, True], [False, False]],
    'weight_decay': [0.0, 1e-5, 1e-4, 1e-3],
    'batch_size': [16, 32, 64, 128],
    'leaky_relu_negative_slope': [0.01, 0.1],  # Negative slope for LeakyReLU
    'num_epochs': [200],  # Max number of training epochs
    # ReduceLROnPlateau hyperparameters
    'scheduler_mode': ['min'],  # Mode for the scheduler
    'scheduler_factor': [0.1],  # Factor by which the learning rate will be reduced
    'scheduler_patience': [3],  # Number of epochs with no improvement after which learning rate will be reduced
    'scheduler_threshold': [1e-4],  # Threshold for measuring the new optimum
    'scheduler_cooldown': [0],  # Number of epochs to wait before resuming normal operation after lr has been reduced
    # Early Stopping hyperparameters
    'early_stopping_mode': ['min'],  # Mode for early stopping
    'early_stopping_patience': [5],  # Patience for early stopping
    'early_stopping_delta': [1e-4],  # Minimum change to qualify as improvement
}

# Generate all combinations of hyperparameters
keys, values = zip(*hyperparameter_grid.items())
hyperparameter_combinations = [dict(zip(keys, v)) for v in product(*values)]

csv_filename = 'MLP_hyperparameter_tuning_results.csv'

# Initialize an empty set to store hashes of existing results
existing_hashes = set()

# Initialize best_val_loss to infinity
best_val_loss = float('inf')

if os.path.exists(csv_filename):
    with open(csv_filename, 'r', newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            hparams_hash = row['hparams_hash']
            existing_hashes.add(hparams_hash)
            try:
                current_val_loss = float(row['val_loss'])
                if current_val_loss < best_val_loss:
                    best_val_loss = current_val_loss
            except ValueError:
                # Handle cases where val_loss might not be a valid float
                continue
    print(f"Existing best validation loss from CSV: {best_val_loss:.4f}")
else:
    print("CSV file does not exist. Starting fresh.")
    # If the file does not exist, we'll create it later

def train_and_evaluate(hparams):
    # Unpack hyperparameters
    learning_rate = hparams['learning_rate']
    layer_sizes = hparams['layer_sizes']
    dropout_rates = hparams['dropout_rates']
    batch_norm = hparams['batch_norm']
    weight_decay = hparams['weight_decay']
    batch_size = hparams['batch_size']
    leaky_relu_negative_slope = hparams['leaky_relu_negative_slope']
    num_epochs = hparams.get('num_epochs', 50)
    
    # Scheduler hyperparameters
    scheduler_mode = hparams['scheduler_mode']
    scheduler_factor = hparams['scheduler_factor']
    scheduler_patience = hparams['scheduler_patience']
    scheduler_threshold = hparams['scheduler_threshold']
    scheduler_cooldown = hparams['scheduler_cooldown']
    
    # Early Stopping hyperparameters
    early_stopping_mode = hparams['early_stopping_mode']
    early_stopping_patience = hparams['early_stopping_patience']
    early_stopping_delta = hparams['early_stopping_delta']
    
    # Define activation functions: Always LeakyReLU with specified negative slope
    activation_functions = [nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) for _ in layer_sizes]

    # Initialize the model
    model = MLPClassifier(
        input_size=input_size,
        layer_sizes=layer_sizes,
        output_size=output_size,
        activation_functions=activation_functions,
        dropout_rates=dropout_rates,
        batch_norm=batch_norm,
        weight_init=init_weights
    ).to(device)

    # Optimizer (always AdamW)
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Scheduler: ReduceLROnPlateau
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode=scheduler_mode,
        factor=scheduler_factor,
        patience=scheduler_patience,
        threshold=scheduler_threshold,
        threshold_mode='rel',
        cooldown=scheduler_cooldown,
    )

    # Data loaders
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    # Loss function
    loss_function = nn.CrossEntropyLoss()

    # Early Stopping variables
    best_metric = None
    epochs_no_improve = 0
    early_stop = False

    # Determine the comparison operator based on mode
    if early_stopping_mode == 'min':
        def is_improvement(current, best):
            return current < best - early_stopping_delta
        best_metric = float('inf')
    elif early_stopping_mode == 'max':
        def is_improvement(current, best):
            return current > best + early_stopping_delta
        best_metric = float('-inf')
    else:
        raise ValueError(f"Unsupported early_stopping_mode: {early_stopping_mode}")

    # Initialize variables for logging
    initial_lr = optimizer.param_groups[0]['lr']
    current_lr = initial_lr
    lr_reduction_epochs = []
    epoch_logs = []
    
    # Training loop
    for epoch in range(num_epochs):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        model.train()
        train_loss = 0
        correct_train = 0
        total_train = 0

        for data in train_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = loss_function(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            # Calculate training accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        train_loss /= len(train_loader)
        train_accuracy = 100 * correct_train / total_train

        # Validation
        model.eval()
        val_loss = 0
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                loss = loss_function(outputs, labels)
                val_loss += loss.item()

                # Calculate validation accuracy
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        val_loss /= len(val_loader)
        val_accuracy = 100 * correct_val / total_val

        # Scheduler step: ReduceLROnPlateau expects a metric to monitor
        scheduler.step(val_loss)

        # Check if learning rate has been reduced
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr < current_lr:
            lr_reduction_epochs.append(epoch+1)  # Epochs are 1-indexed
            current_lr = new_lr

        # Record per-epoch logs
        epoch_logs.append({
            'epoch': epoch+1,
            'train_loss': train_loss,
            'train_accuracy': train_accuracy,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'learning_rate': current_lr
        })

        # Early Stopping check
        if is_improvement(val_loss, best_metric):
            best_metric = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= early_stopping_patience:
            early_stop = True

        # Optionally, print progress
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%, "
              f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    # Prepare training log
    training_log = {
        'final_epoch': epoch+1,
        'lr_reduction_epochs': lr_reduction_epochs,
        'epoch_logs': epoch_logs
    }

    # Return final validation loss and accuracy, training log, and the model
    return val_loss, val_accuracy, training_log, model, activation_functions, optimizer, scheduler

# Main loop for hyperparameter tuning
results = []

for idx, hparams in enumerate(hyperparameter_combinations):
    # Set the base seed
    set_seed(42)
    # Generate a unique id for the hyperparameters
    hparams_str = json.dumps(hparams, sort_keys=True)
    hparams_hash = hashlib.md5(hparams_str.encode('utf-8')).hexdigest()
    
    # Check if this combination exists in existing_results
    if hparams_hash in existing_hashes:
        print(f"Skipping hyperparameter combination {idx+1}/{len(hyperparameter_combinations)}: {hparams} (already tested)")
        continue
    else:
        print(f"Testing hyperparameter combination {idx+1}/{len(hyperparameter_combinations)}: {hparams}")
        try:
            # Call train_and_evaluate and record the computation time
            start_time = time.time()
            val_loss, val_accuracy, training_log, model, activation_functions, optimizer, scheduler = train_and_evaluate(hparams)
            end_time = time.time()
            computation_time = end_time - start_time

            # Save the per-epoch logs to a file
            log_filename = f"misc/training_log_{hparams_hash}.json"
            with open(log_filename, 'w') as f:
                json.dump(training_log, f)
            
            # Append the result to the CSV file
            with open(csv_filename, 'a', newline='') as csvfile:
                fieldnames = ['hparams_hash'] + list(hparams.keys()) + ['val_loss', 'val_accuracy', 'computation_time', 'final_num_epochs', 'lr_reduction_epochs', 'log_filename']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                # If the file is new, write the header
                if csvfile.tell() == 0:
                    writer.writeheader()
                result_row = {'hparams_hash': hparams_hash}
                for key, value in hparams.items():
                    # Convert value to string using json.dumps
                    result_row[key] = json.dumps(value)
                result_row.update({
                    'val_loss': val_loss,
                    'val_accuracy': val_accuracy,
                    'computation_time': computation_time,
                    'final_num_epochs': training_log['final_epoch'],
                    'lr_reduction_epochs': json.dumps(training_log['lr_reduction_epochs']),
                    'log_filename': log_filename
                })
                writer.writerow(result_row)

            # Update the set of existing hashes
            existing_hashes.add(hparams_hash)

            # Check if current val_loss is better than the best_val_loss
            if val_loss < best_val_loss: 
                best_val_loss = val_loss
                # Save the model's state_dict
                best_model_filename = 'MLP_best_model.pth'
                # Prepare all necessary information for saving
                save_data = {
                    'model_state_dict': model.state_dict(),
                    'model_config': {
                        'input_size': input_size,
                        'output_size': output_size,
                        'layer_sizes': hparams['layer_sizes'],
                        'activation_functions': [str(type(act)) for act in activation_functions],
                        'dropout_rates': hparams['dropout_rates'],
                        'batch_norm': hparams['batch_norm'],
                    },
                    'optimizer_state_dict': optimizer.state_dict(),  # Save optimizer state if needed for resuming training
                    'scheduler_state_dict': scheduler.state_dict(),  # Save scheduler state if used
                    'hyperparameters': hparams,  # Save the hyperparameters for reference
                    'training_log': training_log  # Include the training log if desired
                }

                # Save to file
                torch.save(save_data, best_model_filename)
                print(f"New best model saved to {best_model_filename} with validation loss: {val_loss:.4f}")
            
            # Append to results for final reporting
            results.append({'hparams': hparams, 'val_loss': val_loss, 'val_accuracy': val_accuracy})
            print(f"Result - Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%\n")
        except Exception as e:
            print(f"Error with hyperparameters {hparams}: {e}\n")
            continue

# Find the best hyperparameters based on validation accuracy
if results:
    best_result = min(results, key=lambda x: x['val_loss'])  # Changed to min based on val_loss
    print("Best hyperparameters based on validation loss:")
    print(best_result['hparams'])
    print(f"Validation Loss: {best_result['val_loss']:.4f}, Validation Accuracy: {best_result['val_accuracy']:.2f}%")
else:
    print("No results to report.")


Existing best validation loss from CSV: 0.3055
Skipping hyperparameter combination 1/2880: {'learning_rate': 0.001, 'layer_sizes': [128, 64], 'dropout_rates': [0.1, 0.1], 'batch_norm': [True, True], 'weight_decay': 0.0, 'batch_size': 16, 'leaky_relu_negative_slope': 0.01, 'num_epochs': 200, 'scheduler_mode': 'min', 'scheduler_factor': 0.1, 'scheduler_patience': 3, 'scheduler_threshold': 0.0001, 'scheduler_cooldown': 0, 'early_stopping_mode': 'min', 'early_stopping_patience': 5, 'early_stopping_delta': 0.0001} (already tested)
Skipping hyperparameter combination 2/2880: {'learning_rate': 0.001, 'layer_sizes': [128, 64], 'dropout_rates': [0.1, 0.1], 'batch_norm': [True, True], 'weight_decay': 0.0, 'batch_size': 16, 'leaky_relu_negative_slope': 0.1, 'num_epochs': 200, 'scheduler_mode': 'min', 'scheduler_factor': 0.1, 'scheduler_patience': 3, 'scheduler_threshold': 0.0001, 'scheduler_cooldown': 0, 'early_stopping_mode': 'min', 'early_stopping_patience': 5, 'early_stopping_delta': 0.0001} (