In [1]:
import jax
import jax.numpy as jnp
import tensorflow as tf

In [2]:
cpu = jax.devices("cpu")[0] if jax.devices("cpu") else None
gpu = jax.devices("METAL")[0] if jax.devices("METAL") else None
jax.config.update("jax_platform_name", "cpu")

key = jax.random.PRNGKey(0)

I0000 00:00:1757988873.080649 23686962 service.cc:145] XLA service 0x158e55810 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1757988873.080786 23686962 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1757988873.082262 23686962 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1757988873.082302 23686962 mps_client.cc:384] XLA backend will use up to 11452776448 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M3

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

print(f"Data range: {x_train.min()} to {x_train.max()}")
print(f"Number of classes: {len(jnp.unique(y_train))}")

num_classes = len(jnp.unique(y_train))

x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)

y_train = jnp.eye(num_classes)[y_train]
y_test = jnp.eye(num_classes)[y_test]

print(f"Training data shape: {x_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Test labels shape: {y_test.shape}")


Data range: 0 to 255
Number of classes: 10
Training data shape: (60000, 784)
Training labels shape: (60000, 10)
Test data shape: (10000, 784)
Test labels shape: (10000, 10)


In [4]:
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

In [5]:
x_train = jnp.array(x_train)
y_train = jnp.array(y_train)
x_test = jnp.array(x_test)
y_test = jnp.array(y_test)

In [6]:
def jnp_log(x: jnp.array) -> jnp.array:
    x = jnp.clip(x, 1e-10, 1e+10)
    return jnp.log(x)

In [7]:
class Relu:
    @staticmethod
    def forward(x: jnp.array) -> jnp.array:
        return jnp.maximum(0, x)

    @staticmethod
    def backward(dout: jnp.array, x: jnp.array) -> jnp.array:
        return dout * (x > 0).astype(dout.dtype)

In [8]:
class Softmax:
    @staticmethod
    def forward(x: jnp.array) -> jnp.array:
        x_max = jnp.max(x, axis=-1, keepdims=True)
        x_shifted = x - x_max
        exp_x = jnp.exp(x_shifted)
        return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)

    @staticmethod
    def backward(dout: jnp.array, x: jnp.array) -> jnp.array:
        return Softmax.forward(x) * (dout - jnp.sum(dout * Softmax.forward(x), axis=-1, keepdims=True))

In [9]:
def cross_entropy(y_hat: jnp.array, y: jnp.array) -> jnp.array:
    return -jnp.mean(jnp.sum(y * jnp_log(y_hat), axis=-1))

In [10]:
input_dim = 784
hidden_dim = 256
output_dim = 10

In [11]:
W1 = jax.random.normal(key, shape=(input_dim, hidden_dim))
b1 = jax.random.normal(key, shape=(hidden_dim,))
W2 = jax.random.normal(key, shape=(hidden_dim, hidden_dim))
b2 = jax.random.normal(key, shape=(hidden_dim,))
W3 = jax.random.normal(key, shape=(hidden_dim, output_dim))
b3 = jax.random.normal(key, shape=(output_dim,))

In [None]:
epochs = 1000
eta = 0.1

for epoch in range(1, epochs+1):
    batch_size = x_train.shape[0]

    # forward
    u1 = jnp.dot(x_train, W1) + b1
    h1 = Relu.forward(u1)

    u2 = jnp.dot(h1, W2) + b2
    h2 = Relu.forward(u2)

    u3 = jnp.dot(h2, W3) + b3
    y_hat = Softmax.forward(u3)

    # backward
    delta_3 = y_hat - y_train

    dout = jnp.dot(delta_3, W3.T)
    delta_2 = Relu.backward(dout=dout, x=u2)

    dout = jnp.dot(delta_2, W2.T)
    delta_1 = Relu.backward(dout=dout, x=u1)

    # calculate gradients
    dW1 = jnp.dot(x_train.T, delta_1) / batch_size
    db1 = jnp.mean(delta_1, axis=0)

    dW2 = jnp.dot(h1.T, delta_2) / batch_size
    db2 = jnp.mean(delta_2, axis=0)

    dW3 = jnp.dot(h2.T, delta_3) / batch_size
    db3 = jnp.mean(delta_3, axis=0)

    # parameter update
    W1 -= eta * dW1
    b1 -= eta * db1

    W2 -= eta * dW2
    b2 -= eta * db2

    W3 -= eta * dW3
    b3 -= eta * db3


(60000, 10)
