In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

from SAE_model import StackedAutoencoder  # Import the SAE class

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='leaky_relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

# =========================
# Hyperparameters
# =========================

# Model parameters
input_size = 784  # For 28x28 images
layer_sizes = [800, 200, 50]  # Sizes of each hidden layer
leaky_relu_negative_slope = 0.01  # Negative slope for LeakyReLU
dropout_rates = [0.0, 0.0, 0.0]  # Dropout rates for each layer

# Training parameters
learning_rate = 1e-3
num_epochs = 20
batch_size = 128
weight_decay = 1e-5 
loss_function = nn.MSELoss()

# Early stopping parameters
early_stopping = True
early_stopping_patience = 5  # Number of epochs to wait for improvement
early_stopping_min_delta = 0.0  # Minimum change to qualify as improvement

# Learning rate scheduler parameters (ReduceLROnPlateau)
scheduler_mode = 'min'  # Mode for scheduler ('min' or 'max')
scheduler_factor = 0.1  # Factor to reduce the learning rate
scheduler_patience = 10  # Number of epochs with no improvement after which to reduce LR
scheduler_threshold = 1e-4  # Threshold for measuring the new optimum
scheduler_cooldown = 0  # Number of epochs to wait before resuming normal operation

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# =========================
# Model Initialization
# =========================

model = StackedAutoencoder(
    input_size=input_size, 
    layer_sizes=layer_sizes, 
    activation_functions=[
        nn.LeakyReLU(negative_slope=leaky_relu_negative_slope) 
        for _ in layer_sizes
    ],
    dropout_rates=dropout_rates,
    weight_init=init_weights  # Initialize weights using the defined function
).to(device)

# =========================
# Optimizer and Scheduler
# =========================

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode=scheduler_mode,
    factor=scheduler_factor,
    patience=scheduler_patience,
    threshold=scheduler_threshold,
    threshold_mode='rel',
    cooldown=scheduler_cooldown,
)

# =========================
# Data Preparation
# =========================

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Normalize to [-1, 1]
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image to a vector
])

# Load the Kuzushiji-MNIST dataset
train_dataset = datasets.KMNIST(root='../data', train=True, transform=transform, download=True)
val_dataset = datasets.KMNIST(root='../data', train=False, transform=transform, download=True)

# Data loaders
train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset, 
    batch_size=batch_size, 
    shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_dataset, 
    batch_size=batch_size, 
    shuffle=False
)

# =========================
# Training Loop
# =========================

best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    for data in train_loader:
        inputs, _ = data
        inputs = inputs.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = loss_function(outputs, inputs)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    train_loss /= len(train_loader)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for data in val_loader:
            inputs, _ = data
            inputs = inputs.to(device)
            outputs = model(inputs)
            loss = loss_function(outputs, inputs)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    
    # Step the scheduler with validation loss
    scheduler.step(val_loss)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Training Loss: {train_loss:.4f}, "
          f"Validation Loss: {val_loss:.4f}")
    
    # Early stopping logic
    if early_stopping:
        if val_loss < best_val_loss - early_stopping_min_delta:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_sae_model.pth')
        else:
            epochs_no_improve += 1
            print(f"No improvement in validation loss for {epochs_no_improve} epoch(s).")
            if epochs_no_improve >= early_stopping_patience:
                print("Early stopping triggered!")
                model.load_state_dict(torch.load('best_sae_model.pth'))
                break

# Save the final model
torch.save(model.state_dict(), 'sae_model_final.pth')
print("Training complete. Model saved.")


Epoch [1/20], Training Loss: 0.4463, Validation Loss: 0.3614
Epoch [2/20], Training Loss: 0.3529, Validation Loss: 0.3286
Epoch [3/20], Training Loss: 0.3246, Validation Loss: 0.3149
Epoch [4/20], Training Loss: 0.3129, Validation Loss: 0.3063
Epoch [5/20], Training Loss: 0.3039, Validation Loss: 0.2987
Epoch [6/20], Training Loss: 0.2963, Validation Loss: 0.2928
Epoch [7/20], Training Loss: 0.2896, Validation Loss: 0.2868
Epoch [8/20], Training Loss: 0.2837, Validation Loss: 0.2812
Epoch [9/20], Training Loss: 0.2782, Validation Loss: 0.2765
Epoch [10/20], Training Loss: 0.2733, Validation Loss: 0.2720
Epoch [11/20], Training Loss: 0.2691, Validation Loss: 0.2702
Epoch [12/20], Training Loss: 0.2654, Validation Loss: 0.2650
Epoch [13/20], Training Loss: 0.2613, Validation Loss: 0.2619
Epoch [14/20], Training Loss: 0.2578, Validation Loss: 0.2591
Epoch [15/20], Training Loss: 0.2540, Validation Loss: 0.2555
Epoch [16/20], Training Loss: 0.2506, Validation Loss: 0.2530
Epoch [17/20], Tr