In [136]:
import torch
import pandas as pd
import math

In [137]:
def relu(x):
    return (x + torch.abs(x)) / 2

def relu_derivative(x):
    return (x > 0).float()

def softmax(x):
    x_stable = x - torch.max(x, dim=1, keepdim=True)[0]
    exp_x = torch.exp(x_stable)
    return exp_x / torch.sum(exp_x, dim=1, keepdim=True)

def cross_entropy_loss(probs, labels):
    N = probs.shape[0]
    correct_probs = probs[torch.arange(N, device=probs.device), labels]
    loss = -torch.log(correct_probs)
    return loss.mean()

def load_data(csv_path, device):
    df = pd.read_csv(csv_path, skiprows=1, low_memory=False, header=None)
    data = df.values
    labels = data[:, 0].astype(int)
    images = data[:, 1:].astype('float32') / 255.0  # Normalize to [0,1]
    # Create tensors and then move them to the device
    images = torch.tensor(images, dtype=torch.float32).to(device)
    labels = torch.tensor(labels, dtype=torch.long).to(device)
    return images, labels

In [138]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [139]:
torch.manual_seed(42)

in_features = 784
hidden1 = 512
hidden2 = 256
hidden3 = 128
out_features = 10

# He initialization for each layer:
W1 = torch.randn(in_features, hidden1, dtype=torch.float32, device=device) * math.sqrt(2.0/in_features)
b1 = torch.zeros(1, hidden1, dtype=torch.float32, device=device)

W2 = torch.randn(hidden1, hidden2, dtype=torch.float32, device=device) * math.sqrt(2.0/hidden1)
b2 = torch.zeros(1, hidden2, dtype=torch.float32, device=device)

W3 = torch.randn(hidden2, hidden3, dtype=torch.float32, device=device) * math.sqrt(2.0/hidden2)
b3 = torch.zeros(1, hidden3, dtype=torch.float32, device=device)

W4 = torch.randn(hidden3, out_features, dtype=torch.float32, device=device) * math.sqrt(2.0 / hidden3)
b4 = torch.zeros(1, out_features, dtype=torch.float32, device=device)

In [140]:
# Load the data
train_images, train_labels = load_data('mnist_train.csv', device)
test_images, test_labels   = load_data('mnist_test.csv', device)

#Data to device:
train_images = train_images.to(device)
train_labels = train_labels.to(device)
test_images = test_images.to(device)
test_labels = test_labels.to(device)

# Parameters
learning_rate = 0.05
lambda_l2 = 0.0001
num_epochs = 10
batch_size = 32
num_train = train_images.shape[0]

