Skip to content

Commit b76ca08

Browse files
Smit-createjstacHumphreyYangHengchengZhang
authored
Add exercise in MLE (#74)
* Add exercise in MLE * fix * Fix exercises * misc * update distribution and text * Remove figure settings * Fix a typo * update sections and variable names --------- Co-authored-by: John Stachurski <john.stachurski@gmail.com> Co-authored-by: Humphrey Yang <humphrey.yang@anu.edu.au> Co-authored-by: HengCheng <79777246+2789372130@users.noreply.github.com>
1 parent 1e3cc3b commit b76ca08

File tree

1 file changed

+116
-9
lines changed

1 file changed

+116
-9
lines changed

lectures/mle.md

Lines changed: 116 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ Here we will exploit the automatic differentiation capabilities of JAX rather th
3131
We'll require the following imports:
3232

3333
```{code-cell} ipython3
34-
%matplotlib inline
3534
import matplotlib.pyplot as plt
36-
plt.rcParams["figure.figsize"] = (11, 5) # set default figure size
3735
from collections import namedtuple
3836
import jax.numpy as jnp
3937
import jax
@@ -59,6 +57,10 @@ numerical methods to solve for parameter estimates.
5957

6058
One such numerical method is the Newton-Raphson algorithm.
6159

60+
Let's start with a simple example to illustrate the algorithm.
61+
62+
### A toy model
63+
6264
Our goal is to find the maximum likelihood estimate $\hat{\boldsymbol{\beta}}$.
6365

6466
At $\hat{\boldsymbol{\beta}}$, the first derivative of the log-likelihood
@@ -130,6 +132,8 @@ Then we use the updating rule involving gradient information to iterate the algo
130132

131133
Please refer to [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) for the detailed algorithm.
132134

135+
### A Poisson model
136+
133137
Let's have a go at implementing the Newton-Raphson algorithm to calculate the maximum likelihood estimations of a Poisson regression.
134138

135139
The Poisson regression has a joint pmf:
@@ -145,18 +149,18 @@ $$
145149
= \exp(\mathbf{x}_i' \boldsymbol{\beta})
146150
= \exp(\beta_0 + \beta_1 x_{i1} + \ldots + \beta_k x_{ik})
147151
$$
148-
152+
149153
We create a `namedtuple` to store the observed values
150154

151155
```{code-cell} ipython3
152-
PoissonRegressionModel = namedtuple('PoissonRegressionModel', ['X', 'y'])
156+
RegressionModel = namedtuple('RegressionModel', ['X', 'y'])
153157
154-
def create_poisson_model(X, y):
158+
def create_regression_model(X, y):
155159
n, k = X.shape
156160
# Reshape y as a n_by_1 column vector
157161
y = y.reshape(n, 1)
158162
X, y = jax.device_put((X, y))
159-
return PoissonRegressionModel(X=X, y=y)
163+
return RegressionModel(X=X, y=y)
160164
```
161165

162166
The log likelihood function of the Poisson regression is
@@ -203,7 +207,6 @@ def poisson_logL(β, model):
203207
return jnp.sum(model.y * jnp.log(μ) - μ - jnp.log(jax_factorial(y)))
204208
```
205209

206-
207210
To find the gradient of the `poisson_logL`, we again use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html).
208211

