In [1]:
import jax
import jax.numpy as jnp
import optax

In [2]:
SEED = 42
key = jax.random.PRNGKey(SEED)

In [3]:
# criando nossos dados
X = jnp.array([
    [0., 0.],
    [0., 1.],
    [1., 0.],
    [1., 1.]
])

y = jnp.array([
    [0.],
    [1.],
    [1.],
    [0.]
])

In [4]:
k1, k2 = jax.random.split(key)

In [5]:
def init_params(key):
    # inicializando os parametros
    k1, k2 = jax.random.split(key)
    return {
        "W1": jax.random.normal(k1, (2, 4)),
        "b1": jnp.zeros((4,)),
        "W2": jax.random.normal(k2, (4, 1)),
        "b2": jnp.zeros((1,))
    }

In [6]:
def forward(params, x):
    # construção do modelo
    z = jax.nn.tanh(x @ params["W1"] + params["b1"])
    out = jax.nn.sigmoid(z @ params["W2"] + params["b2"])
    return out

In [7]:
def loss_fn(params, x, y):
    preds = forward(params, x)
    eps = 1e-7
    return -jnp.mean(
        y * jnp.log(preds + eps) +
        (1 - y) * jnp.log(1 - preds + eps)
    )

In [8]:
@jax.jit
def update(params, opt_state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss

In [47]:
LR = 1e-1
optimizer = optax.sgd(learning_rate=LR)

In [48]:
params = init_params(key)
params

{'W1': Array([[ 0.07592554, -0.48634264,  1.2903206 ,  0.5196119 ],
        [ 0.30040437,  0.31034866,  0.5761609 , -0.8074621 ]],      dtype=float32),
 'b1': Array([0., 0., 0., 0.], dtype=float32),
 'W2': Array([[ 0.60576403],
        [ 0.7990441 ],
        [-0.908927  ],
        [-0.63525754]], dtype=float32),
 'b2': Array([0.], dtype=float32)}

In [49]:
params['W1']

Array([[ 0.07592554, -0.48634264,  1.2903206 ,  0.5196119 ],
       [ 0.30040437,  0.31034866,  0.5761609 , -0.8074621 ]],      dtype=float32)

In [50]:
opt_state = optimizer.init(params)

In [51]:
EPOCHS = 500_000
for epoch in range(EPOCHS):
    params, opt_state, loss = update(params, opt_state, X, y)
    if epoch % 50_000 == 0:
        print(f"step ={epoch:6d}, loss ={loss:10.6f}")
        epoch += 1

step =     0, loss =  0.818074
step = 50000, loss =  0.465172
step =100000, loss =  0.092970
step =150000, loss =  0.034529
step =200000, loss =  0.020073
step =250000, loss =  0.013953
step =300000, loss =  0.010632
step =350000, loss =  0.008564
step =400000, loss =  0.007153
step =450000, loss =  0.006142


In [53]:
preds = forward(params, X)
print(jnp.round(preds, 3))

[[0.001     ]
 [0.99500006]
 [0.99300003]
 [0.008     ]]


In [54]:
y

Array([[0.],
       [1.],
       [1.],
       [0.]], dtype=float32)