@@ -31,9 +31,7 @@ Here we will exploit the automatic differentiation capabilities of JAX rather th
3131We'll require the following imports:
3232
3333``` {code-cell} ipython3
34- %matplotlib inline
3534import matplotlib.pyplot as plt
36- plt.rcParams["figure.figsize"] = (11, 5) # set default figure size
3735from collections import namedtuple
3836import jax.numpy as jnp
3937import jax
@@ -59,6 +57,10 @@ numerical methods to solve for parameter estimates.
5957
6058One such numerical method is the Newton-Raphson algorithm.
6159
60+ Let's start with a simple example to illustrate the algorithm.
61+
62+ ### A toy model
63+
6264Our goal is to find the maximum likelihood estimate $\hat{\boldsymbol{\beta}}$.
6365
6466At $\hat{\boldsymbol{\beta}}$, the first derivative of the log-likelihood
@@ -130,6 +132,8 @@ Then we use the updating rule involving gradient information to iterate the algo
130132
131133Please refer to [ this section] ( https://python.quantecon.org/mle.html#mle-with-numerical-methods ) for the detailed algorithm.
132134
135+ ### A Poisson model
136+
133137Let's have a go at implementing the Newton-Raphson algorithm to calculate the maximum likelihood estimations of a Poisson regression.
134138
135139The Poisson regression has a joint pmf:
145149 = \exp(\mathbf{x}_i' \boldsymbol{\beta})
146150 = \exp(\beta_0 + \beta_1 x_{i1} + \ldots + \beta_k x_{ik})
147151$$
148-
152+
149153We create a ` namedtuple ` to store the observed values
150154
151155``` {code-cell} ipython3
152- PoissonRegressionModel = namedtuple('PoissonRegressionModel ', ['X', 'y'])
156+ RegressionModel = namedtuple('RegressionModel ', ['X', 'y'])
153157
154- def create_poisson_model (X, y):
158+ def create_regression_model (X, y):
155159 n, k = X.shape
156160 # Reshape y as a n_by_1 column vector
157161 y = y.reshape(n, 1)
158162 X, y = jax.device_put((X, y))
159- return PoissonRegressionModel (X=X, y=y)
163+ return RegressionModel (X=X, y=y)
160164```
161165
162166The log likelihood function of the Poisson regression is
@@ -203,7 +207,6 @@ def poisson_logL(β, model):
203207 return jnp.sum(model.y * jnp.log(μ) - μ - jnp.log(jax_factorial(y)))
204208```
205209
206-
207210To find the gradient of the ` poisson_logL ` , we again use [ jax.grad] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html ) .
208211
209212According to [ the documentation] ( https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev ) ,
@@ -220,7 +223,7 @@ G_poisson_logL = jax.grad(poisson_logL)
220223H_poisson_logL = jax.jacfwd(G_poisson_logL)
221224```
222225
223- Our function ` newton_raphson ` will take a ` PoissonRegressionModel ` object
226+ Our function ` newton_raphson ` will take a ` RegressionModel ` object
224227that has an initial guess of the parameter vector $\boldsymbol{\beta}_ 0$.
225228
226229The algorithm will update the parameter vector according to the updating
@@ -276,7 +279,7 @@ y = jnp.array([1, 0, 1, 1, 0])
276279init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)
277280
278281# Create an object with Poisson model values
279- poi = create_poisson_model (X, y)
282+ poi = create_regression_model (X, y)
280283
281284# Use newton_raphson to find the MLE
282285β_hat = newton_raphson(poi, init_β, display=True)
@@ -311,3 +314,107 @@ y_numpy = y.__array__()
311314stats_poisson = Poisson(y_numpy, X_numpy).fit()
312315print(stats_poisson.summary())
313316```
317+
318+ The benefit of writing our own procedure, relative to statsmodels is that
319+
320+ * we can exploit the power of the GPU and
321+ * we learn the underlying methodology, which can be extended to complex situations where no existing routines are available.
322+
323+ ``` {exercise-start}
324+ :label: newton_mle1
325+ ```
326+
327+ We define a quadratic model for a single explanatory variable by
328+
329+ $$
330+ \log(\lambda_t) = \beta_0 + \beta_1 x_t + \beta_2 x_{t}^2
331+ $$
332+
333+ We calculate the mean on the original scale instead of the log scale by exponentiating both sides of the above equation, which gives
334+
335+ ``` {math}
336+ :label: lambda_mle
337+ \lambda_t = \exp(\beta_0 + \beta_1 x_t + \beta_2 x_{t}^2)
338+ ```
339+
340+ Simulate the values of $x_t$ by sampling from a normal distribution and $\lambda_t$ by using {eq}` lambda_mle ` and the following constants:
341+
342+ $$
343+ \beta_0 = -2.5,
344+ \quad
345+ \beta_1 = 0.25,
346+ \quad
347+ \beta_2 = 0.5
348+ $$
349+
350+ Try to obtain the approximate values of $\beta_0,\beta_1,\beta_2$, by simulating a Poission Regression Model such that
351+
352+ $$
353+ y_t \sim {\rm Poisson}(\lambda_t)
354+ \quad \text{for all } t.
355+ $$
356+
357+ Using our ` newton_raphson ` function on the data set $X = [ 1, x_t, x_t^{2}] $ and
358+ $y$, obtain the maximum likelihood estimates of $\beta_0,\beta_1,\beta_2$.
359+
360+ With a sufficient large sample size, you should approximately
361+ recover the true values of of these parameters.
362+
363+
364+ ``` {exercise-end}
365+ ```
366+
367+ ``` {solution-start} newton_mle1
368+ :class: dropdown
369+ ```
370+
371+ Let's start by defining "true" parameter values.
372+
373+ ``` {code-cell} ipython3
374+ β_0 = -2.5
375+ β_1 = 0.25
376+ β_2 = 0.5
377+ ```
378+
379+ To simulate the model, we sample 500,000 values of $x_t$ from the standard normal distribution.
380+
381+ ``` {code-cell} ipython3
382+ seed = 32
383+ shape = (500_000, 1)
384+ key = jax.random.PRNGKey(seed)
385+ x = jax.random.normal(key, shape)
386+ ```
387+
388+ We compute $\lambda$ using {eq}` lambda_mle `
389+
390+ ``` {code-cell} ipython3
391+ λ = jnp.exp(β_0 + β_1 * x + β_2 * x**2)
392+ ```
393+
394+ Let's define $y_t$ by sampling from a Poission distribution with mean as $\lambda_t$.
395+
396+ ``` {code-cell} ipython3
397+ y = jax.random.poisson(key, λ, shape)
398+ ```
399+
400+ Now let's try to recover the true parameter values using the Newton-Raphson
401+ method described above.
402+
403+
404+ ``` {code-cell} ipython3
405+ X = jnp.hstack((jnp.ones(shape), x, x**2))
406+
407+ # Take a guess at initial βs
408+ init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)
409+
410+ # Create an object with Poisson model values
411+ poi = create_regression_model(X, y)
412+
413+ # Use newton_raphson to find the MLE
414+ β_hat = newton_raphson(poi, init_β, tol=1e-5, display=True)
415+ ```
416+
417+ The maximum likelihood estimates are similar to the true parameter values.
418+
419+ ``` {solution-end}
420+ ```
0 commit comments