# Exercise 2: Multi-Layer Perceptrons (MLPs) with PyTorch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ITI-THM/EiDL/blob/master/docs/1-uebung/Perceptron.ipynb)

Welcome to this exercise on Multi-Layer Perceptrons (MLPs)! In this notebook, we'll implement MLPs using PyTorch to solve both simple problems like XOR and more complex tasks like digit recognition with MNIST.

**📖 Learning Goals:**
* Understand why MLPs are necessary for problems like XOR
* Implement MLPs using PyTorch's neural network modules
* Train models on both the XOR problem and the MNIST dataset
* Visualize decision boundaries and model performance
* Gain hands-on experience with PyTorch's data handling capabilities


## Setup

Let's import the necessary libraries for our exercises. We'll use PyTorch for building and training our neural networks, along with visualization tools like Matplotlib and Seaborn.


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm  # Import tqdm for progress bar
import seaborn as sns

# For MNIST dataset
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, random_split

# Set visual style
sns.set_context("notebook")
sns.set_style("white")

# Set a random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

This time it ccan be usefull to execute the training on a GPU. So with the next snippet we check if a GPU is present and if so we set it as our used device.

In [None]:
# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Part 1: Solving the XOR Problem with PyTorch

### The XOR Problem

The XOR (exclusive OR) problem is a classic example that demonstrates the limitations of single-layer neural networks and the power of multi-layer networks:

* Input [0, 0] → Output 0
* Input [0, 1] → Output 1
* Input [1, 0] → Output 1
* Input [1, 1] → Output 0

A single perceptron can only create a linear decision boundary (a straight line in 2D), which cannot separate the XOR pattern. MLPs overcome this by introducing one or more **hidden layers** between the input and output layers, allowing the network to learn non-linear decision boundaries.

Let's first create our XOR dataset:


In [None]:
# Create XOR Dataset
X_xor = torch.tensor([
    # TODO
], dtype=torch.float32)

y_xor = torch.tensor([
    # TODO
], dtype=torch.float32)

print("XOR dataset created:")
print(f"Inputs shape: {X_xor.shape}")
print(f"Outputs shape: {y_xor.shape}")


### Building an MLP for XOR with PyTorch

Now, let's define a simple MLP model using PyTorch's neural network modules. For the XOR problem, we'll use:

* **Input Layer:** 2 neurons (for the 2 inputs of XOR)
* **Hidden Layer:** 2 neurons (allows learning non-linearity)
* **Output Layer:** 1 neuron (for the binary output of XOR)
* **Activation Function:** Sigmoid (outputs values between 0 and 1)

Resources: https://docs.pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html


In [None]:
# Define our MLP model for XOR
class XOR_MLP(nn.Module):
    def __init__(self, activation_fn=nn.Sigmoid()):
        super(XOR_MLP, self).__init__()
        # TODO: Define the layers

    def forward(self, x):
        # TODO: Define the forward pass
        return x

# Create the model and move it to the available device
model = XOR_MLP().to(device)
print(model)


### Training the XOR Model

Now we'll train our MLP on the XOR problem using:
* Mean Squared Error (MSE) as the loss function
* Stochastic Gradient Descent (SGD) as the optimizer
* Early stopping to prevent overfitting


In [None]:
# Training parameters
epochs = 20000
learning_rate = 0.5
loss_history = []

# Loss function and optimizer
# TODO: Choose a loss function and optimizer
criterion = ?
optimizer = ?

# Early stopping parameters
patience = 1000  # Number of epochs to wait for improvement
min_delta = 1e-5  # Minimum change to qualify as improvement
best_loss = float('inf')
counter = 0

# Move data to device
X_xor = X_xor.to(device)
y_xor = y_xor.to(device)

# Training loop
print("--- Starting Training ---")
print(f"Learning Rate: {learning_rate}, Epochs: {epochs}, Early Stopping Patience: {patience}")

progress_bar = tqdm(range(epochs), desc="Training", unit="epoch")

for epoch in progress_bar:
    #TODO: Forward pass
    

    #TODO: Calculate loss
    loss = ?
    loss_history.append(loss.item())

    #TODO: Backward pass and optimize


    # Update progress bar with current loss
    if (epoch + 1) % 100 == 0:  # Update less frequently to avoid slowdown
        progress_bar.set_postfix(loss=f"{loss.item():.6f}")

    # Early stopping logic
    if loss.item() < best_loss - min_delta:
        best_loss = loss.item()
        counter = 0
    else:
        counter += 1

    if counter >= patience:
        print(f"\nEarly stopping triggered after {epoch+1} epochs. No improvement for {patience} epochs.")
        break

