In [None]:
import jax
import jax.numpy as jnp
from typing import Tuple

In [None]:
cpu = jax.devices("cpu")[0] if jax.devices("cpu") else None
gpu = (jax.devices("METAL")[0] if jax.devices("METAL") else
       jax.devices("gpu")[0] if jax.devices("gpu") else None)

jax.config.update("jax_platform_name", "cpu")

key = jax.random.PRNGKey(0)

In [None]:
x_train = jnp.array(
    [
        [0, 0],
        [0, 1],
        [1, 0],
        [1, 1],
    ]
)
x_valid = jnp.array(
    [
        [0, 0],
        [0, 1],
        [1, 1],
        [1, 0],
    ]
)
x_test = jnp.array(
    [
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
    ]
)

In [None]:
y_train = jnp.array([[0], [1], [1], [0]])
y_valid = jnp.array([[0], [1], [0], [1]])
y_test = jnp.array([[1], [1], [1], [1]])

In [None]:
def jnp_log(x: jnp.array) -> jnp.array:
    x = jnp.clip(x, 1e-10, 1e+10)
    return jnp.log(x)

In [None]:
class Sigmoid:
    def __init__(self):
        pass

    def __call__(self, x: jnp.array) -> jnp.array:
        return jnp.exp(jnp.minimum(x, 0)) / (1 + jnp.exp(-jnp.abs(x)))

    def backward(self, dout: jnp.array, x: jnp.array) -> jnp.array:
        return dout * (1.0 - self(x)) * self(x)

sigmoid = Sigmoid()

In [None]:
class Relu:
    def __init__(self):
        pass

    def __call__(self, x: jnp.array) -> jnp.array:
        return jnp.maximum(0, x)

    def backward(self, dout: jnp.array, x: jnp.array) -> jnp.array:
        return dout * (x > 0).astype(dout.dtype)

relu = Relu()

In [None]:
def binary_cross_entropy(y: jnp.array, y_hat: jnp.array) -> jnp.array:
    return -jnp.mean(y * jnp_log(y_hat) + (1 - y) * jnp_log(1 - y_hat))

In [None]:
input_dim = 2
hidden_dim = 8
output_dim = 1

In [None]:
W1 = jax.random.uniform(key, minval=-0.08, maxval=0.08, shape=(input_dim, hidden_dim)).astype("float32")
b1 = jnp.zeros(shape=(hidden_dim,)).astype("float32")
W2 = jax.random.uniform(key, minval=-0.08, maxval=0.08, shape=(hidden_dim, output_dim)).astype("float32")
b2 = jnp.zeros(shape=(output_dim,)).astype("float32")

In [None]:
def train(x: jnp.array, y: jnp.array, eta=0.05) -> jnp.array:
    global W1, b1, W2, b2

    batch_size = x.shape[0]

    # forward
    u1 = jnp.dot(x, W1) + b1
    h1 = relu(u1)

    u2 = jnp.dot(h1, W2) + b2
    y_hat = sigmoid(u2)
    loss = binary_cross_entropy(y, y_hat)

    # backward
    delta_2 = y_hat - y

    dout = jnp.dot(delta_2, W2.T)
    delta_1 = relu.backward(dout=dout, x=u1)

    # calculate gradients
    dW1 = jnp.dot(x.T, delta_1) / batch_size
    db1 = jnp.mean(delta_1, axis=0)

    dW2 = jnp.dot(h1.T, delta_2) / batch_size
    db2 = jnp.mean(delta_2, axis=0)

    # parameter update
    W1 -= eta * dW1
    b1 -= eta * db1

    W2 -= eta * dW2
    b2 -= eta * db2

    return loss

In [None]:
def valid(x: jnp.array, y: jnp.array) -> Tuple[jnp.array, jnp.array]:
    # forward
    u1 = jnp.dot(x, W1) + b1
    h1 = relu(u1)

    u2 = jnp.dot(h1, W2) + b2
    y_hat = sigmoid(u2)
    loss = binary_cross_entropy(y, y_hat)

    return loss, y_hat

In [None]:
epochs = 3000

for epoch in range(epochs):
    loss = train(x_train, y_train)

loss, y_pred = valid(x_valid, y_valid)
print(y_pred)

In [None]:
_, y_pred = valid(x_test, y_test)

print("Predictions:")
print(y_pred)
print("\nTrue labels:")
print(y_test)