In [4]:
import torch
from torchvision import datasets
import torch.nn.functional

class MLP:
    def __init__(self, input_size, hidden_size, output_size):
        self.weights1 = torch.randn(input_size, hidden_size, requires_grad=True)
        self.bias1 = torch.randn(1, hidden_size, requires_grad=True)
        self.weights2 = torch.randn(hidden_size, output_size, requires_grad=True)
        self.bias2 = torch.randn(1, output_size, requires_grad=True)
        self.training = True  # Initialize the mode as 'training'
        self.dropout_prob = 0.5
        self.l2_lambda = 0.001

        # Batch normalization parameters
        self.bn1_running_mean = torch.zeros(1, hidden_size)
        self.bn1_running_var = torch.ones(1, hidden_size)
        self.bn1_mean = 0
        self.bn1_var = 0

    def forward(self, x):
        # Layer 1 with batch normalization and ReLU activation
        x = torch.mm(x, self.weights1) + self.bias1

        if self.training:
            self.bn1_mean = x.mean(0, keepdim=True)
            self.bn1_var = x.var(0, unbiased=False, keepdim=True)
            x = (x - self.bn1_mean) / torch.sqrt(self.bn1_var + 1e-5)
        else:
            x = (x - self.bn1_running_mean) / torch.sqrt(self.bn1_running_var + 1e-5)

        x = torch.relu(x)

        # Dropout
        if self.training:
            mask = torch.rand(x.shape) > self.dropout_prob
            x = x * mask

        # Layer 2
        x = torch.mm(x, self.weights2) + self.bias2
        return x

    def backward(self, x, target, lr):
        loss = torch.nn.functional.cross_entropy(x, target)

        # Add L2 regularization term to the loss
        loss += 0.5 * self.l2_lambda * (torch.sum(self.weights1 ** 2) + torch.sum(self.weights2 ** 2))

        # Backpropagation
        loss.backward()

        # Update weights and biases during training
        if self.training:
            with torch.no_grad():
                self.weights1 -= lr * self.weights1.grad
                self.bias1 -= lr * self.bias1.grad
                self.weights2 -= lr * self.weights2.grad
                self.bias2 -= lr * self.bias2.grad

                # Batch normalization running statistics are updated with a weighted combination of the current running statistics and the batch statistics (bn1_mean and bn1_var) computed during training
                self.bn1_running_mean = 0.9 * self.bn1_running_mean + 0.1 * self.bn1_mean
                self.bn1_running_var = 0.9 * self.bn1_running_var + 0.1 * self.bn1_var

                # Zero the gradients for the next iteration
                self.weights1.grad.zero_()
                self.bias1.grad.zero_()
                self.weights2.grad.zero_()
                self.bias2.grad.zero_()

        return loss.item()

# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda" if device == "gpu" and torch.cuda.is_available() else "cpu")

# Loading MNIST dataset
train_dataset = datasets.MNIST('data', train=True, download=True, transform=None)
validation_dataset = datasets.MNIST('data', train=False, download=True, transform=None)

# Initialize the model
model = MLP(784, 100, 10)
model.weights1 = model.weights1.to(device)
model.bias1 = model.bias1.to(device)
model.weights2 = model.weights2.to(device)
model.bias2 = model.bias2.to(device)

# Training loop
lr = 0.1
epochs = 20
batch_size = 32
train_data = train_dataset.data.float()
train_targets = train_dataset.targets

for epoch in range(epochs):
    model.training = True  # Set the model in training mode
    total_loss = 0
    correct = 0
    total = 0
    num_batches = len(train_data) // batch_size

    for batch_idx in range(num_batches):
        start_idx = batch_idx * batch_size
        end_idx = (batch_idx + 1) * batch_size

        data = train_data[start_idx:end_idx].view(-1, 784).to(device)
        target = train_targets[start_idx:end_idx].to(device)

        output = model.forward(data)
        loss = model.backward(output, target, lr)

        # Calculate accuracy
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
        total_loss += loss

    # Print training statistics
    accuracy = 100 * correct / total
    avg_loss = total_loss / num_batches
    print(f"Epoch [{epoch+1}/{epochs}] loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

# Validation loop
model.training = False  # Set the model in evaluation mode
total_loss = 0
correct = 0
total = 0
validation_data = validation_dataset.data.float()
validation_targets = validation_dataset.targets
num_validation_batches = len(validation_data) // batch_size

for batch_idx in range(num_validation_batches):
    start_idx = batch_idx * batch_size
    end_idx = (batch_idx + 1) * batch_size

    data = validation_data[start_idx:end_idx].view(-1, 784).to(device)
    target = validation_targets[start_idx:end_idx].to(device)

    # Forward pass
    output = model.forward(data)

    # Calculate loss and backpropagate
    loss = model.backward(output, target, lr)

    # Calculate accuracy
    _, predicted = output.max(1)
    total += target.size(0)
    correct += predicted.eq(target).sum().item()
    total_loss += loss

accuracy = 100 * correct / total
avg_loss = total_loss / num_validation_batches
print(f"Validation loss: {avg_loss:.4f}, accuracy: {accuracy:.2f}%")

Epoch [1/20] loss: 34.7022, Accuracy: 58.18%
Epoch [2/20] loss: 23.3214, Accuracy: 76.33%
Epoch [3/20] loss: 16.1527, Accuracy: 79.98%
Epoch [4/20] loss: 11.2571, Accuracy: 82.29%
Epoch [5/20] loss: 7.8833, Accuracy: 84.01%
Epoch [6/20] loss: 5.5610, Accuracy: 85.35%
Epoch [7/20] loss: 3.9527, Accuracy: 86.80%
Epoch [8/20] loss: 2.8440, Accuracy: 87.94%
Epoch [9/20] loss: 2.0734, Accuracy: 88.91%
Epoch [10/20] loss: 1.5469, Accuracy: 89.45%
Epoch [11/20] loss: 1.1745, Accuracy: 90.32%
Epoch [12/20] loss: 0.9183, Accuracy: 90.80%
Epoch [13/20] loss: 0.7428, Accuracy: 91.30%
Epoch [14/20] loss: 0.6183, Accuracy: 91.78%
Epoch [15/20] loss: 0.5366, Accuracy: 91.95%
Epoch [16/20] loss: 0.4799, Accuracy: 92.04%
Epoch [17/20] loss: 0.4393, Accuracy: 92.22%
Epoch [18/20] loss: 0.4160, Accuracy: 92.33%
Epoch [19/20] loss: 0.4022, Accuracy: 92.32%
Epoch [20/20] loss: 0.3877, Accuracy: 92.41%
Validation loss: 0.2999, accuracy: 95.07%