In [141]:
for epoch in range(num_epochs):
    permutation = torch.randperm(num_train, device=device)
    running_loss = 0.0
    correct_train = 0

    for i in range(0, num_train, batch_size):
        indices = permutation[i:i+batch_size]
        X = train_images[indices]  # (B, 784)
        y = train_labels[indices]  # (B,)
        B = X.shape[0]

        # Forward Pass
        z1 = torch.matmul(X, W1) + b1         # (B, 512)
        a1 = relu(z1)

        z2 = torch.matmul(a1, W2) + b2          # (B, 256)
        a2 = relu(z2)

        z3 = torch.matmul(a2, W3) + b3          # (B, 128)
        a3 = relu(z3)

        logits = torch.matmul(a3, W4) + b4      # (B, 10)
        probs = softmax(logits)                # (B, 10)

        loss = cross_entropy_loss(probs, y)
        # Regularization loss added to cross-entropy loss:
        reg_loss = lambda_l2 * (torch.sum(W1**2) + torch.sum(W2**2) + torch.sum(W3**2) + torch.sum(W4**2))
        loss_total = loss + reg_loss

        # Backward Pass (manual gradients)
        one_hot = torch.zeros_like(probs)
        one_hot[torch.arange(B, device=device), y] = 1

        d_logits = (probs - one_hot) / B  # (B, 10)

        dW4 = torch.matmul(a3.t(), d_logits)           # (hidden3, 10)
        db4 = d_logits.sum(dim=0, keepdim=True)          # (1, 10)

        d_a3 = torch.matmul(d_logits, W4.t())            # (B, hidden3)
        d_z3 = d_a3 * relu_derivative(z3)                # (B, hidden3)

        dW3 = torch.matmul(a2.t(), d_z3)                 # (hidden2, hidden3)
        db3 = d_z3.sum(dim=0, keepdim=True)              # (1, hidden3)

        d_a2 = torch.matmul(d_z3, W3.t())                # (B, hidden2)
        d_z2 = d_a2 * relu_derivative(z2)                # (B, hidden2)

        dW2 = torch.matmul(a1.t(), d_z2)                 # (hidden1, hidden2)
        db2 = d_z2.sum(dim=0, keepdim=True)              # (1, hidden2)

        d_a1 = torch.matmul(d_z2, W2.t())                # (B, hidden1)
        d_z1 = d_a1 * relu_derivative(z1)                # (B, hidden1)

        dW1 = torch.matmul(X.t(), d_z1)                  # (in_features, hidden1)
        db1 = d_z1.sum(dim=0, keepdim=True)              # (1, hidden1)

        # Add regularization gradients: derivative of reg_loss is 2 * lambda_l2 * W
        dW1 += 2 * lambda_l2 * W1
        dW2 += 2 * lambda_l2 * W2
        dW3 += 2 * lambda_l2 * W3
        dW4 += 2 * lambda_l2 * W4

        # Update Parameters (Gradient Descent)
        W1 -= learning_rate * dW1
        b1 -= learning_rate * db1
        W2 -= learning_rate * dW2
        b2 -= learning_rate * db2
        W3 -= learning_rate * dW3
        b3 -= learning_rate * db3
        W4 -= learning_rate * dW4
        b4 -= learning_rate * db4

        running_loss += loss_total.item() * B
        preds = torch.argmax(probs, dim=1)
        correct_train += (preds == y).sum().item()

    train_loss = running_loss / num_train
    train_accuracy = 100.0 * correct_train / num_train

    # Evaluate on Test Set
    z1_test = torch.matmul(test_images, W1) + b1
    a1_test = relu(z1_test)
    z2_test = torch.matmul(a1_test, W2) + b2
    a2_test = relu(z2_test)
    z3_test = torch.matmul(a2_test, W3) + b3
    a3_test = relu(z3_test)
    logits_test = torch.matmul(a3_test, W4) + b4
    probs_test = softmax(logits_test)
    preds_test = torch.argmax(probs_test, dim=1)
    test_accuracy = 100.0 * (preds_test == test_labels).sum().item() / test_labels.shape[0]

    print(f"Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, " +
          f"Train Acc: {train_accuracy:.2f}%, Test Acc: {test_accuracy:.2f}%")

Epoch 1/10: Train Loss: 0.4507, Train Acc: 91.98%, Test Acc: 96.07%
Epoch 2/10: Train Loss: 0.2895, Train Acc: 96.65%, Test Acc: 96.73%
Epoch 3/10: Train Loss: 0.2505, Train Acc: 97.76%, Test Acc: 97.34%
Epoch 4/10: Train Loss: 0.2258, Train Acc: 98.35%, Test Acc: 97.65%
Epoch 5/10: Train Loss: 0.2089, Train Acc: 98.74%, Test Acc: 97.82%
Epoch 6/10: Train Loss: 0.1955, Train Acc: 99.09%, Test Acc: 97.92%
Epoch 7/10: Train Loss: 0.1831, Train Acc: 99.38%, Test Acc: 97.92%
Epoch 8/10: Train Loss: 0.1744, Train Acc: 99.51%, Test Acc: 97.99%
Epoch 9/10: Train Loss: 0.1661, Train Acc: 99.66%, Test Acc: 97.95%
Epoch 10/10: Train Loss: 0.1591, Train Acc: 99.77%, Test Acc: 98.19%
