In [None]:
import numpy as np
import torch
from torchvision import datasets, transforms

def sigmoid(z):
    return 1 / (1 + np.exp(-z))

def softmax(z):
    exp_z = np.exp(z - np.max(z)) # to avoid numerical instability
    return exp_z / np.sum(exp_z)

def cross_entropy_loss(y, y_hat):
    return -np.sum(y * np.log(y_hat))

def forward(x, W1, W2, W3):
    a1 = x
    z1 = np.dot(W1, a1)
    a2 = sigmoid(z1)
    z2 = np.dot(W2, a2)
    a3 = sigmoid(z2)
    z3 = np.dot(W3, a3)
    y_hat = softmax(z3)
    return a1, z1, a2, z2, a3, z3, y_hat

def backward(x, y, W1, W2, W3, a1, z1, a2, z2, a3, z3, y_hat, lr):
    delta4 = y_hat - y
    delta3 = np.dot(W3.T, delta4) * a3 * (1 - a3)
    delta2 = np.dot(W2.T, delta3) * a2 * (1 - a2)
    grad_W3 = np.outer(delta4, a3)
    grad_W2 = np.outer(delta3, a2)
    grad_W1 = np.outer(delta2, a1)
    W3 -= lr * grad_W3
    W2 -= lr * grad_W2
    W1 -= lr * grad_W1
    return W1, W2, W3

def train(X_train, y_train, n_epochs, lr, d1, d2, k, batch_size):
    n_train, d = X_train.shape
    #W1 = np.zeros((d1, d))
    #W2 = np.zeros((d2, d1))
    #W3 = np.zeros((k, d2))

    W1 = np.random.randn(d1, d)
    W2 = np.random.randn(d2, d1)
    W3 = np.random.randn(k, d2)

    train_loss = []
    train_acc = []
    
    for epoch in range(n_epochs):

        permutation = np.random.permutation(n_train)
        X_train = X_train[permutation]
        y_train = y_train[permutation]

        # training
        loss = 0
        acc = 0
        for i in range(0, n_train, batch_size):
            X_batch = X_train[i:i+batch_size]
            y_batch = y_train[i:i+batch_size]
            batch_size_actual = y_batch.shape[0]
            for j in range(batch_size_actual):
                x = X_batch[j]
                y = y_batch[j]
                a1, z1, a2, z2, a3, z3, y_hat = forward(x, W1, W2, W3)
                
                if np.argmax(y_hat) == np.argmax(y):
                    acc += 1
                    
                loss += cross_entropy_loss(y, y_hat)
                W1, W2, W3 = backward(x, y, W1, W2, W3, a1, z1, a2, z2, a3, z3, y_hat, lr)
        loss /= n_train
        acc /= n_train
        
        train_loss.append(loss)
        train_acc.append(acc)
        
        print(f"Epoch {epoch+1}/{n_epochs}, train loss: {train_loss[-1]:.3f}, train accuracy: {train_acc[-1]:.3f}")
        
    return W1, W2, W3, train_acc, train_loss

def test(X_test, y_test, W1, W2, W3):
    n_test = X_test.shape[0]
    test_acc = []
    # testing
    loss = 0
    acc = 0
    for i in range(n_test):
        x = X_test[i]
        y = y_test[i]
        _, _, _, _, _, _, y_hat = forward(x, W1, W2, W3)
        loss += cross_entropy_loss(y, y_hat)
        if np.argmax(y_hat) == np.argmax(y):
            acc += 1

    acc /= n_test   
    test_acc.append(acc)

    print(f"test error: {1-test_acc[-1]:.3f}")
    
# Define the transformation to be applied to the data
transform = transforms.Compose([
    transforms.ToTensor(), # Convert the image to a tensor
    transforms.Normalize((0.1307,), (0.3081,)) # Normalize the tensor with the mean and standard deviation of the dataset
])

# Load the training dataset
train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transform)

# Load the test dataset
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False)

X_train = []
y_train = []

for x, y in train_loader:
    X_train.append(np.matrix.flatten(np.array(x)))
    one_hot = np.zeros(10)
    one_hot[np.array(y)[0]] = 1
    y_train.append(one_hot)  
    
X_test = []
y_test = []

for x, y in test_loader:
    X_test.append(np.matrix.flatten(np.array(x)))
    one_hot = np.zeros(10)
    one_hot[np.array(y)[0]] = 1
    y_test.append(one_hot)  
    
W1, W2, W3, acc_curve, err_curve = train(np.array(X_train), np.array(y_train), 100, 0.05, 300, 200, 10, 32)

test(np.array(X_test), np.array(y_test), W1, W2, W3)