In [1]:
import jax
import jax.numpy as jnp
import optax

In [3]:
SEED = 42
key = jax.random.PRNGKey(SEED)

In [4]:
# criando nossos dados
X = jnp.array([
    [0., 0.],
    [0., 1.],
    [1., 0.],
    [1., 1.]
])

y = jnp.array([
    [0.],
    [1.],
    [1.],
    [0.]
])

In [5]:
k1, k2 = jax.random.split(key)

In [6]:
def init_params(key):
    # inicializando os parametros
    k1, k2 = jax.random.split(key)
    return {
        "W1": jax.random.normal(k1, (2, 4)),
        "b1": jnp.zeros((4,)),
        "W2": jax.random.normal(k2, (4, 1)),
        "b2": jnp.zeros((1,))
    }

In [7]:
def forward(params, x):
    # construção do modelo
    z = jax.nn.tanh(x @ params["W1"] + params["b1"])
    out = jax.nn.sigmoid(z @ params["W2"] + params["b2"])
    return out

In [None]:
def loss_fn(params, x, y):
    preds = forward(params, x)
    eps = 1e-7
    return -jnp.mean(
        y * jnp.log(preds + eps) +
        (1 - y) * jnp.log(1 - preds + eps)
    )

In [None]:
@jax.jit
def update(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss

In [None]:
LR = 1e-3
optimizer = optax.sgd(learning_rate=LR)

In [None]:
params = init_params(key)
params

In [None]:
params['W1']

In [None]:
opt_state = optimizer.init(params)
opt_state

In [None]:
for _ in range(500):
    params, opt_state, loss = update(params, opt_state, X, y)
    if step % 10_000 == 0:
        print(f"step={step}, loss={loss}")
        step += 1

In [None]:
preds = forward(params, X)
print(jnp.round(preds, 3))