print("--- Training Finished ---")
print(f"Final loss: {loss.item():.6f}")

Plot the loss history

In [None]:
#TODO: Plot the loss history


In [None]:
# Test the trained model
print("\n--- Testing Trained Network ---")
with torch.no_grad():  # No need to track gradients during testing
    final_predictions = model(X_xor)

    for i in range(len(X_xor)):
        input_data = X_xor[i].cpu().numpy()
        target = y_xor[i].item()
        prediction = final_predictions[i].item()
        print(f"Input: {input_data}, Target: {target}, Prediction: {prediction:.4f} (Rounded: {round(prediction)})")

# Self-checks
print("\n--- Running Self-Checks ---")
with torch.no_grad():
    # Test rounded predictions match targets
    final_predictions_rounded = torch.round(final_predictions).flatten().int().cpu().numpy()
    expected_xor_outputs = y_xor.flatten().int().cpu().numpy()

    assert np.array_equal(final_predictions_rounded, expected_xor_outputs), \
        f"Test Failed: Final XOR predictions {final_predictions_rounded} do not match targets {expected_xor_outputs}"

    # Check if loss has decreased significantly
    final_loss = loss_history[-1]
    assert final_loss < 0.05, \
        f"Test Failed: Final loss {final_loss:.4f} is too high. Expected < 0.05"

    print("✅ All Self-Checks Passed Successfully!")


### Visualizing the XOR Decision Boundary

Let's visualize the decision boundary learned by our MLP. This will show how the network has created a non-linear boundary to separate the XOR classes.


In [None]:
# Function to plot decision boundary
def plot_decision_boundary(model, X, y):
    # Create a mesh grid
    x_min, x_max = -0.5, 1.5  # Extend beyond the data points
    y_min, y_max = -0.5, 1.5
    h = 0.01  # Step size in the mesh
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))

    # Flatten the grid for prediction
    grid_points = np.c_[xx.ravel(), yy.ravel()]
    grid_tensor = torch.tensor(grid_points, dtype=torch.float32).to(device)

    # Make predictions with the model
    with torch.no_grad():
        Z = model(grid_tensor)
        Z = Z.cpu().numpy().reshape(xx.shape)

    # Create the contour plot
    plt.figure(figsize=(10, 8))

    # Plot the decision boundary
    contour = plt.contourf(xx, yy, Z, levels=20, cmap=plt.cm.RdBu, alpha=0.8)
    plt.colorbar(contour, label='Predicted Value')

    # Plot decision boundary at 0.5 threshold
    plt.contour(xx, yy, Z, levels=[0.5], colors='green', linestyles='--')

    # Plot the data points
    X_np = X.cpu().numpy()
    y_np = y.cpu().numpy()
    plt.scatter(X_np[:, 0], X_np[:, 1], c=y_np.flatten(), cmap=plt.cm.RdBu,
                edgecolors='k', s=200, marker='o')

    # Annotate the points
    for i, label in enumerate(['(0,0)→0', '(0,1)→1', '(1,0)→1', '(1,1)→0']):
        plt.annotate(label, (X_np[i, 0], X_np[i, 1]),
                    xytext=(5, 5), textcoords='offset points',
                    fontsize=12, fontweight='bold')

    plt.xlim(x_min, x_max)
    plt.ylim(y_min, y_max)
    plt.xlabel('X1')
    plt.ylabel('X2')
    plt.title('MLP Decision Boundary for XOR Problem')
    plt.grid(True)
    plt.show()

# Plot the decision boundary
plot_decision_boundary(model, X_xor, y_xor)


## Part 2: MNIST Digit Classification with PyTorch

Now that we've solved the XOR problem, let's move on to a more complex task: classifying handwritten digits using the MNIST dataset.

### The MNIST Dataset

MNIST is a collection of 70,000 grayscale images of handwritten digits (0-9), commonly used for benchmarking machine learning models:
* Each image is 28x28 pixels
* 60,000 images for training, 10,000 for testing
* 10 classes (digits 0-9)

PyTorch makes it easy to load and prepare this dataset for training.


