In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [3]:
# You can differentiate a scalar valued function with jax.grad

In [4]:
import jax
import jax.numpy as jnp

In [5]:
grad_tanh = jax.grad(jnp.tanh)

In [6]:
grad_tanh(2.0)

Array(0.07065094, dtype=float32, weak_type=True)

In [7]:
f = lambda x: x ** 3 + x**2 + x + 2

In [8]:
df = jax.grad(f)

In [9]:
df(2.0)

Array(17., dtype=float32, weak_type=True)

In [13]:
jax.grad(jax.grad(jax.grad(df)))(2.0)

Array(0., dtype=float32, weak_type=True)

## Logistic Regression with JAX

In [14]:
key = jax.random.key(0)

In [15]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x/2) + 1)

In [31]:
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

In [17]:
inputs = jnp.array([[0.52, 1.12,  0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

In [18]:
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1-preds) * (1-targets)
    return -jnp.sum(jnp.log(label_probs))


In [22]:
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3, ))
b = jax.random.normal(b_key, ())

In [25]:
W, b

(Array([-0.09005956, -0.71841365,  2.4771307 ], dtype=float32),
 Array(0.09016261, dtype=float32))

In [32]:
W_grad = jax.grad(loss, argnums=0)(W, b)

In [33]:
b_grad = jax.grad(loss, argnums=1)(W, b)

In [38]:
# The following also works.
# jax.grad(loss, (0, 1))(W, b)

In [40]:
W_grad, b_grad

(Array([-0.31957766,  0.00316601, -0.27873003], dtype=float32),
 Array(-0.44693938, dtype=float32))

In [41]:
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

In [43]:
print(jax.grad(loss2)({'W': W, 'b': b}))

{'W': Array([-0.31957766,  0.00316601, -0.27873003], dtype=float32), 'b': Array(-0.44693938, dtype=float32)}


In [45]:
# Calculate the output of the function and gradient wrt its inputs at once.
loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))

loss value 0.59347546
loss value 0.59347546
