In [1]:
import sys
import os
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
from torchvision import datasets, transforms
import csv
import time
import json
import hashlib
import random
from SAE_model import StackedAutoencoder  # Import the SAE class
from itertools import product

# Function to set all random seeds for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # If you are using multi-GPU.

# Set the base seed
set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image to a vector
])

# Load the KMNIST dataset
# Load the full training dataset
full_train_dataset = datasets.KMNIST(root='../data', train=True, transform=transform, download=True)

# Define the sizes for the training and validation datasets
val_size = 10000
train_size = len(full_train_dataset) - val_size

# Split the training data into train and validation sets
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Load the saved autoencoder model
model_checkpoint = torch.load('SAE_best_model.pth', map_location=device)

# Get the model configuration
model_config = model_checkpoint['config']
input_size = model_config['input_size']
layer_sizes = model_config['layer_sizes']
dropout_rates = model_config['dropout_rates']
activation_functions = [nn.LeakyReLU(negative_slope=0.01) for _ in layer_sizes]  # Assuming LeakyReLU with negative_slope=0.01

# Rebuild the autoencoder model
autoencoder = StackedAutoencoder(
    input_size=input_size,
    layer_sizes=layer_sizes,
    activation_functions=activation_functions,
    dropout_rates=dropout_rates,
    weight_init=None  # We don't need to initialize weights, as we'll load them
).to(device)

# Load the state_dict
autoencoder.load_state_dict(model_checkpoint['state_dict'])

# Set the autoencoder to evaluation mode
autoencoder.eval()

# Get the output size of the encoder
encoder_output_size = layer_sizes[-1]

# Define the classification model
class SAEClassifier(nn.Module):
    def __init__(self, encoder, encoder_output_size, classifier_hidden_sizes, num_classes, leaky_relu_negative_slope=0.01, batch_norm=True):
        super(SAEClassifier, self).__init__()
        # Encoder (pre-trained)
        self.encoder = encoder  # We will freeze this
        for param in self.encoder.parameters():
            param.requires_grad = False
        # Classifier head
        layers = []
        prev_size = encoder_output_size  # The output size of the encoder
        for idx, hidden_size in enumerate(classifier_hidden_sizes):
            layers.append(nn.Linear(prev_size, hidden_size))
            if batch_norm:
                layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.LeakyReLU(negative_slope=leaky_relu_negative_slope))
            prev_size = hidden_size
        # Output layer
        layers.append(nn.Linear(prev_size, num_classes))
        self.classifier = nn.Sequential(*layers)
    def forward(self, x):
        x = self.encoder(x)
        x = self.classifier(x)
        return x

# Hyperparameter grid for tuning
hyperparameter_grid = {
    'learning_rate': [1e-3, 1e-4],
    'classifier_hidden_sizes': [[100], [200], [100, 50]],
    'batch_size': [32, 64],
    'weight_decay': [1e-4, 1e-3],
    'leaky_relu_negative_slope': [0.01],
    'num_epochs': [100],  # Max number of training epochs
    'batch_norm': [True],
    # Early Stopping hyperparameters
    'early_stopping_patience': [5],
    'early_stopping_min_delta': [1e-4],
    # Scheduler hyperparameters (ReduceLROnPlateau)
    'scheduler_mode': ['min'],
    'scheduler_factor': [0.1],
    'scheduler_patience': [3],
    'scheduler_threshold': [1e-4],
    'scheduler_cooldown': [0]
}

# Generate all combinations of hyperparameters
keys, values = zip(*hyperparameter_grid.items())
hyperparameter_combinations = [dict(zip(keys, v)) for v in product(*values)]

csv_filename = 'SAE_classifier_hyperparameter_tuning_results.csv'

# Initialize an empty set to store hashes of existing results
existing_hashes = set()

# Initialize best_val_loss by reading existing CSV
best_val_loss =  float('inf')

if os.path.exists(csv_filename):
    with open(csv_filename, 'r', newline='') as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            hparams_hash = row['hparams_hash']
            existing_hashes.add(hparams_hash)
            try:
                val_loss = float(row['val_loss'])
                if val_loss > best_val_loss:
                    best_val_loss = val_loss
            except ValueError:
                # If val_accuracy is not a float, skip
                continue
    print(f"Current best validation accuracy from CSV: {best_val_loss:.3f}%")
else:
    print("CSV file does not exist. Starting fresh.")

