In [None]:
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import jax
import jax.numpy as jnp
from keras.datasets import mnist
from typing import Tuple

In [None]:
cpu = jax.devices("cpu")[0] if jax.devices("cpu") else None
gpu = (jax.devices("METAL")[0] if jax.devices("METAL") else
       jax.devices("gpu")[0] if jax.devices("gpu") else None)

jax.config.update("jax_platform_name", "cpu")

key = jax.random.PRNGKey(0)

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

print(f"Data range: {x_train.min()} to {x_train.max()}")
print(f"Number of classes: {len(jnp.unique(y_train))}")

num_classes = len(jnp.unique(y_train))

x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)

y_train = jnp.eye(num_classes)[y_train]
y_test = jnp.eye(num_classes)[y_test]

x_train, x_valid, y_train, y_valid = train_test_split(x_train, y_train, test_size=10000)

print(f"Training data shape: {x_train.shape}")
print(f"Training labels shape: {y_train.shape}")
print(f"Validation data shape: {x_valid.shape}")
print(f"Validation labels shape: {y_valid.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Test labels shape: {y_test.shape}")

In [None]:
x_train = x_train.astype('float32') / 255.0
x_valid = x_valid.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

In [None]:
x_train = jnp.array(x_train)
y_train = jnp.array(y_train)
x_valid = jnp.array(x_valid)
y_valid = jnp.array(y_valid)
x_test = jnp.array(x_test)
y_test = jnp.array(y_test)

In [None]:
def jnp_log(x: jnp.array) -> jnp.array:
    x = jnp.clip(x, 1e-10, 1e+10)
    return jnp.log(x)

In [None]:
def softmax(x: jnp.array) -> jnp.array:
    x_max = jnp.max(x, axis=-1, keepdims=True)
    x_shifted = x - x_max
    exp_x = jnp.exp(x_shifted)
    return exp_x / jnp.sum(exp_x, axis=-1, keepdims=True)

In [None]:
def cross_entropy(y_hat: jnp.array, y: jnp.array) -> jnp.array:
    return -jnp.mean(jnp.sum(y * jnp_log(y_hat), axis=-1))

In [None]:
W = jax.random.uniform(key=key, minval=-0.08, maxval=0.08, shape=(784, 10)).astype('float32')
b = jnp.zeros(shape=(10,)).astype('float32')

In [None]:
def train(x: jnp.array, y: jnp.array, eta=1.0) -> jnp.array:
    global W, b

    y_hat = softmax(jnp.dot(x, W) + b)

    loss = cross_entropy(y_hat, y)

    delta = y_hat - y

    batch_size = x_train.shape[0]

    dW = jnp.dot(x_train.T, delta) / batch_size
    db = jnp.mean(delta, axis=0, keepdims=True)

    W -= eta * dW
    b -= eta * db

    return loss

In [None]:
def valid(x: jnp.array, y: jnp.array) -> Tuple[jnp.array, jnp.array]:
    y_hat = softmax(jnp.dot(x, W) + b)
    loss = cross_entropy(y_hat, y)
    return loss, y_hat

In [None]:
epochs = 100

for epoch in range(epochs):
    loss = train(x_train, y_train)

    loss, y_pred = valid(x_valid, y_valid)

    if epoch % 10 == 9 or epoch == 0:
        print('EPOCH: {}, Valid Cost: {:.3f}, Valid Accuracy: {:.3f}'.format(
            epoch + 1,
            loss,
            accuracy_score(jnp.argmax(y_valid, axis=1), jnp.argmax(y_pred, axis=1))
        ))

In [None]:
y_hat = softmax(jnp.dot(x_test, W) + b)
y_pred = jnp.argmax(y_hat, axis=-1)
y_true = jnp.argmax(y_test, axis=-1)
acc = jnp.mean(y_pred == y_true)
print(f"Test accuracy: {acc * 100:.2f}%")