Skip to content

Commit f1f16df

Browse files
authored
Add MLE lecture (#71)
* add initial jax section * complete MLE section * fixes * fix bug in gradient cell * use \log
1 parent 837d7a7 commit f1f16df

File tree

2 files changed

+299
-0
lines changed

2 files changed

+299
-0
lines changed

lectures/_toc.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ parts:
1414
- file: inventory_dynamics
1515
- file: kesten_processes
1616
- file: wealth_dynamics
17+
- caption: Data and Empirics
18+
numbered: true
19+
chapters:
20+
- file: mle
1721
- caption: Dynamic Programming
1822
numbered: true
1923
chapters:

lectures/mle.md

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
---
2+
jupytext:
3+
text_representation:
4+
extension: .md
5+
format_name: myst
6+
format_version: 0.13
7+
jupytext_version: 1.14.5
8+
kernelspec:
9+
display_name: Python 3 (ipykernel)
10+
language: python
11+
name: python3
12+
---
13+
14+
15+
# Maximum Likelihood Estimation
16+
17+
```{contents} Contents
18+
:depth: 2
19+
```
20+
21+
```{include} _admonition/gpu.md
22+
```
23+
24+
## Overview
25+
26+
This lecture is the extended JAX implementation of [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) of [this lecture](https://python.quantecon.org/mle.html).
27+
28+
Please refer that lecture for all background and notation.
29+
30+
Here we will exploit the automatic differentiation capabilities of JAX rather than calculating derivatives by hand.
31+
32+
We'll require the following imports:
33+
34+
```{code-cell} ipython3
35+
%matplotlib inline
36+
import matplotlib.pyplot as plt
37+
plt.rcParams["figure.figsize"] = (11, 5) # set default figure size
38+
from collections import namedtuple
39+
import jax.numpy as jnp
40+
import jax
41+
from statsmodels.api import Poisson
42+
```
43+
44+
Let's check the GPU we are running
45+
46+
```{code-cell} ipython3
47+
!nvidia-smi
48+
```
49+
50+
51+
We will use 64 bit floats with JAX in order to increase the precision.
52+
53+
```{code-cell} ipython3
54+
jax.config.update("jax_enable_x64", True)
55+
```
56+
57+
58+
## MLE with Numerical Methods (JAX)
59+
60+
Many distributions do not have nice, analytical solutions and therefore require
61+
numerical methods to solve for parameter estimates.
62+
63+
One such numerical method is the Newton-Raphson algorithm.
64+
65+
Our goal is to find the maximum likelihood estimate $\hat{\boldsymbol{\beta}}$.
66+
67+
At $\hat{\boldsymbol{\beta}}$, the first derivative of the log-likelihood
68+
function will be equal to 0.
69+
70+
Let's illustrate this by supposing
71+
72+
$$
73+
\log \mathcal{L(\beta)} = - (\beta - 10) ^2 - 10
74+
$$
75+
76+
Define the function `logL`.
77+
78+
```{code-cell} ipython3
79+
@jax.jit
80+
def logL(β):
81+
return -(β - 10) ** 2 - 10
82+
```
83+
84+
85+
To find the value of $\frac{d \log \mathcal{L(\boldsymbol{\beta})}}{d \boldsymbol{\beta}}$, we can use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) which auto-differentiates the given function.
86+
87+
We further use [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) which vectorizes the given function i.e. the function acting upon scalar inputs can now be used with vector inputs.
88+
89+
```{code-cell} ipython3
90+
dlogL = jax.vmap(jax.grad(logL))
91+
```
92+
93+
```{code-cell} ipython3
94+
β = jnp.linspace(1, 20)
95+
96+
fig, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(12, 8))
97+
98+
ax1.plot(β, logL(β), lw=2)
99+
ax2.plot(β, dlogL(β), lw=2)
100+
101+
ax1.set_ylabel(r'$log \mathcal{L(\beta)}$',
102+
rotation=0,
103+
labelpad=35,
104+
fontsize=15)
105+
ax2.set_ylabel(r'$\frac{dlog \mathcal{L(\beta)}}{d \beta}$ ',
106+
rotation=0,
107+
labelpad=35,
108+
fontsize=19)
109+
110+
ax2.set_xlabel(r'$\beta$', fontsize=15)
111+
ax1.grid(), ax2.grid()
112+
plt.axhline(c='black')
113+
plt.show()
114+
```
115+
116+
117+
The plot shows that the maximum likelihood value (the top plot) occurs
118+
when $\frac{d \log \mathcal{L(\boldsymbol{\beta})}}{d \boldsymbol{\beta}} = 0$ (the bottom
119+
plot).
120+
121+
Therefore, the likelihood is maximized when $\beta = 10$.
122+
123+
We can also ensure that this value is a *maximum* (as opposed to a
124+
minimum) by checking that the second derivative (slope of the bottom
125+
plot) is negative.
126+
127+
The Newton-Raphson algorithm finds a point where the first derivative is
128+
0.
129+
130+
To use the algorithm, we take an initial guess at the maximum value,
131+
$\beta_0$ (the OLS parameter estimates might be a reasonable
132+
guess), then
133+
134+
135+
Please refer to [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) for the detailed algorithm.
136+
137+
Let's have a go at implementing the Newton-Raphson algorithm.
138+
139+
First, we'll create a `PoissonRegressionModel`.
140+
141+
```{code-cell} ipython3
142+
PoissonRegressionModel = namedtuple('PoissonRegressionModel', ['X', 'y'])
143+
144+
def create_poisson_model(X, y):
145+
n, k = X.shape
146+
# Reshape y as a n_by_1 column vector
147+
y = y.reshape(n, 1)
148+
X, y = jax.device_put((X, y))
149+
return PoissonRegressionModel(X=X, y=y)
150+
```
151+
152+
153+
At present, JAX doesn't have an implementation to compute factorial directly.
154+
155+
In order to compute the factorial efficiently such that we can JIT it, we use
156+
157+
$$
158+
n! = e^{\log(\Gamma(n+1))}
159+
$$
160+
161+
since [jax.lax.lgamma](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.lgamma.html) and [jax.lax.exp](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.exp.html) are available.
162+
163+
The following function `jax_factorial` computes the factorial using this idea.
164+
165+
```{code-cell} ipython3
166+
@jax.jit
167+
def _factorial(n):
168+
return jax.lax.exp(jax.lax.lgamma(n + 1.0)).astype(int)
169+
170+
jax_factorial = jax.vmap(_factorial)
171+
```
172+
173+
174+
Let's define the Poisson Regression's log likelihood function.
175+
176+
```{code-cell} ipython3
177+
@jax.jit
178+
def poisson_logL(β, model):
179+
y = model.y
180+
μ = jnp.exp(model.X @ β)
181+
return jnp.sum(model.y * jnp.log(μ) - μ - jnp.log(jax_factorial(y)))
182+
```
183+
184+
185+
To find the gradient of the `poisson_logL`, we again use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html).
186+
187+
According to [the documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev),
188+
189+
* `jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while
190+
* `jax.jacrev` uses reverse-mode, which is more efficient for “wide” Jacobian matrices.
191+
192+
(The documentation also states that when matrices that are near-square, `jax.jacfwd` probably has an edge over `jax.jacrev`.)
193+
194+
Therefore, to find the Hessian, we can directly use `jax.jacfwd`.
195+
196+
```{code-cell} ipython3
197+
G_poisson_logL = jax.grad(poisson_logL)
198+
H_poisson_logL = jax.jacfwd(G_poisson_logL)
199+
```
200+
201+
202+
Our function `newton_raphson` will take a `PoissonRegressionModel` object
203+
that has an initial guess of the parameter vector $\boldsymbol{\beta}_0$.
204+
205+
The algorithm will update the parameter vector according to the updating
206+
rule, and recalculate the gradient and Hessian matrices at the new
207+
parameter estimates.
208+
209+
```{code-cell} ipython3
210+
def newton_raphson(model, β, tol=1e-3, max_iter=100, display=True):
211+
212+
i = 0
213+
error = 100 # Initial error value
214+
215+
# Print header of output
216+
if display:
217+
header = f'{"Iteration_k":<13}{"Log-likelihood":<16}{"θ":<60}'
218+
print(header)
219+
print("-" * len(header))
220+
221+
# While loop runs while any value in error is greater
222+
# than the tolerance until max iterations are reached
223+
while jnp.any(error > tol) and i < max_iter:
224+
H, G = jnp.squeeze(H_poisson_logL(β, model)), G_poisson_logL(β, model)
225+
β_new = β - (jnp.dot(jnp.linalg.inv(H), G))
226+
error = jnp.abs(β_new - β)
227+
β = β_new
228+
229+
if display:
230+
β_list = [f'{t:.3}' for t in list(β.flatten())]
231+
update = f'{i:<13}{poisson_logL(β, model):<16.8}{β_list}'
232+
print(update)
233+
234+
i += 1
235+
236+
print(f'Number of iterations: {i}')
237+
print(f'β_hat = {β.flatten()}')
238+
239+
return β
240+
```
241+
242+
243+
Let's try out our algorithm with a small dataset of 5 observations and 3
244+
variables in $\mathbf{X}$.
245+
246+
```{code-cell} ipython3
247+
X = jnp.array([[1, 2, 5],
248+
[1, 1, 3],
249+
[1, 4, 2],
250+
[1, 5, 2],
251+
[1, 3, 1]])
252+
253+
y = jnp.array([1, 0, 1, 1, 0])
254+
255+
# Take a guess at initial βs
256+
init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1)
257+
258+
# Create an object with Poisson model values
259+
poi = create_poisson_model(X, y)
260+
261+
# Use newton_raphson to find the MLE
262+
β_hat = newton_raphson(poi, init_β, display=True)
263+
```
264+
265+
266+
As this was a simple model with few observations, the algorithm achieved
267+
convergence in only 7 iterations.
268+
269+
The gradient vector should be close to 0 at $\hat{\boldsymbol{\beta}}$
270+
271+
```{code-cell} ipython3
272+
G_poisson_logL(β_hat, poi)
273+
```
274+
275+
276+
## MLE with `statsmodels`
277+
278+
We’ll use the Poisson regression model in `statsmodels` to verify the results
279+
obtained using JAX.
280+
281+
`statsmodels` uses the same algorithm as above to find the maximum
282+
likelihood estimates.
283+
284+
Now, as `statsmodels` accepts only NumPy arrays, we can use the `__array__` method
285+
of JAX arrays to convert it to NumPy arrays.
286+
287+
```{code-cell} ipython3
288+
X_numpy = X.__array__()
289+
y_numpy = y.__array__()
290+
```
291+
292+
```{code-cell} ipython3
293+
stats_poisson = Poisson(y_numpy, X_numpy).fit()
294+
print(stats_poisson.summary())
295+
```

0 commit comments

Comments
 (0)