In [None]:
# Define transformations for MNIST
train_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to PyTorch Tensor (scales to [0, 1])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # Normalize with MNIST mean and std
])

# Download and load the training data
trainset = torchvision.datasets.MNIST(root='./data',
                                     train=True,
                                     download=True,
                                     transform=train_transform)

# Split into train and validation
train_size = int(0.85 * len(trainset))
val_size = len(trainset) - train_size
train_dataset, val_dataset = random_split(trainset, [train_size, val_size])

# Download and load the test data
test_dataset = torchvision.datasets.MNIST(root='./data',
                                         train=False,
                                         download=True,
                                         transform=test_transform)

print(f"MNIST dataset loaded.")
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")


### Visualizing MNIST Data

Let's visualize some examples from the MNIST dataset to get a better understanding of what we're working with.


In [None]:
# Function to show MNIST images
def show_mnist_images(dataset, num_images=10):
    # Create a dataloader to get batches
    dataloader = DataLoader(dataset, batch_size=num_images, shuffle=True)

    # Get a batch of images
    images, labels = next(iter(dataloader))

    # Create a grid of images
    fig, axes = plt.subplots(1, num_images, figsize=(15, 3))

    for i in range(num_images):
        # Get the image and remove normalization for display
        img = images[i].squeeze().cpu().numpy()
        label = labels[i].item()

        # Display the image
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f"Label: {label}")
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

# Show some training examples
print("Examples from the MNIST training dataset:")
show_mnist_images(train_dataset)

# Analyze the distribution of digits in the training set
def analyze_mnist_distribution(dataset):
    # Create a dataloader to get all labels
    dataloader = DataLoader(dataset, batch_size=1000, shuffle=False)

    all_labels = []
    for _, labels in dataloader:
        all_labels.append(labels)

    all_labels = torch.cat(all_labels).cpu().numpy()

    # Count occurrences of each digit
    unique, counts = np.unique(all_labels, return_counts=True)

    # Plot distribution
    plt.figure(figsize=(10, 6))
    plt.bar(unique, counts)
    plt.xlabel('Digit')
    plt.ylabel('Count')
    plt.title('Distribution of Digits in MNIST Dataset')
    plt.xticks(unique)
    plt.grid(axis='y', alpha=0.75)

    # Add count labels on top of each bar
    for i, count in enumerate(counts):
        plt.text(unique[i], count + 100, str(count), ha='center')

    plt.show()

print("\nAnalyzing the distribution of digits in the training dataset:")
analyze_mnist_distribution(train_dataset)


### Questions

1. Describe the distribution of digit classes in the MNIST dataset. Are there classes that are more or less frequently represented?
2. Take a look at the sample images shown: Are there any anomalies or similarities in the representation of the digits?
3. What pre-processing steps could you apply to the images before training a neural network? Give reasons for your suggestions.
4. Let's assume that one digit is significantly less frequently represented in the data set than others. What impact could this have on the training and performance of a model?
5. What options are there for checking whether the loaded data is correct and complete?

### Creating DataLoaders for MNIST

For efficient training, we'll use PyTorch's `DataLoader` class to:
* Load data in batches
* Shuffle the training data
* Enable parallel data loading

DataLoaders are essential for handling large datasets like MNIST that don't fit entirely in memory.


In [None]:
# Define batch size
batch_size = 64  # Common batch size, often a power of 2

# Create DataLoader for the training set
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Create DataLoader for the test set
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"DataLoaders created with batch size: {batch_size}")

# Let's examine one batch from the trainloader
dataiter = iter(train_loader)
images_batch, labels_batch = next(dataiter)

print(f"\n--- Example Batch ---")
print(f"Shape of images batch: {images_batch.shape}")  # Shape: [batch_size, channels, height, width]
print(f"Shape of labels batch: {labels_batch.shape}")  # Shape: [batch_size]
print(f"First few labels in the batch: {labels_batch[:10]}")

# Visualize a batch
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
axes = axes.flatten()

for i in range(10):
    img = images_batch[i].squeeze().cpu().numpy()
    label = labels_batch[i].item()

    axes[i].imshow(img, cmap='gray')
    axes[i].set_title(f"Label: {label}")
    axes[i].axis('off')

plt.tight_layout()
plt.show()


### Building an MLP for MNIST

Now, let's define an MLP for the MNIST classification task:

