# Single Layer Perceptron (SLP) for MNIST Classification

This notebook implements a simple Single Layer Perceptron to classify handwritten digits from the MNIST dataset.

## Import Libraries

In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

Softmax_fn = torch.nn.Softmax(dim=1)

def show_data(data_tuple):
    image, label = data_tuple
    plt.imshow(image.squeeze(), cmap='gray')
    plt.title(f"Label: {label}")
    plt.axis("off")

## Load MNIST Dataset

In [None]:
train_dataset = dsets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
print("Print the training dataset:\n", train_dataset)

validation_dataset = dsets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
print("Print the validating dataset:\n ", validation_dataset)

## Visualize Sample Images

In [None]:
h = plt.figure(figsize=(10,8))
cols, rows = 5, 5
for i in range(1, cols*rows+1):
    random_idx = torch.randint(len(train_dataset), size=(1,)).item() 
    img, label = train_dataset[random_idx]
    h.add_subplot(rows, cols, i)
    plt.title(label)
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap='gray')
plt.show()

## Define the SLP Model

In [None]:
class MLP(nn.Module):
    # Constructor
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.hidden = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU()
        )
        self.out = nn.Linear(hidden_size, output_size)

    # Prediction
    def forward(self, x):
        x = self.hidden(x)
        pred = self.out(x)
        return pred

## Initialize Model and Set Device

In [None]:
input_dim = 28*28
hidden_dim = 20  # can adjust e.g. 32 or 130
output_dim = 10

device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
    device = "mps" if torch.backends.mps.is_available() else "cpu"

print("Using device:", device)

model = MLP(input_dim, hidden_dim, output_dim).to(device)
print('The model: \n', model)

## Inspect Model Parameters

In [None]:
print('W: ', list(model.parameters())[0].size())
print('b: ', list(model.parameters())[1].size())

In [None]:
model.state_dict()

In [None]:
print('Hidden layer W: ', model.state_dict()['hidden.0.weight'].size())
print('Hidden layer b: ', model.state_dict()['hidden.0.bias'].size())
print('Output layer W: ', model.state_dict()['out.weight'].size())
print('Output layer b: ', model.state_dict()['out.bias'].size())

## Visualize Model Weights

In [None]:
def PlotParameters(model, hiddenDim): 
    W = model.state_dict()['hidden.0.weight'].data.cpu()
    b = model.state_dict()['hidden.0.bias'].data.cpu()
    w_min = W.min().item()
    w_max = W.max().item()
    rows = int(np.ceil(hiddenDim/10.0))
    fig, axes = plt.subplots(rows, 10, figsize=(20, rows*3))
    fig.subplots_adjust(hspace=0.01, wspace=0.1)
    axes_flat = axes.flat if hasattr(axes, 'flat') else [axes]
    for i, ax in enumerate(axes_flat):
        if i < hiddenDim:
            ax.set_xlabel(f"neuron: {i}")
            Img = W[i, :].view(28, 28)
            ax.imshow(Img, vmin=w_min, vmax=w_max, cmap='seismic')
            ax.set_xticks([])
            ax.set_yticks([])
    plt.show()

PlotParameters(model=model, hiddenDim=hidden_dim)

## Setup Training Components

In [None]:
learning_rate = 0.001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Alternative:
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
loss_func = nn.CrossEntropyLoss()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True, num_workers=1)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, batch_size=100, shuffle=True, num_workers=1)

## Define Training Function

In [None]:
# Initialize lists to track metrics
train_accuracy_list = []
loss_list = []

# Training function

def train(dataloader, model, loss_func, optimizer):
    size = len(dataloader.dataset)
    model.train()
    train_loss, correct = 0, 0
    for batchNr, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()

        pred = model(X.view(-1, input_dim))
        loss = loss_func(pred, y)
        loss.backward()
        optimizer.step()
        _, yhat = torch.max(pred.data, 1)
        correct += (yhat == y).sum().item()
        train_loss += loss
        if (batchNr+1) % 100 == 0:
            loss_val, current = loss.item(), (batchNr+1)*len(y)
            print(f"loss: {loss_val:>7f} [{current:>5d}/{size:>5d}]")
        # PlotParameters(model=model, hiddenDim=hidden_dim)

    accuracy = correct / size
    train_accuracy_list.append(accuracy)
    loss_list.append(train_loss.item()/size)

In [None]:
val_loss_list = []
val_accuracy_list = []

# Validation function

def validate(dataloader, model, loss_func):
    num_batches = len(dataloader)
    size = len(dataloader.dataset)
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad():
        for x_test, y_test in dataloader:
            x_test, y_test = x_test.to(device), y_test.to(device)
            pred = model(x_test.view(-1, input_dim))
            val_loss += loss_func(pred, y_test)
            _, yhat = torch.max(pred.data, 1)
            correct += (yhat == y_test).sum().item()

    val_loss /= num_batches
    accuracy = correct / size
    val_accuracy_list.append(accuracy)
    val_loss_list.append(val_loss.item())
    print(f"Validation Error: \n Accuracy: {(100*accuracy):>0.1f} Avg loss: {(val_loss.item()):>8f}")

n_epochs = 10
loss_list = []
train_accuracy_list = []
val_loss_list = []
val_accuracy_list = []
N_train = len(train_dataset)
N_test = len(validation_dataset)
print(N_test)
print(len(train_loader.dataset))

for t in range(n_epochs):
    print(f"n_epoch {t+1}\n-------------------------------")
    train(train_loader, model, loss_func, optimizer)
    validate(validation_loader, model, loss_func)
print("Done!")

fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.plot(loss_list, color=color)
ax1.set_xlabel('epoch', color=color)
ax1.set_ylabel('total loss', color=color)
ax1.tick_params(axis='y', color=color)

ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('accuracy', color=color)
ax2.plot(train_accuracy_list, color=color)
ax2.tick_params(axis='y', color=color)
fig.tight_layout()
color = 'tab:orange'
ax2.plot(val_accuracy_list, color=color)
color = 'tab:purple'
ax2.plot(val_loss_list, color=color)

# Inference example
X, y = validation_dataset[63]
X = X.to(device)
label_tensor = torch.tensor([y], device=device)
z = model(X.reshape(-1, input_dim))
_, yhat = torch.max(z, 1)
yhat_int = int(yhat.item())
X_cpu = X.cpu()
show_data((X_cpu, y))
plt.show()
print("Predicted yhat:", yhat_int)
print("Raw logits z:", z)
print("Probability of predicted class", torch.max(Softmax_fn(z)).item())
print("Input reshaped:", X_cpu.shape)
hidden_out = model.hidden(X.reshape(-1, input_dim).to(device))
print("Hidden layer output shape:", hidden_out.shape)

b = model.state_dict()['hidden.0.bias'].data.cpu()
print("Bias sample (first 5):", b[:5])