In [1]:
import numpy as np
import jax.numpy as jnp
from scipy.optimize import minimize
from jax.scipy import optimize 
import jax

In [2]:
# number of observations
N = 5000
# number of parameters
K = 2
# true parameter values
# beta = 2 * np.random.randn(K)
beta = np.array([2,2])
# true error std deviation
sigma =  0.5

def datagen(N, beta, sigma):
    """
    Generates data for OLS regression.
    Inputs:
    N: Number of observations
    beta: K x 1 true parameter values
    sigma: std dev of error
    """
    K = beta.shape[0]
    x_ = 10 + 2 * np.random.randn(N,K-1)
    # x is the N x K data matrix with column of ones
    #   in the first position for estimating a constant
    x = np.c_[np.ones(N),x_]
    # y is the N x 1 vector of dependent variables
    y = x.dot(beta) + sigma*np.random.randn(N)
    return y, x

y, x  = datagen(N, beta, sigma)

In [3]:
def neg_loglike(theta):
    beta = theta[:-1]
    # transform theta[-1]
    # so that sigma > 0
    sigma = jnp.exp(theta[-1])
    mu = jnp.dot(x,beta)
    ll = jax.numpy.sum(jax.scipy.stats.norm.logpdf(y, loc=mu, scale=sigma))
    return  (-1 * ll)/N

In [4]:
jacobian = jax.jacfwd(neg_loglike)
hessian = jax.hessian(neg_loglike)

In [5]:
theta = jnp.append(beta,jnp.log(sigma))
print(f'Jacobian : {jacobian(theta)} \n')
print(f'Hessian: {hessian(theta)}')

Jacobian : [-0.02804999 -0.26064894  0.04582791] 

Hessian: [[4.0000029e+00 4.0129120e+01 5.6099977e-02]
 [4.0129120e+01 4.1863495e+02 5.2129787e-01]
 [5.6099974e-02 5.2129787e-01 1.9083445e+00]]


In [6]:
theta_start = jax.numpy.append(jax.numpy.zeros(beta.shape[0]),0.0)
res1 = minimize(neg_loglike, theta_start, method = 'BFGS', 
	       options={'disp': True,'gtol': 1e-7*N}, jac = jacobian) # Tolerance added to aid in convergence
print("Convergence Achieved: ", res1.success)
print("Number of Function Evaluations: ", res1.nfev)

Optimization terminated successfully.
         Current function value: 0.702219
         Iterations: 22
         Function evaluations: 31
         Gradient evaluations: 31
Convergence Achieved:  True
Number of Function Evaluations:  31


In [7]:
print(res1)

      fun: 0.7022186517715454
 hess_inv: array([[ 6.29904616e+00, -6.05553100e-01,  2.23532868e-02],
       [-6.05553100e-01,  6.05430164e-02, -2.04262499e-03],
       [ 2.23532868e-02, -2.04262499e-03,  5.17293897e-01]])
      jac: array([3.1219483e-06, 4.0380863e-05, 5.0994877e-06], dtype=float32)
  message: 'Optimization terminated successfully.'
     nfev: 31
      nit: 22
     njev: 31
   status: 0
  success: True
        x: array([ 2.01998299,  1.99870721, -0.71671744])


In [8]:
print(f'Coefficient estimates: {res1.x[0]}, {res1.x[1]}')

new_sigma = np.exp(res1.x[-1])

print(f'Original sigma: {new_sigma}')

Coefficient estimates: 2.019982993143108, 1.9987072101363004
Original sigma: 0.488352673054439


In [9]:
res2 = jax.scipy.optimize.minimize(neg_loglike, theta_start, tol=1e-7*N, method='BFGS')

In [10]:
print(res2)

OptimizeResults(x=DeviceArray([ 0.59281933,  6.1378427 , -1.5077661 ], dtype=float32), success=DeviceArray(False, dtype=bool), status=DeviceArray(3, dtype=int32, weak_type=True), fun=DeviceArray(27.986916, dtype=float32), jac=DeviceArray([ -2.8423486, -29.52571  , -51.29997  ], dtype=float32), hess_inv=DeviceArray([[ 0.9984879 , -0.01572122, -0.03550449],
             [-0.01572122,  0.8365498 , -0.36913145],
             [-0.03550449, -0.36913145,  0.16713503]], dtype=float32), nfev=DeviceArray(6, dtype=int32, weak_type=True), njev=DeviceArray(6, dtype=int32, weak_type=True), nit=DeviceArray(2, dtype=int32, weak_type=True))


In [11]:
print(f'Betas: {res2.x[0]}, {res2.x[1]}')
new_sigma = np.exp(res2.x[-1])
print(f'Original sigma: {new_sigma}')

Betas: 0.592819333076477, 6.137842655181885
Original sigma: 0.22140401601791382