* **Input Layer:** 784 neurons (28×28 pixels flattened)
* **Hidden Layers:** Two hidden layers with ReLU activation
* **Output Layer:** 10 neurons (one for each digit class)

We'll use ReLU activation for the hidden layers as it generally performs better than Sigmoid for deep networks by helping to mitigate the vanishing gradient problem.


In [None]:
class MNIST_MLP(nn.Module):
    def __init__(self, activation_fn=nn.ReLU()):
        super(MNIST_MLP, self).__init__()
        # TODO: Define the correct input size and hidden layers
        input_size = ?  
        hidden_size1 = ?    
        hidden_size2 = ?     
        output_size = ?      

        #TODO: Define the layers of the MLP
        self.flatten = nn.Flatten()  # Flatten the input -> makes the 2d image into a 1d vector
        
        # Note: We don't apply Softmax here because nn.CrossEntropyLoss
        # expects raw logits and applies softmax internally

    def forward(self, x):
        #TODO: Define the forward pass
        return x

# Instantiate the model and move it to the device
model_mnist = MNIST_MLP().to(device)

# Print the model architecture
print("PyTorch MLP Model for MNIST:")
print(model_mnist)

# Test forward pass with a dummy batch
dummy_batch = torch.randn(batch_size, 1, 28, 28).to(device)  # Example batch
output = model_mnist(dummy_batch)
print(f"\nOutput shape for a dummy batch: {output.shape}")  # Should be [batch_size, 10]


### Training the MNIST Model

Now we'll train our MLP on the MNIST dataset using:
* Cross Entropy Loss (appropriate for multi-class classification)
* SGD with momentum as the optimizer
* Early stopping based on validation loss


In [None]:
# Hyperparameters
learning_rate = 0.01
epochs = 10

# Loss function and optimizer
criterion = ?
optimizer = ?

# Training parameters
train_loss_history = []
val_loss_history = []
train_acc_history = []
val_acc_history = []

# Early stopping parameters
patience = 3
min_delta = 1e-4
best_loss = float('inf')
counter = 0

# Training loop
print("--- Starting Training ---")
print(f"Learning Rate: 0.01, Momentum: 0.9, Epochs: {epochs}, Early Stopping Patience: {patience}")

for epoch in range(epochs):
    # Training phase
    model_mnist.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")

    for batch_idx, (data, target) in enumerate(progress_bar):
        # Move data to device
        data, target = data.to(device), target.to(device)

        #TODO: Forward pass
        

        #TODO: Calculate loss
        loss = ?
        train_loss += loss.item()

        # Calculate accuracy
        _, predicted = outputs.max(1)
        total_train += target.size(0)
        correct_train += predicted.eq(target).sum().item()

        #TODO: Backward pass and optimize
        

        # Update progress bar
        progress_bar.set_postfix(loss=f"{loss.item():.4f}", 
                                acc=f"{100.*correct_train/total_train:.2f}%")

    # Calculate average training metrics for the epoch
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100. * correct_train / total_train
    train_loss_history.append(avg_train_loss)
    train_acc_history.append(train_accuracy)

    # Validation phase
    model_mnist.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} [Valid]")

        for batch_idx, (data, target) in enumerate(progress_bar):
            data, target = data.to(device), target.to(device)
            outputs = model_mnist(data)

            # Calculate loss
            loss = criterion(outputs, target)
            val_loss += loss.item()

            # Calculate accuracy
            _, predicted = outputs.max(1)
            total_val += target.size(0)
            correct_val += predicted.eq(target).sum().item()

            # Update progress bar
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", 
                                   acc=f"{100.*correct_val/total_val:.2f}%")

    # Calculate average validation metrics for the epoch
    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100. * correct_val / total_val
    val_loss_history.append(avg_val_loss)
    val_acc_history.append(val_accuracy)

    print(f"Epoch {epoch+1}/{epochs} - "
          f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, "
          f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_accuracy:.2f}%")

    # Early stopping logic
    if avg_val_loss < best_loss - min_delta:
        best_loss = avg_val_loss
        counter = 0
        # Save the best model
        torch.save(model_mnist.state_dict(), 'best_mnist_model.pth')
        print(f"Model improved and saved!")
    else:
        counter += 1
        print(f"No improvement for {counter} epochs")

    if counter >= patience:
        print(f"\nEarly stopping triggered after {epoch+1} epochs. No improvement for {patience} epochs.")
        break

