# Basic Usage

A typical use case in DL is to differentiate the loss function, which in turn calls the model.

In [1]:
import jax
import jax.numpy as jnp

In [2]:
key = jax.random.PRNGKey(0)

In [12]:
def model(params, X):
    W, b = params["W"], params["b"]
    logits = X @ W + b
    return logits


def loss(params, batch):
    X, y = batch
    logits = model(params, X)
    log_p = jax.nn.log_sigmoid(logits)
    # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable
    log_not_p = jax.nn.log_sigmoid(-logits)
    bce = -y * log_p - (1. - y) * log_not_p
    mean_bce = jnp.mean(bce)
    return mean_bce


def step(params, batch):
    train_loss, grads = jax.value_and_grad(loss)(params, batch)
    params = jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    return params, train_loss

## Wirecheck

In [13]:
X = jax.random.normal(key, (3, 4))
y = jax.random.choice(key, jnp.array([0, 1]), (3,))
batch = (X, y)
print(batch)

(Array([[ 1.1901639 , -1.0996888 ,  0.44367844,  0.5984697 ],
       [-0.39189556,  0.69261974,  0.46018356, -2.068578  ],
       [-0.21438177, -0.9898306 , -0.6789304 ,  0.27362573]],      dtype=float32), Array([0, 1, 1], dtype=int32))


In [14]:
params = {
    "W": jax.nn.initializers.glorot_normal()(key, (4,1)).squeeze(),
    "b": jnp.array([1.])
}
params

{'W': Array([ 1.1436434 , -0.51325184,  0.23285529, -0.36541915], dtype=float32),
 'b': Array([1.], dtype=float32)}

In [15]:
model(params, batch[0])

Array([2.8101609, 1.0593771, 1.0047755], dtype=float32)

In [16]:
grads = jax.grad(loss)(params, batch)
grads

{'W': Array([ 0.42697653, -0.3167577 ,  0.16065963,  0.3412228 ], dtype=float32),
 'b': Array([0.13926351], dtype=float32)}

In [17]:
jax.tree_util.tree_map(lambda p, g: p - 0.01 * g, params, grads)

{'W': Array([ 1.1393737 , -0.5100843 ,  0.23124869, -0.36883137], dtype=float32),
 'b': Array([0.99860734], dtype=float32)}

In [18]:
step(params, batch)

({'W': Array([ 1.1393737 , -0.5100843 ,  0.23124869, -0.36883137], dtype=float32),
  'b': Array([0.99860734], dtype=float32)},
 Array(1.15941, dtype=float32))

## Typical Train Loop

In [19]:
for epoch in range(2):
    X = jax.random.normal(key, (3, 4))
    y = jax.random.choice(key, jnp.array([0, 1]), (3,))

    params, train_loss = step(params, batch)
    