|
| 1 | +--- |
| 2 | +jupytext: |
| 3 | + text_representation: |
| 4 | + extension: .md |
| 5 | + format_name: myst |
| 6 | + format_version: 0.13 |
| 7 | + jupytext_version: 1.14.5 |
| 8 | +kernelspec: |
| 9 | + display_name: Python 3 (ipykernel) |
| 10 | + language: python |
| 11 | + name: python3 |
| 12 | +--- |
| 13 | + |
| 14 | + |
| 15 | +# Maximum Likelihood Estimation |
| 16 | + |
| 17 | +```{contents} Contents |
| 18 | +:depth: 2 |
| 19 | +``` |
| 20 | + |
| 21 | +```{include} _admonition/gpu.md |
| 22 | +``` |
| 23 | + |
| 24 | +## Overview |
| 25 | + |
| 26 | +This lecture is the extended JAX implementation of [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) of [this lecture](https://python.quantecon.org/mle.html). |
| 27 | + |
| 28 | +Please refer that lecture for all background and notation. |
| 29 | + |
| 30 | +Here we will exploit the automatic differentiation capabilities of JAX rather than calculating derivatives by hand. |
| 31 | + |
| 32 | +We'll require the following imports: |
| 33 | + |
| 34 | +```{code-cell} ipython3 |
| 35 | +%matplotlib inline |
| 36 | +import matplotlib.pyplot as plt |
| 37 | +plt.rcParams["figure.figsize"] = (11, 5) # set default figure size |
| 38 | +from collections import namedtuple |
| 39 | +import jax.numpy as jnp |
| 40 | +import jax |
| 41 | +from statsmodels.api import Poisson |
| 42 | +``` |
| 43 | + |
| 44 | +Let's check the GPU we are running |
| 45 | + |
| 46 | +```{code-cell} ipython3 |
| 47 | +!nvidia-smi |
| 48 | +``` |
| 49 | + |
| 50 | + |
| 51 | +We will use 64 bit floats with JAX in order to increase the precision. |
| 52 | + |
| 53 | +```{code-cell} ipython3 |
| 54 | +jax.config.update("jax_enable_x64", True) |
| 55 | +``` |
| 56 | + |
| 57 | + |
| 58 | +## MLE with Numerical Methods (JAX) |
| 59 | + |
| 60 | +Many distributions do not have nice, analytical solutions and therefore require |
| 61 | +numerical methods to solve for parameter estimates. |
| 62 | + |
| 63 | +One such numerical method is the Newton-Raphson algorithm. |
| 64 | + |
| 65 | +Our goal is to find the maximum likelihood estimate $\hat{\boldsymbol{\beta}}$. |
| 66 | + |
| 67 | +At $\hat{\boldsymbol{\beta}}$, the first derivative of the log-likelihood |
| 68 | +function will be equal to 0. |
| 69 | + |
| 70 | +Let's illustrate this by supposing |
| 71 | + |
| 72 | +$$ |
| 73 | +\log \mathcal{L(\beta)} = - (\beta - 10) ^2 - 10 |
| 74 | +$$ |
| 75 | + |
| 76 | +Define the function `logL`. |
| 77 | + |
| 78 | +```{code-cell} ipython3 |
| 79 | +@jax.jit |
| 80 | +def logL(β): |
| 81 | + return -(β - 10) ** 2 - 10 |
| 82 | +``` |
| 83 | + |
| 84 | + |
| 85 | +To find the value of $\frac{d \log \mathcal{L(\boldsymbol{\beta})}}{d \boldsymbol{\beta}}$, we can use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) which auto-differentiates the given function. |
| 86 | + |
| 87 | +We further use [jax.vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html) which vectorizes the given function i.e. the function acting upon scalar inputs can now be used with vector inputs. |
| 88 | + |
| 89 | +```{code-cell} ipython3 |
| 90 | +dlogL = jax.vmap(jax.grad(logL)) |
| 91 | +``` |
| 92 | + |
| 93 | +```{code-cell} ipython3 |
| 94 | +β = jnp.linspace(1, 20) |
| 95 | +
|
| 96 | +fig, (ax1, ax2) = plt.subplots(2, sharex=True, figsize=(12, 8)) |
| 97 | +
|
| 98 | +ax1.plot(β, logL(β), lw=2) |
| 99 | +ax2.plot(β, dlogL(β), lw=2) |
| 100 | +
|
| 101 | +ax1.set_ylabel(r'$log \mathcal{L(\beta)}$', |
| 102 | + rotation=0, |
| 103 | + labelpad=35, |
| 104 | + fontsize=15) |
| 105 | +ax2.set_ylabel(r'$\frac{dlog \mathcal{L(\beta)}}{d \beta}$ ', |
| 106 | + rotation=0, |
| 107 | + labelpad=35, |
| 108 | + fontsize=19) |
| 109 | +
|
| 110 | +ax2.set_xlabel(r'$\beta$', fontsize=15) |
| 111 | +ax1.grid(), ax2.grid() |
| 112 | +plt.axhline(c='black') |
| 113 | +plt.show() |
| 114 | +``` |
| 115 | + |
| 116 | + |
| 117 | +The plot shows that the maximum likelihood value (the top plot) occurs |
| 118 | +when $\frac{d \log \mathcal{L(\boldsymbol{\beta})}}{d \boldsymbol{\beta}} = 0$ (the bottom |
| 119 | +plot). |
| 120 | + |
| 121 | +Therefore, the likelihood is maximized when $\beta = 10$. |
| 122 | + |
| 123 | +We can also ensure that this value is a *maximum* (as opposed to a |
| 124 | +minimum) by checking that the second derivative (slope of the bottom |
| 125 | +plot) is negative. |
| 126 | + |
| 127 | +The Newton-Raphson algorithm finds a point where the first derivative is |
| 128 | +0. |
| 129 | + |
| 130 | +To use the algorithm, we take an initial guess at the maximum value, |
| 131 | +$\beta_0$ (the OLS parameter estimates might be a reasonable |
| 132 | +guess), then |
| 133 | + |
| 134 | + |
| 135 | +Please refer to [this section](https://python.quantecon.org/mle.html#mle-with-numerical-methods) for the detailed algorithm. |
| 136 | + |
| 137 | +Let's have a go at implementing the Newton-Raphson algorithm. |
| 138 | + |
| 139 | +First, we'll create a `PoissonRegressionModel`. |
| 140 | + |
| 141 | +```{code-cell} ipython3 |
| 142 | +PoissonRegressionModel = namedtuple('PoissonRegressionModel', ['X', 'y']) |
| 143 | +
|
| 144 | +def create_poisson_model(X, y): |
| 145 | + n, k = X.shape |
| 146 | + # Reshape y as a n_by_1 column vector |
| 147 | + y = y.reshape(n, 1) |
| 148 | + X, y = jax.device_put((X, y)) |
| 149 | + return PoissonRegressionModel(X=X, y=y) |
| 150 | +``` |
| 151 | + |
| 152 | + |
| 153 | +At present, JAX doesn't have an implementation to compute factorial directly. |
| 154 | + |
| 155 | +In order to compute the factorial efficiently such that we can JIT it, we use |
| 156 | + |
| 157 | +$$ |
| 158 | + n! = e^{\log(\Gamma(n+1))} |
| 159 | +$$ |
| 160 | + |
| 161 | +since [jax.lax.lgamma](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.lgamma.html) and [jax.lax.exp](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.exp.html) are available. |
| 162 | + |
| 163 | +The following function `jax_factorial` computes the factorial using this idea. |
| 164 | + |
| 165 | +```{code-cell} ipython3 |
| 166 | +@jax.jit |
| 167 | +def _factorial(n): |
| 168 | + return jax.lax.exp(jax.lax.lgamma(n + 1.0)).astype(int) |
| 169 | +
|
| 170 | +jax_factorial = jax.vmap(_factorial) |
| 171 | +``` |
| 172 | + |
| 173 | + |
| 174 | +Let's define the Poisson Regression's log likelihood function. |
| 175 | + |
| 176 | +```{code-cell} ipython3 |
| 177 | +@jax.jit |
| 178 | +def poisson_logL(β, model): |
| 179 | + y = model.y |
| 180 | + μ = jnp.exp(model.X @ β) |
| 181 | + return jnp.sum(model.y * jnp.log(μ) - μ - jnp.log(jax_factorial(y))) |
| 182 | +``` |
| 183 | + |
| 184 | + |
| 185 | +To find the gradient of the `poisson_logL`, we again use [jax.grad](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html). |
| 186 | + |
| 187 | +According to [the documentation](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#jacobians-and-hessians-using-jacfwd-and-jacrev), |
| 188 | + |
| 189 | +* `jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for “tall” Jacobian matrices, while |
| 190 | +* `jax.jacrev` uses reverse-mode, which is more efficient for “wide” Jacobian matrices. |
| 191 | + |
| 192 | +(The documentation also states that when matrices that are near-square, `jax.jacfwd` probably has an edge over `jax.jacrev`.) |
| 193 | + |
| 194 | +Therefore, to find the Hessian, we can directly use `jax.jacfwd`. |
| 195 | + |
| 196 | +```{code-cell} ipython3 |
| 197 | +G_poisson_logL = jax.grad(poisson_logL) |
| 198 | +H_poisson_logL = jax.jacfwd(G_poisson_logL) |
| 199 | +``` |
| 200 | + |
| 201 | + |
| 202 | +Our function `newton_raphson` will take a `PoissonRegressionModel` object |
| 203 | +that has an initial guess of the parameter vector $\boldsymbol{\beta}_0$. |
| 204 | + |
| 205 | +The algorithm will update the parameter vector according to the updating |
| 206 | +rule, and recalculate the gradient and Hessian matrices at the new |
| 207 | +parameter estimates. |
| 208 | + |
| 209 | +```{code-cell} ipython3 |
| 210 | +def newton_raphson(model, β, tol=1e-3, max_iter=100, display=True): |
| 211 | +
|
| 212 | + i = 0 |
| 213 | + error = 100 # Initial error value |
| 214 | +
|
| 215 | + # Print header of output |
| 216 | + if display: |
| 217 | + header = f'{"Iteration_k":<13}{"Log-likelihood":<16}{"θ":<60}' |
| 218 | + print(header) |
| 219 | + print("-" * len(header)) |
| 220 | +
|
| 221 | + # While loop runs while any value in error is greater |
| 222 | + # than the tolerance until max iterations are reached |
| 223 | + while jnp.any(error > tol) and i < max_iter: |
| 224 | + H, G = jnp.squeeze(H_poisson_logL(β, model)), G_poisson_logL(β, model) |
| 225 | + β_new = β - (jnp.dot(jnp.linalg.inv(H), G)) |
| 226 | + error = jnp.abs(β_new - β) |
| 227 | + β = β_new |
| 228 | +
|
| 229 | + if display: |
| 230 | + β_list = [f'{t:.3}' for t in list(β.flatten())] |
| 231 | + update = f'{i:<13}{poisson_logL(β, model):<16.8}{β_list}' |
| 232 | + print(update) |
| 233 | +
|
| 234 | + i += 1 |
| 235 | +
|
| 236 | + print(f'Number of iterations: {i}') |
| 237 | + print(f'β_hat = {β.flatten()}') |
| 238 | +
|
| 239 | + return β |
| 240 | +``` |
| 241 | + |
| 242 | + |
| 243 | +Let's try out our algorithm with a small dataset of 5 observations and 3 |
| 244 | +variables in $\mathbf{X}$. |
| 245 | + |
| 246 | +```{code-cell} ipython3 |
| 247 | +X = jnp.array([[1, 2, 5], |
| 248 | + [1, 1, 3], |
| 249 | + [1, 4, 2], |
| 250 | + [1, 5, 2], |
| 251 | + [1, 3, 1]]) |
| 252 | +
|
| 253 | +y = jnp.array([1, 0, 1, 1, 0]) |
| 254 | +
|
| 255 | +# Take a guess at initial βs |
| 256 | +init_β = jnp.array([0.1, 0.1, 0.1]).reshape(X.shape[1], 1) |
| 257 | +
|
| 258 | +# Create an object with Poisson model values |
| 259 | +poi = create_poisson_model(X, y) |
| 260 | +
|
| 261 | +# Use newton_raphson to find the MLE |
| 262 | +β_hat = newton_raphson(poi, init_β, display=True) |
| 263 | +``` |
| 264 | + |
| 265 | + |
| 266 | +As this was a simple model with few observations, the algorithm achieved |
| 267 | +convergence in only 7 iterations. |
| 268 | + |
| 269 | +The gradient vector should be close to 0 at $\hat{\boldsymbol{\beta}}$ |
| 270 | + |
| 271 | +```{code-cell} ipython3 |
| 272 | +G_poisson_logL(β_hat, poi) |
| 273 | +``` |
| 274 | + |
| 275 | + |
| 276 | +## MLE with `statsmodels` |
| 277 | + |
| 278 | +We’ll use the Poisson regression model in `statsmodels` to verify the results |
| 279 | +obtained using JAX. |
| 280 | + |
| 281 | +`statsmodels` uses the same algorithm as above to find the maximum |
| 282 | +likelihood estimates. |
| 283 | + |
| 284 | +Now, as `statsmodels` accepts only NumPy arrays, we can use the `__array__` method |
| 285 | +of JAX arrays to convert it to NumPy arrays. |
| 286 | + |
| 287 | +```{code-cell} ipython3 |
| 288 | +X_numpy = X.__array__() |
| 289 | +y_numpy = y.__array__() |
| 290 | +``` |
| 291 | + |
| 292 | +```{code-cell} ipython3 |
| 293 | +stats_poisson = Poisson(y_numpy, X_numpy).fit() |
| 294 | +print(stats_poisson.summary()) |
| 295 | +``` |
0 commit comments