print("--- Training Finished ---")


### Visualizing Training Results

Let's visualize the training and validation metrics to understand how our model learned over time.


In [None]:
# Plot the loss history
plt.figure(figsize=(12, 5))

# Loss subplot
plt.subplot(1, 2, 1)
plt.plot(train_loss_history, label='Training Loss')
plt.plot(val_loss_history, label='Validation Loss')
plt.title('Loss During Training')
plt.xlabel('Epoch')
plt.ylabel('Cross Entropy Loss')
plt.legend()
plt.grid(True)

# Accuracy subplot
plt.subplot(1, 2, 2)
plt.plot(train_acc_history, label='Training Accuracy')
plt.plot(val_acc_history, label='Validation Accuracy')
plt.title('Accuracy During Training')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


### Evaluating the MNIST Model

Now let's evaluate our trained model on the test dataset to see how well it generalizes to unseen data.


In [None]:
# Load the best model
model_mnist.load_state_dict(torch.load('best_mnist_model.pth'))
model_mnist.eval()

# Test the model on the test dataset
test_loss = 0.0
correct = 0
total = 0
all_preds = []
all_targets = []

with torch.no_grad():
    progress_bar = tqdm(test_loader, desc="Testing")

    for data, target in progress_bar:
        data, target = data.to(device), target.to(device)
        outputs = model_mnist(data)

        # Calculate loss
        loss = criterion(outputs, target)
        test_loss += loss.item()

        # Calculate accuracy
        _, predicted = outputs.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        # Store predictions and targets for confusion matrix
        all_preds.extend(predicted.cpu().numpy())
        all_targets.extend(target.cpu().numpy())

        # Update progress bar
        progress_bar.set_postfix(loss=f"{loss.item():.4f}", 
                               acc=f"{100.*correct/total:.2f}%")

# Calculate average test metrics
avg_test_loss = test_loss / len(test_loader)
test_accuracy = 100. * correct / total

print(f"\nTest Results:")
print(f"Average Loss: {avg_test_loss:.4f}")
print(f"Accuracy: {test_accuracy:.2f}%")


### Visualizing Model Performance

Let's create a confusion matrix to see which digits our model struggles with the most.

Resource: https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

In [None]:
#TODO: Create confusion matrix



### Visualizing Misclassifications

Let's look at some examples that our model misclassified to understand its weaknesses.


In [None]:
# Function to show misclassified examples
def show_misclassified_examples(model, test_loader, num_examples=10):
    model.eval()
    misclassified_images = []
    misclassified_labels = []
    misclassified_preds = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)

            # Find misclassified examples
            incorrect_mask = pred.ne(target)

            if incorrect_mask.any():
                misclassified_idx = torch.where(incorrect_mask)[0]

                for idx in misclassified_idx:
                    if len(misclassified_images) < num_examples:
                        misclassified_images.append(data[idx].cpu())
                        misclassified_labels.append(target[idx].item())
                        misclassified_preds.append(pred[idx].item())
                    else:
                        break

            if len(misclassified_images) >= num_examples:
                break

    # Plot misclassified examples
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    axes = axes.flatten()

    for i in range(len(misclassified_images)):
        img = misclassified_images[i].squeeze().numpy()
        true_label = misclassified_labels[i]
        pred_label = misclassified_preds[i]

        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f"True: {true_label}, Pred: {pred_label}")
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

# Show misclassified examples
print("Examples of misclassified digits:")
show_misclassified_examples(model_mnist, test_loader)


## Conclusion

In this exercise, we've:

1. Implemented an MLP in PyTorch to solve the XOR problem
2. Visualized the non-linear decision boundary learned by our network
3. Built a more complex MLP for MNIST digit classification
4. Used PyTorch's data handling capabilities with Dataset and DataLoader
5. Trained, validated, and tested our model
6. Visualized model performance and analyzed results

These skills form the foundation for working with more complex neural network architectures and datasets in the future.

### Next Steps

To further improve the MNIST model, you could:
* Experiment with different network architectures (more/fewer layers, different sizes)
* Try different activation functions (Leaky ReLU, ELU, etc.)
* Implement regularization techniques (dropout, weight decay)
* Use more advanced optimizers (Adam, RMSprop)
* Apply data augmentation to increase the effective training set size
