In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

from SAE_model import StackedAutoencoder  # Import the SAE class

def init_weights(m, init_type='xavier'):
    if isinstance(m, nn.Linear):
        if init_type == 'xavier':
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif init_type == 'he':  # He initialization
            nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

# Hyperparameters
input_size = 784  # For 28x28 images
layer_sizes = [800, 200, 50]  # layer sizes
activation_functions = [nn.ReLU(), nn.ReLU(), nn.ReLU()]  # Activation functions per layer
dropout_rates = [0.0, 0.0, 0.0]  # Dropout rates per layer
learning_rate = 1e-3
num_epochs = 20
batch_size = 128
optimizer_type = 'Adam' 
weight_decay = 1e-5 
loss_function = nn.MSELoss()
early_stopping = False
patience = 5
scheduler_step_size = 10
scheduler_gamma = 0.1

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model
model = StackedAutoencoder(
    input_size=input_size, 
    layer_sizes=layer_sizes, 
    activation_functions=activation_functions,
    dropout_rates=dropout_rates,
    weight_init=lambda m: init_weights(m, init_type='he')
).to(device)

# Choose optimizer
if optimizer_type == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
elif optimizer_type == 'SGD':
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
else:
    raise ValueError("Unsupported optimizer type")

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=scheduler_step_size, gamma=scheduler_gamma)

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Adjust if necessary
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

# Load the Kuzushiji-MNIST dataset
train_dataset = datasets.KMNIST(root='../data', train=True, transform=transform, download=True)
val_dataset = datasets.KMNIST(root='../data', train=False, transform=transform, download=True)

# Data loaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

# Training loop
best_val_loss = float('inf')
epochs_no_improve = 0
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for data in train_loader:
        inputs, _ = data
        inputs = inputs.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = loss_function(outputs, inputs)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
    train_loss /= len(train_loader)
    
    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data in val_loader:
            inputs, _ = data
            inputs = inputs.to(device)
            outputs = model(inputs)
            loss = loss_function(outputs, inputs)
            val_loss += loss.item()
    val_loss /= len(val_loader)
    
    scheduler.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    
    # Early stopping logic (if enabled)
    if early_stopping:
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_sae_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early stopping!")
                model.load_state_dict(torch.load('best_sae_model.pth'))
                break

# Save the final model
torch.save(model.state_dict(), 'sae_model_final.pth')


ModuleNotFoundError: No module named 'torch'