In [None]:
import jax
import jax.numpy as jnp
import tensorflow as tf

In [None]:
cpu = jax.devices("cpu")[0] if jax.devices("cpu") else None
gpu = jax.devices("METAL")[0] if jax.devices("METAL") else None
jax.config.update("jax_platform_name", "cpu")

key = jax.random.PRNGKey(0)

In [None]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

print(f"Data range: {x_train.min()} to {x_train.max()}")
print(f"Number of classes: {len(jnp.unique(y_train))}")

num_classes = len(jnp.unique(y_train))

x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)

y_train = jnp.eye(num_classes)[y_train]
y_test = jnp.eye(num_classes)[y_test]

print(f"Training data shape: {x_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Test labels shape: {y_test.shape}")


In [None]:
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

In [None]:
x_train = jnp.array(x_train)
y_train = jnp.array(y_train)
x_test = jnp.array(x_test)
y_test = jnp.array(y_test)

In [None]:
def jnp_log(x: jnp.array) -> jnp.array:
    x = jnp.clip(x, 1e-10, 1e+10)
    return jnp.log(x)

In [None]:
class Relu:
    @staticmethod
    def forward(x: jnp.array) -> jnp.array:
        return jnp.maximum(0, x)

    @staticmethod
    def backward(dout: jnp.array, x: jnp.array) -> jnp.array:
        return dout * (x > 0).astype(dout.dtype)

In [None]:
class Softmax:
    @staticmethod
    def forward(x: jnp.array) -> jnp.array:
        x_max = jnp.max(x, axis=-1, keepdims=True)
        x_shifted = x - x_max
        exp_x = jnp.exp(x_shifted)
        return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)

    @staticmethod
    def backward(dout: jnp.array, x: jnp.array) -> jnp.array:
        return Softmax.forward(x) * (dout - jnp.sum(dout * Softmax.forward(x), axis=-1, keepdims=True))

In [None]:
def cross_entropy(y_hat: jnp.array, y: jnp.array) -> jnp.array:
    return -jnp.mean(jnp.sum(y * jnp_log(y_hat), axis=-1))

In [None]:
input_dim = 784
hidden_dim = 256
output_dim = 10

In [None]:
W1 = jax.random.normal(key, shape=(input_dim, hidden_dim))
b1 = jax.random.normal(key, shape=(hidden_dim,))
W2 = jax.random.normal(key, shape=(hidden_dim, output_dim))
b2 = jax.random.normal(key, shape=(output_dim,))

In [None]:
epochs = 1000
eta = 0.1

for epoch in range(1, epochs+1):
    batch_size = x_train.shape[0]

    # forward
    u1 = jnp.dot(x_train, W1) + b1
    h1 = Relu.forward(u1)

    u2 = jnp.dot(h1, W2) + b2
    y_hat = Softmax.forward(u2)

    loss = cross_entropy(y_hat, y_train)

    # backward
    delta_2 = y_hat - y_train

    dout = jnp.dot(delta_2, W2.T)
    delta_1 = Relu.backward(dout=dout, x=u1)

    # calculate gradients
    dW1 = jnp.dot(x_train.T, delta_1) / batch_size
    db1 = jnp.mean(delta_1, axis=0)

    dW2 = jnp.dot(h1.T, delta_2) / batch_size
    db2 = jnp.mean(delta_2, axis=0)


    # parameter update
    W1 -= eta * dW1
    b1 -= eta * db1

    W2 -= eta * dW2
    b2 -= eta * db2

    if epoch % 100 == 0:
        print(f"epoch: {epoch}, loss: {loss}")


In [None]:
u1 = jnp.dot(x_test, W1) + b1
h1 = Relu.forward(u1)

u2 = jnp.dot(h1, W2) + b2
y_hat = Softmax.forward(u2)

y_pred = jnp.argmax(y_hat, axis=-1)
y_true = jnp.argmax(y_test, axis=-1)
acc = jnp.mean(y_pred == y_true)
print(f"Test accuracy: {acc * 100:.2f}%")