In [None]:
import jax
import jax.numpy as jnp

In [None]:
cpu = jax.devices("cpu")[0] if jax.devices("cpu") else None
gpu = jax.devices("METAL")[0] if jax.devices("METAL") else None
jax.config.update("jax_platform_name", "cpu")

key = jax.random.PRNGKey(0)

In [None]:
x_train = jnp.array(
    [
        [0, 0],
        [0, 1],
        [1, 0],
        [1, 1],
    ]
)

In [None]:
y_train = jnp.array([[0], [1], [1], [0]])

In [None]:
x_test = jnp.array(
    [
        [1, 0],
        [0, 1],
        [1, 0],
        [0, 1],
    ]
)

In [None]:
y_test = jnp.array([[1], [1], [1], [1]])

In [None]:
def jnp_log(x: jnp.array) -> jnp.array:
    x = jnp.clip(x, 1e-10, 1e+10)
    return jnp.log(x)

In [None]:
class Sigmoid:
    @staticmethod
    def forward(x: jnp.array) -> jnp.array:
        return jnp.exp(jnp.minimum(x, 0)) / (1 + jnp.exp(-jnp.abs(x)))

    @staticmethod
    def backward(dout: jnp.array, x: jnp.array) -> jnp.array:
        return dout * (1.0 - Sigmoid.forward(x)) * Sigmoid.forward(x)

In [None]:
class Relu:
    @staticmethod
    def forward(x: jnp.array) -> jnp.array:
        return jnp.maximum(0, x)

    @staticmethod
    def backward(dout: jnp.array, x: jnp.array) -> jnp.array:
        return dout * (x > 0).astype(dout.dtype)

In [None]:
def binary_cross_entropy(y: jnp.array, y_hat: jnp.array) -> jnp.array:
    return -jnp.mean(y * jnp_log(y_hat) + (1 - y) * jnp_log(1 - y_hat))

In [None]:
input_dim = 2
hidden_dim = 8
output_dim = 1

In [None]:
W1 = jax.random.normal(key, shape=(input_dim, hidden_dim))
b1 = jax.random.normal(key, shape=(hidden_dim,))
W2 = jax.random.normal(key, shape=(hidden_dim, output_dim))
b2 = jax.random.normal(key, shape=(output_dim,))

In [None]:
epochs = 10000
eta = 0.1

for epoch in range(1, epochs+1):
    batch_size = x_train.shape[0]

    # forward
    u1 = jnp.dot(x_train, W1) + b1 # (batch_size, hidden_dim)
    h1 = Relu.forward(u1) # (batch_size, hidden_dim)
    u2 = jnp.dot(h1, W2) + b2 # (batch_size, output_dim)
    y_hat = Sigmoid.forward(u2) # (batch_size, output_dim)
    loss = binary_cross_entropy(y_train, y_hat)

    # backward
    delta_2 = y_hat - y_train

    dout = jnp.dot(delta_2, W2.T)
    delta_1 = Relu.backward(dout=dout, x=u1)

    # calculate gradients
    dW1 = jnp.dot(x_train.T, delta_1) / batch_size
    db1 = jnp.mean(delta_1, axis=0)

    dW2 = jnp.dot(h1.T, delta_2) / batch_size
    db2 = jnp.mean(delta_2, axis=0)

    # parameter update
    W1 -= eta * dW1
    b1 -= eta * db1

    W2 -= eta * dW2
    b2 -= eta * db2

    if epoch % 100 == 0:
        print(f"epoch: {epoch}, loss: {loss}")

In [None]:
u1 = jnp.dot(x_test, W1) + b1 # (batch_size, hidden_dim)
h1 = Relu.forward(u1) # (batch_size, hidden_dim)
u2 = jnp.dot(h1, W2) + b2 # (batch_size, output_dim)
y_hat = Sigmoid.forward(u2) # (batch_size, output_dim)

print("Predictions:")
print(y_hat)
print("\nTrue labels:")
print(y_test)