In [None]:
import jax
import jax.numpy as jnp

In [None]:
cpu = jax.devices("cpu")[0] if jax.devices("cpu") else None
gpu = jax.devices("METAL")[0] if jax.devices("METAL") else None
jax.config.update("jax_platform_name", "cpu")

key = jax.random.PRNGKey(0)

In [None]:
x_train = jnp.array(
    [
        [0, 0],
        [0, 1],
        [1, 0],
        [1, 1],
    ]
)

In [None]:
y_train = jnp.array([[0], [0], [0], [1]])

In [None]:
x_test = jnp.array(
    [
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
    ]
)

In [None]:
y_test = jnp.array([[0], [0], [0], [0]])

In [None]:
def jnp_log(x: jnp.array) -> jnp.array:
    x = jnp.clip(x, 1e-10, 1e+10)
    return jnp.log(x)

In [None]:
def sigmoid(x: jnp.array) -> jnp.array:
    return jnp.exp(jnp.minimum(x, 0)) / (1 + jnp.exp(-jnp.abs(x)))

In [None]:
def binary_cross_entropy(y: jnp.array, y_hat: jnp.array) -> jnp.array:
    return -jnp.mean(y * jnp_log(y_hat) + (1 - y) * jnp_log(1 - y_hat))

In [None]:
W = jax.random.normal(key, shape=(2, 1))
b = jax.random.normal(key, shape=(1,))

In [None]:
epochs = 10000
eta = 0.1

for epoch in range(1, epochs+1):
    y_hat = sigmoid(jnp.dot(x_train, W) + b)
    loss = binary_cross_entropy(y_train, y_hat).item()

    delta = y_hat - y_train

    batch_size = y_train.shape[0]
    dW = jnp.dot(x_train.T, delta) / batch_size
    db = jnp.mean(delta, axis=0, keepdims=True)

    W -= eta * dW
    b -= eta * db

    if epoch % 1000 == 0:
        print(f"epoch: {epoch}, loss: {loss}")

In [None]:
y_hat = sigmoid(jnp.dot(x_test, W) + b)
print("Predictions:")
print(y_hat)
print("\nTrue labels:")
print(y_test)