Sequential Quadratic Programming

Dependencies: `jax`

In [None]:
import jax.numpy as jnp
from jax import jacfwd, jacrev

Define your own Objective Functions and Constrains here :

In [6]:
#rosenbrock function
def f(x):
    return jnp.power(1-x[1],2)+jnp.power((x[2]-jnp.power(x[1],2)),2)
print(f(jnp.array([1.,2.,3.])))

2.0


In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)



In [2]:
from jax import jacfwd, jacrev

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())


# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)

print("W: ",W)

jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]

 [[-0.03195215  0.03921401 -0.00544639]
  [ 0.03921401 -0.04812629  0.00668421]
  [-0.00544639  0.00668421 -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103524  0.00348343 -0.00194457]
  [ 0.00348343 -0.01172127  0.0065432 ]
  [-0.00194457  0.0065432  -0.00365263]]]
W:  [-0.36838785 -2.275689    0.01144757]


In [3]:
from jax import jacfwd, jacrev

def f(x):
    return jnp.power(x,3).sum()
print(f(jnp.array([1.,2.,3.])))

def hessian(f):
    return jacfwd(jacrev(f))

def jacobian(f):
    return jacfwd(f)

J = jacobian(f)(jnp.array([1.,2.,3.]))
print("jacobian, with shape", J.shape)
print(J)

H = hessian(f)(jnp.array([1.,2.,3.]))
print("hessian, with shape", H.shape)
print(H)

36.0
jacobian, with shape (3,)
[ 3. 12. 27.]
hessian, with shape (3, 3)
[[ 6.  0.  0.]
 [ 0. 12.  0.]
 [ 0.  0. 18.]]
