Sequential Quadratic Programming

In [8]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

In [24]:
from jax import jacfwd, jacrev

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())


# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)

print("W: ",W)

jacfwd result, with shape (4, 3)
[[ 0.0395953   0.08528218  0.0586315 ]
 [ 0.1939475  -0.23802648  0.03305923]
 [ 0.12545583  0.01447567 -0.31363955]
 [ 0.17280862 -0.58147764  0.32459995]]
jacrev result, with shape (4, 3)
[[ 0.0395953   0.08528218  0.0586315 ]
 [ 0.1939475  -0.23802648  0.03305923]
 [ 0.12545581  0.01447567 -0.31363955]
 [ 0.17280862 -0.58147764  0.32459995]]
hessian, with shape (4, 3, 3)
[[[ 1.71700176e-02  3.69815752e-02  2.54248381e-02]
  [ 3.69815789e-02  7.96526223e-02  5.47611937e-02]
  [ 2.54248343e-02  5.47611788e-02  3.76483165e-02]]

 [[ 5.87327257e-02 -7.20810965e-02  1.00112613e-02]
  [-7.20810741e-02  8.84631649e-02 -1.22865485e-02]
  [ 1.00112604e-02 -1.22865504e-02  1.70646503e-03]]

 [[ 1.21969143e-02  1.40733621e-03 -3.04922890e-02]
  [ 1.40733621e-03  1.62384953e-04 -3.51834111e-03]
  [-3.04922853e-02 -3.51834064e-03  7.62307197e-02]]

 [[ 3.28274891e-02 -1.10460043e-01  6.16624281e-02]
  [-1.10460065e-01  3.71683121e-01 -2.07485735e-01]
  [ 6.166244

In [29]:
from jax import jacfwd, jacrev

def f(x):
    return jnp.power(x,3).sum()
print(f(jnp.array([1.,2.,3.])))

def hessian(f):
    return jacfwd(jacrev(f))

def jacobian(f):
    return jacfwd(f)

J = jacobian(f)(jnp.array([1.,2.,3.]))
print("jacobian, with shape", J.shape)
print(J)

H = hessian(f)(jnp.array([1.,2.,3.]))
print("hessian, with shape", H.shape)
print(H)

36.0
jacobian, with shape (3,)
[ 3. 12. 27.]
hessian, with shape (3, 3)
[[ 6.  0.  0.]
 [ 0. 12.  0.]
 [ 0.  0. 18.]]
