Skip to content

Commit 1e3cc3b

Browse files
authored
Update MLE Lecture with More Details (#72)
* update mle lecture with more details * remove capitalization * revert user expressions
1 parent f1f16df commit 1e3cc3b

File tree

1 file changed

+33
-15
lines changed

1 file changed

+33
-15
lines changed

lectures/mle.md

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ kernelspec:
1111
name: python3
1212
---
1313

14-
1514
# Maximum Likelihood Estimation
1615

1716
```{contents} Contents
@@ -47,15 +46,13 @@ Let's check the GPU we are running
4746
!nvidia-smi
4847
```
4948

50-
5149
We will use 64 bit floats with JAX in order to increase the precision.
5250

5351
```{code-cell} ipython3
5452
jax.config.update("jax_enable_x64", True)
5553
```
5654

57-
58-
## MLE with Numerical Methods (JAX)
55+
## MLE with numerical methods (JAX)
5956

6057
Many distributions do not have nice, analytical solutions and therefore require
6158
numerical methods to solve for parameter estimates.
@@ -81,7 +78,6 @@ def logL(β):
8178
return -(β - 10) ** 2 - 10
8279
```
8380

84-
8581
To find the value of $\frac{d \log \mathcal{L(\boldsymbol{\beta})}}{d \boldsymbol{\beta}}$, we can use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) which auto-differentiates the given function.
8682

8783
We further use [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) which vectorizes the given function i.e. the function acting upon scalar inputs can now be used with vector inputs.
@@ -113,7 +109,6 @@ plt.axhline(c='black')
113109
plt.show()
114110
```
115111

116-
117112
The plot shows that the maximum likelihood value (the top plot) occurs
118113
when $\frac{d \log \mathcal{L(\boldsymbol{\beta})}}{d \boldsymbol{\beta}} = 0$ (the bottom
119114
plot).
@@ -129,14 +124,29 @@ The Newton-Raphson algorithm finds a point where the first derivative is
129124

130125
To use the algorithm, we take an initial guess at the maximum value,
131126
$\beta_0$ (the OLS parameter estimates might be a reasonable
132-
guess), then
127+
guess).
133128

129+
Then we use the updating rule involving gradient information to iterate the algorithm until the error is sufficiently small or the algorithm reaches the maximum number of iterations.
134130

135131
Please refer to [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) for the detailed algorithm.
136132

137-
Let's have a go at implementing the Newton-Raphson algorithm.
133+
Let's have a go at implementing the Newton-Raphson algorithm to calculate the maximum likelihood estimations of a Poisson regression.
138134

139-
First, we'll create a `PoissonRegressionModel`.
135+
The Poisson regression has a joint pmf:
136+
137+
$$
138+
f(y_1, y_2, \ldots, y_n \mid \mathbf{x}_1, \mathbf{x}_2, \ldots, \mathbf{x}_n; \boldsymbol{\beta})
139+
= \prod_{i=1}^{n} \frac{\mu_i^{y_i}}{y_i!} e^{-\mu_i}
140+
141+
$$
142+
143+
$$
144+
\text{where}\ \mu_i
145+
= \exp(\mathbf{x}_i' \boldsymbol{\beta})
146+
= \exp(\beta_0 + \beta_1 x_{i1} + \ldots + \beta_k x_{ik})
147+
$$
148+
149+
We create a `namedtuple` to store the observed values
140150

141151
```{code-cell} ipython3
142152
PoissonRegressionModel = namedtuple('PoissonRegressionModel', ['X', 'y'])
@@ -149,8 +159,18 @@ def create_poisson_model(X, y):
149159
return PoissonRegressionModel(X=X, y=y)
150160
```
151161

162+
The log likelihood function of the Poisson regression is
163+
164+
$$
165+
\underset{\beta}{\max} \Big(
166+
\sum_{i=1}^{n} y_i \log{\mu_i} -
167+
\sum_{i=1}^{n} \mu_i -
168+
\sum_{i=1}^{n} \log y! \Big)
169+
$$
170+
171+
The full derivation can be found [here](https://python.quantecon.org/mle.html#id2).
152172

153-
At present, JAX doesn't have an implementation to compute factorial directly.
173+
The log likelihood function involves factorial, but JAX doesn't have a readily available implementation to compute factorial directly.
154174

155175
In order to compute the factorial efficiently such that we can JIT it, we use
156176

@@ -162,6 +182,8 @@ since [jax.lax.lgamma](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax
162182

163183
The following function `jax_factorial` computes the factorial using this idea.
164184

185+
Let's define this function in Python
186+
165187
```{code-cell} ipython3
166188
@jax.jit
167189
def _factorial(n):
@@ -171,7 +193,7 @@ jax_factorial = jax.vmap(_factorial)
171193
```
172194

173195

174-
Let's define the Poisson Regression's log likelihood function.
196+
Now we can define the log likelihood function in Python
175197

176198
```{code-cell} ipython3
177199
@jax.jit
@@ -198,7 +220,6 @@ G_poisson_logL = jax.grad(poisson_logL)
198220
H_poisson_logL = jax.jacfwd(G_poisson_logL)
199221
```
200222

201-
202223
Our function `newton_raphson` will take a `PoissonRegressionModel` object
203224
that has an initial guess of the parameter vector $\boldsymbol{\beta}_0$.
204225

@@ -239,7 +260,6 @@ def newton_raphson(model, β, tol=1e-3, max_iter=100, display=True):
239260
return β
240261
```
241262

242-
243263
Let's try out our algorithm with a small dataset of 5 observations and 3
244264
variables in $\mathbf{X}$.
245265

@@ -262,7 +282,6 @@ poi = create_poisson_model(X, y)
262282
β_hat = newton_raphson(poi, init_β, display=True)
263283
```
264284

265-
266285
As this was a simple model with few observations, the algorithm achieved
267286
convergence in only 7 iterations.
268287

@@ -272,7 +291,6 @@ The gradient vector should be close to 0 at $\hat{\boldsymbol{\beta}}$
272291
G_poisson_logL(β_hat, poi)
273292
```
274293

275-
276294
## MLE with `statsmodels`
277295

278296
We’ll use the Poisson regression model in `statsmodels` to verify the results

0 commit comments

Comments
 (0)