def train_and_evaluate(hparams):
    # Unpack hyperparameters
    learning_rate = hparams['learning_rate']
    classifier_hidden_sizes = hparams['classifier_hidden_sizes']
    batch_size = hparams['batch_size']
    weight_decay = hparams['weight_decay']
    leaky_relu_negative_slope = hparams['leaky_relu_negative_slope']
    num_epochs = hparams.get('num_epochs', 20)
    batch_norm = hparams['batch_norm']

    # Early Stopping hyperparameters
    early_stopping_patience = hparams['early_stopping_patience']
    early_stopping_min_delta = hparams['early_stopping_min_delta']

    # Scheduler hyperparameters
    scheduler_mode = hparams['scheduler_mode']
    scheduler_factor = hparams['scheduler_factor']
    scheduler_patience = hparams['scheduler_patience']
    scheduler_threshold = hparams['scheduler_threshold']
    scheduler_cooldown = hparams['scheduler_cooldown']

    # Define the model
    model = SAEClassifier(
        encoder=autoencoder.encoder,
        encoder_output_size=encoder_output_size,
        classifier_hidden_sizes=classifier_hidden_sizes,
        num_classes=10,  # KMNIST has 10 classes
        leaky_relu_negative_slope=leaky_relu_negative_slope,
        batch_norm=batch_norm
    ).to(device)

    # Apply weight initialization to the classifier head
    def init_weights(m):
        if isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    model.classifier.apply(init_weights)

    # Set up the optimizer (only trainable parameters)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, weight_decay=weight_decay)

    # Scheduler: ReduceLROnPlateau
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode=scheduler_mode,
        factor=scheduler_factor,
        patience=scheduler_patience,
        threshold=scheduler_threshold,
        threshold_mode='rel',
        cooldown=scheduler_cooldown,
    )

    # Set up the loss function
    criterion = nn.CrossEntropyLoss()

    # Data loaders
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

    # Early Stopping variables
    best_val_loss = float('inf')
    epochs_no_improve = 0
    early_stop = False

    # Initialize variables for logging
    initial_lr = optimizer.param_groups[0]['lr']
    current_lr = initial_lr
    lr_reduction_epochs = []
    epoch_logs = []

    # Training loop
    for epoch in range(num_epochs):
        if early_stop:
            print(f"Early stopping at epoch {epoch+1}")
            break

        # Training phase
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0
        for data in train_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        train_loss /= len(train_loader)
        train_accuracy = 100 * correct / total

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for data in val_loader:
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_loss /= len(val_loader)
        val_accuracy = 100 * correct / total

        # Scheduler step
        scheduler.step(val_loss)

        # Check if learning rate has been reduced
        new_lr = optimizer.param_groups[0]['lr']
        if new_lr < current_lr:
            lr_reduction_epochs.append(epoch+1)  # Epochs are 1-indexed
            current_lr = new_lr

        # Record per-epoch logs
        epoch_logs.append({
            'epoch': epoch+1,
            'train_loss': train_loss,
            'train_accuracy': train_accuracy,
            'val_loss': val_loss,
            'val_accuracy': val_accuracy,
            'learning_rate': current_lr
        })

        # Early Stopping check
        if val_loss < best_val_loss - early_stopping_min_delta:
            best_val_loss = val_loss
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= early_stopping_patience:
            early_stop = True

        # Optionally, print progress
        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%, "
              f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    # Prepare training log
    training_log = {
        'final_epoch': epoch+1,
        'lr_reduction_epochs': lr_reduction_epochs,
        'epoch_logs': epoch_logs
    }

    # Return final validation loss, accuracy, training log, and the trained model
    return val_loss, val_accuracy, training_log, model, optimizer, scheduler

# Main loop for hyperparameter tuning
results = []

