In [None]:
import jax
import jax.numpy as jnp
from typing import Tuple

In [None]:
cpu = jax.devices("cpu")[0] if jax.devices("cpu") else None
gpu = (jax.devices("METAL")[0] if jax.devices("METAL") else
       jax.devices("gpu")[0] if jax.devices("gpu") else None)

jax.config.update("jax_platform_name", "cpu")

key = jax.random.PRNGKey(0)

In [None]:
x_train = jnp.array(
    [
        [0, 0],
        [0, 1],
        [1, 0],
        [1, 1],
    ]
) # (4, 2)
x_valid = jnp.array(
    [
        [1, 1],
        [0, 1],
        [0, 0],
        [1, 0],
    ]
) # (4, 2)
x_test = jnp.array(
    [
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
    ]
) # (4, 2)

y_train = jnp.array([[0], [0], [0], [1]]) # (4, 1)
y_valid = jnp.array([[1], [0], [0], [0]]) # (4, 1)
y_test = jnp.array([[0], [0], [0], [0]])  # (4, 1)

In [None]:
def jnp_log(x: jnp.array) -> jnp.array:
    x = jnp.clip(x, 1e-10, 1e+10)
    return jnp.log(x)

In [None]:
def sigmoid(x: jnp.array) -> jnp.array:
    return jnp.exp(jnp.minimum(x, 0)) / (1 + jnp.exp(-jnp.abs(x)))

In [None]:
def binary_cross_entropy(y: jnp.array, y_hat: jnp.array) -> jnp.array:
    return -jnp.mean(y * jnp_log(y_hat) + (1 - y) * jnp_log(1 - y_hat))

In [None]:
W = jax.random.uniform(key, minval=-0.08, maxval=0.08, shape=(2, 1)).astype("float32")
b = jnp.zeros(shape=(1,))

In [None]:
def train(x: jnp.array, y: jnp.array, eta=1.0) -> jnp.array:
    global W, b

    # forward
    y_hat = sigmoid(jnp.dot(x, W) + b)          # (batch_size, out_dim)
    loss = binary_cross_entropy(y, y_hat)

    # backward
    delta = y_hat - y                           # (batch_size, out_dim)

    # calculate gradients
    batch_size = x.shape[0]
    dW = jnp.dot(x.T, delta) / batch_size       # (in_dim, out_dim)
    db = jnp.mean(delta, axis=0, keepdims=True) # (out_dim,)

    # update parameters
    W -= eta * dW
    b -= eta * db

    return loss

In [None]:
def valid(x: jnp.array, y: jnp.array) -> Tuple[jnp.array, jnp.array]:
    y_hat = sigmoid(jnp.dot(x, W) + b)
    loss = binary_cross_entropy(y, y_hat)
    return loss, y_hat

In [None]:
epochs = 1000

for epoch in range(epochs):
    loss = train(x_train, y_train)
    loss, y_pred = valid(x_valid, y_valid)

    if epoch % 10 == 9 or epoch == 0:
        print(f"EPOCH: {epoch + 1}, Valid Loss: {loss}")

In [None]:
_, y_pred = valid(x_test, y_test)
print("Predictions:")
print(y_pred)
print("\nTrue Labels:")
print(y_test)