In [49]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [50]:
from sklearn.datasets import fetch_openml

mnist = fetch_openml('mnist_784', version=1)

X = mnist['data']
y = mnist['target'].astype(np.int64)

X = X / 255.0

In [51]:
#one hot encoding
def one_hot(y, num_classes=10):
  m = y.shape[0]
  encode = np.zeros((m, num_classes))
  encode[np.arange(m), y] = 1
  return encode
Y = one_hot(y)

In [52]:
#defining neural net
np.random.seed(42)

input_size = 784
hidden_size = 256
output_size = 10

W1 = np.random.randn(input_size, hidden_size) * np.sqrt(1. / input_size)
b1 = np.zeros((1, hidden_size))

W2 = np.random.randn(hidden_size, output_size) * np.sqrt(1. / hidden_size)
b2 = np.zeros((1, output_size))

In [53]:
#activation functions
def relu(x):
  return np.maximum(0, x)

def softmax(x):
  exp_x = np.exp(x - np.max(x, axis=1, keepdims=True))
  return exp_x / np.sum(exp_x, axis=1, keepdims=True)

#forward pass
def forward_pass(X, W1, b1, W2, b2):
  #hidden layer
  Z1 = np.dot(X, W1) + b1
  A1 = relu(Z1)

  #output layer
  Z2 = np.dot(A1, W2) + b2
  A2 = softmax(Z2)
  return Z1, A1, Z2, A2

In [54]:
Z1, A1, Z2, A2 = forward_pass(X, W1, b1, W2, b2)

In [55]:
def compute_loss(Y, A2):
  m = Y.shape[0]
  log_probs = -np.log(A2 + 1e-8)
  loss = np.sum(Y * log_probs) / m
  return loss

In [56]:
loss = compute_loss(Y, A2)
print(loss)

2.3564865538229967


In [57]:
def relu_derivative(Z):
  return Z > 0

#backward pass
def backward_pass(X, Y, Z1, A1, Z2, A2, W2):
  m = X.shape[0]
  dZ2 = (A2 - Y)
  dW2 = np.dot(A1.T, dZ2) / m
  db2 = np.sum(dZ2, axis=0, keepdims=True) / m

  #hidden layer
  dA1 = np.dot(dZ2, W2.T)
  dZ1 = dA1 * relu_derivative(Z1)
  dW1 = np.dot(X.T, dZ1) / m
  db1 = np.sum(dZ1, axis=0, keepdims=True) / m

  return dW1, db1, dW2, db2

In [58]:
alpha = 0.1
epochs = 1000
for epoch in range(epochs):
  #forward pass
  Z1, A1, Z2, A2 = forward_pass(X, W1, b1, W2, b2)

  loss = compute_loss(Y, A2)

  #backward pass
  dW1, db1, dW2, db2 = backward_pass(X, Y, Z1, A1, Z2, A2, W2)

  W1 -= alpha * dW1
  b1 -= alpha * db1
  W2 -= alpha * dW2
  b2 -= alpha * db2

  print(f"Epoch {epoch}: Loss = {loss:.4f}")



Epoch 0: Loss = 2.3565
Epoch 1: Loss = 2.2864
Epoch 2: Loss = 2.2232
Epoch 3: Loss = 2.1642
Epoch 4: Loss = 2.1078
Epoch 5: Loss = 2.0529
Epoch 6: Loss = 1.9989
Epoch 7: Loss = 1.9452
Epoch 8: Loss = 1.8918
Epoch 9: Loss = 1.8384
Epoch 10: Loss = 1.7854
Epoch 11: Loss = 1.7326
Epoch 12: Loss = 1.6805
Epoch 13: Loss = 1.6292
Epoch 14: Loss = 1.5789
Epoch 15: Loss = 1.5299
Epoch 16: Loss = 1.4824
Epoch 17: Loss = 1.4366
Epoch 18: Loss = 1.3925
Epoch 19: Loss = 1.3502
Epoch 20: Loss = 1.3099
Epoch 21: Loss = 1.2714
Epoch 22: Loss = 1.2348
Epoch 23: Loss = 1.2001
Epoch 24: Loss = 1.1672
Epoch 25: Loss = 1.1361
Epoch 26: Loss = 1.1066
Epoch 27: Loss = 1.0787
Epoch 28: Loss = 1.0523
Epoch 29: Loss = 1.0273
Epoch 30: Loss = 1.0037
Epoch 31: Loss = 0.9813
Epoch 32: Loss = 0.9601
Epoch 33: Loss = 0.9400
Epoch 34: Loss = 0.9209
Epoch 35: Loss = 0.9028
Epoch 36: Loss = 0.8856
Epoch 37: Loss = 0.8693
Epoch 38: Loss = 0.8537
Epoch 39: Loss = 0.8389
Epoch 40: Loss = 0.8248
Epoch 41: Loss = 0.8113
Ep