209212
According to [the documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev),
@@ -220,7 +223,7 @@ G_poisson_logL = jax.grad(poisson_logL)
220223
H_poisson_logL = jax.jacfwd(G_poisson_logL)
221224
```
222225

223-
Our function `newton_raphson` will take a `PoissonRegressionModel` object
226+
Our function `newton_raphson` will take a `RegressionModel` object
224227
that has an initial guess of the parameter vector $\boldsymbol{\beta}_0$.
225228

226229
The algorithm will update the parameter vector according to the updating
@@ -276,7 +279,7 @@ y = jnp.array([1, 0, 1, 1, 0])
276279
init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)
277280
278281
# Create an object with Poisson model values
279-
poi = create_poisson_model(X, y)
282+
poi = create_regression_model(X, y)
280283
281284
# Use newton_raphson to find the MLE
282285
β_hat = newton_raphson(poi, init_β, display=True)
@@ -311,3 +314,107 @@ y_numpy = y.__array__()
311314
stats_poisson = Poisson(y_numpy, X_numpy).fit()
312315
print(stats_poisson.summary())
313316
```
317+
318+
The benefit of writing our own procedure, relative to statsmodels is that
319+
320+
* we can exploit the power of the GPU and
321+
* we learn the underlying methodology, which can be extended to complex situations where no existing routines are available.
322+
323+
```{exercise-start}
324+
:label: newton_mle1
325+
```
326+
327+
We define a quadratic model for a single explanatory variable by
328+
329+
$$
330+
\log(\lambda_t) = \beta_0 + \beta_1 x_t + \beta_2 x_{t}^2
331+
$$
332+
333+
We calculate the mean on the original scale instead of the log scale by exponentiating both sides of the above equation, which gives
334+
335+
```{math}
336+
:label: lambda_mle
337+
\lambda_t = \exp(\beta_0 + \beta_1 x_t + \beta_2 x_{t}^2)
338+
```
339+
340+
Simulate the values of $x_t$ by sampling from a normal distribution and $\lambda_t$ by using {eq}`lambda_mle` and the following constants:
341+
342+
$$
343+
\beta_0 = -2.5,
344+
\quad
345+
\beta_1 = 0.25,
346+
\quad
347+
\beta_2 = 0.5
348+
$$
349+
350+
Try to obtain the approximate values of $\beta_0,\beta_1,\beta_2$, by simulating a Poission Regression Model such that
351+
352+
$$
353+
y_t \sim {\rm Poisson}(\lambda_t)
354+
\quad \text{for all } t.
355+
$$
356+
357+
Using our `newton_raphson` function on the data set $X = [1, x_t, x_t^{2}]$ and
358+
$y$, obtain the maximum likelihood estimates of $\beta_0,\beta_1,\beta_2$.
359+
360+
With a sufficient large sample size, you should approximately
361+
recover the true values of of these parameters.
362+
363+
364+
```{exercise-end}
365+
```
366+
367+
```{solution-start} newton_mle1
368+
:class: dropdown
369+
```
370+
371+
Let's start by defining "true" parameter values.
372+
373+
```{code-cell} ipython3
374+
β_0 = -2.5
375+
β_1 = 0.25
376+
β_2 = 0.5
377+
```
378+
379+
To simulate the model, we sample 500,000 values of $x_t$ from the standard normal distribution.
380+
381+
```{code-cell} ipython3
382+
seed = 32
383+
shape = (500_000, 1)
384+
key = jax.random.PRNGKey(seed)
385+
x = jax.random.normal(key, shape)
386+
```
387+
388+
We compute $\lambda$ using {eq}`lambda_mle`
389+
390+
```{code-cell} ipython3
391+
λ = jnp.exp(β_0 + β_1 * x + β_2 * x**2)
392+
```
393+
394+
Let's define $y_t$ by sampling from a Poission distribution with mean as $\lambda_t$.
395+
396+
```{code-cell} ipython3
397+
y = jax.random.poisson(key, λ, shape)
398+
```
399+
400+
Now let's try to recover the true parameter values using the Newton-Raphson
401+
method described above.
402+
403+
404+
```{code-cell} ipython3
405+
X = jnp.hstack((jnp.ones(shape), x, x**2))
406+
407+
# Take a guess at initial βs
408+
init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)
409+
410+
# Create an object with Poisson model values
411+
poi = create_regression_model(X, y)
412+
413+
# Use newton_raphson to find the MLE
414+
β_hat = newton_raphson(poi, init_β, tol=1e-5, display=True)
415+
```
416+
417+
The maximum likelihood estimates are similar to the true parameter values.
418+
419+
```{solution-end}
420+
```

0 commit comments

Comments
 (0)