for idx, hparams in enumerate(hyperparameter_combinations):
    # Set the base seed
    set_seed(42)
    # Generate a unique id for the hyperparameters
    hparams_str = json.dumps(hparams, sort_keys=True)
    hparams_hash = hashlib.md5(hparams_str.encode('utf-8')).hexdigest()

    # Check if this combination exists in existing_results
    if hparams_hash in existing_hashes:
        print(f"Skipping hyperparameter combination {idx+1}/{len(hyperparameter_combinations)}: {hparams} (already tested)")
        continue
    else:
        print(f"Testing hyperparameter combination {idx+1}/{len(hyperparameter_combinations)}: {hparams}")
        try:
            # Call train_and_evaluate and record the computation time
            start_time = time.time()
            val_loss, val_accuracy, training_log, model, optimizer, scheduler = train_and_evaluate(hparams)
            end_time = time.time()
            computation_time = end_time - start_time

            # Save the per-epoch logs to a file
            log_filename = f"misc/classifier_training_log_{hparams_hash}.json"
            os.makedirs(os.path.dirname(log_filename), exist_ok=True)
            with open(log_filename, 'w') as f:
                json.dump(training_log, f)

            # Append the result to the CSV file
            with open(csv_filename, 'a', newline='') as csvfile:
                fieldnames = ['hparams_hash'] + list(hparams.keys()) + ['val_loss', 'val_accuracy', 'computation_time', 'final_num_epochs', 'lr_reduction_epochs', 'log_filename']
                writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
                # If the file is new, write the header
                if csvfile.tell() == 0:
                    writer.writeheader()
                result_row = {'hparams_hash': hparams_hash}
                for key, value in hparams.items():
                    # Convert value to string using json.dumps
                    result_row[key] = json.dumps(value)
                result_row.update({
                    'val_loss': val_loss,
                    'val_accuracy': val_accuracy,
                    'computation_time': computation_time,
                    'final_num_epochs': training_log['final_epoch'],
                    'lr_reduction_epochs': json.dumps(training_log['lr_reduction_epochs']),
                    'log_filename': log_filename
                })
                writer.writerow(result_row)

            # Update the set of existing hashes
            existing_hashes.add(hparams_hash)

            # Check and save the best model
            if val_loss < best_val_loss:
                torch.save({
                    'state_dict': model.state_dict(),         # Model weights
                    'config': {                               # Model configuration
                        'encoder_output_size': encoder_output_size,
                        'classifier_hidden_sizes': hparams['classifier_hidden_sizes'],
                        'leaky_relu_negative_slope': hparams['leaky_relu_negative_slope'],
                        'batch_norm': hparams['batch_norm'],
                    },
                    'hyperparameters': hparams,              # Hyperparameters
                    'training_log': training_log,            # Training logs
                    'best_val_loss': val_loss,               # Best validation loss
                    'best_val_accuracy': val_accuracy,       # Best validation accuracy
                    'optimizer_state': optimizer.state_dict(), # Optimizer state
                    'scheduler_state': scheduler.state_dict()  # Scheduler state
                }, 'SAE_classifier_best_model.pth')
                best_val_loss = val_loss
                print(f"New best model saved with validation accuracy: {val_accuracy:.2f}%\n")
            else:
                print(f"Validation accuracy {val_loss:.3f}% did not improve over the best loss {best_val_loss:.3f}%\n")

            results.append({'hparams': hparams, 'val_loss': val_loss, 'val_accuracy': val_accuracy})
        except Exception as e:
            print(f"Error with hyperparameters {hparams}: {e}\n")
            continue

# Find the best hyperparameters based on validation accuracy
if results:
    best_result = max(results, key=lambda x: x['val_accuracy'])
    print("Best hyperparameters based on validation accuracy:")
    print(best_result['hparams'])
    print(f"Validation Loss: {best_result['val_loss']:.4f}, Validation Accuracy: {best_result['val_accuracy']:.2f}%")
    print("The best model has been saved as 'SAE_classifier_best_model.pth'")
else:
    print("No successful runs to report.")


  model_checkpoint = torch.load('SAE_best_model.pth', map_location=device)


Current best validation accuracy from CSV: inf%
Testing hyperparameter combination 1/24: {'learning_rate': 0.001, 'classifier_hidden_sizes': [100], 'batch_size': 32, 'weight_decay': 0.0001, 'leaky_relu_negative_slope': 0.01, 'num_epochs': 100, 'batch_norm': True, 'early_stopping_patience': 5, 'early_stopping_min_delta': 0.0001, 'scheduler_mode': 'min', 'scheduler_factor': 0.1, 'scheduler_patience': 3, 'scheduler_threshold': 0.0001, 'scheduler_cooldown': 0}
Epoch [1/100], Training Loss: 0.9309, Training Accuracy: 70.50%, Validation Loss: 0.6825, Validation Accuracy: 77.93%
Epoch [2/100], Training Loss: 0.6940, Training Accuracy: 77.37%, Validation Loss: 0.6111, Validation Accuracy: 80.40%
Epoch [3/100], Training Loss: 0.6544, Training Accuracy: 78.39%, Validation Loss: 0.5758, Validation Accuracy: 81.30%
Epoch [4/100], Training Loss: 0.6342, Training Accuracy: 79.10%, Validation Loss: 0.5640, Validation Accuracy: 81.72%
Epoch [5/100], Training Loss: 0.6215, Training Accuracy: 79.67%, Va

KeyboardInterrupt: 