In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

from MLP_model import MLPClassifier  # Import the MLPClassifier class

def init_weights(m, init_type='he'):
    if isinstance(m, nn.Linear):
        if init_type == 'xavier':
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif init_type == 'he':  # He initialization
            nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

# Hyperparameters
input_size = 784  # For 28x28 images
layer_sizes = [1024]  # Single hidden layer with 1024 neurons
output_size = 10  # Number of classes
activation_functions = [nn.ReLU()]  # ReLU activation
dropout_rates = [0.5]  # Dropout rate for the single hidden layer
batch_norm = [False]  # No batch normalization for simplicity
learning_rate = 0.00005
num_epochs = 100  # Set high to allow early stopping to take effect
batch_size = 128
optimizer_type = 'Adam'
weight_decay = 0
loss_function = nn.CrossEntropyLoss()  # Cross-entropy loss
early_stopping = True
patience = 5  # Early stopping patience

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the model
model = MLPClassifier(
    input_size=input_size,
    layer_sizes=layer_sizes,
    output_size=output_size,
    activation_functions=activation_functions,
    dropout_rates=dropout_rates,
    batch_norm=batch_norm,
    weight_init=lambda m: init_weights(m, init_type='he')
).to(device)

# Choose optimizer
if optimizer_type == 'Adam':
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
else:
    raise ValueError("Unsupported optimizer type")

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Normalize data
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the image
])

# Load the KMNIST dataset
train_dataset = datasets.KMNIST(root='../data', train=True, transform=transform, download=True)
val_dataset = datasets.KMNIST(root='../data', train=False, transform=transform, download=True)

# Data loaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

# Training loop
best_val_loss = float('inf')
epochs_no_improve = 0
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct_train = 0
    total_train = 0

    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        outputs = model(inputs)
        loss = loss_function(outputs, labels)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Calculate training accuracy
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    train_loss /= len(train_loader)
    train_accuracy = 100 * correct_train / total_train

    # Validation
    model.eval()
    val_loss = 0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for data in val_loader:
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            val_loss += loss.item()

            # Calculate validation accuracy
            _, predicted = torch.max(outputs.data, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()

    val_loss /= len(val_loader)
    val_accuracy = 100 * correct_val / total_val

    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%, "
          f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    # Early stopping logic
    if early_stopping:
        if val_loss < best_val_loss - 0.001:  # Improvement threshold
            best_val_loss = val_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_mlp_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early stopping!")
                model.load_state_dict(torch.load('best_mlp_model.pth'))
                break

# Save the final model
torch.save(model.state_dict(), 'mlp_model_final.pth')


Epoch [1/100], Training Loss: 1.2495, Training Accuracy: 60.41%, Validation Loss: 1.0591, Validation Accuracy: 66.44%
Epoch [2/100], Training Loss: 0.6768, Training Accuracy: 79.09%, Validation Loss: 0.8742, Validation Accuracy: 72.23%
Epoch [3/100], Training Loss: 0.5428, Training Accuracy: 83.63%, Validation Loss: 0.7957, Validation Accuracy: 74.70%
Epoch [4/100], Training Loss: 0.4648, Training Accuracy: 86.03%, Validation Loss: 0.7111, Validation Accuracy: 77.50%
Epoch [5/100], Training Loss: 0.4119, Training Accuracy: 87.61%, Validation Loss: 0.6608, Validation Accuracy: 78.94%
Epoch [6/100], Training Loss: 0.3742, Training Accuracy: 88.92%, Validation Loss: 0.6150, Validation Accuracy: 80.57%
Epoch [7/100], Training Loss: 0.3407, Training Accuracy: 89.92%, Validation Loss: 0.5797, Validation Accuracy: 81.55%
Epoch [8/100], Training Loss: 0.3166, Training Accuracy: 90.59%, Validation Loss: 0.5507, Validation Accuracy: 82.57%
Epoch [9/100], Training Loss: 0.2946, Training Accuracy: