In [1]:
!pip install -U jax jaxlib

Collecting jax
  Downloading jax-0.4.30-py3-none-any.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
Collecting jaxlib
  Downloading jaxlib-0.4.30-cp310-cp310-manylinux2014_x86_64.whl (79.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.6/79.6 MB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxlib, jax
  Attempting uninstall: jaxlib
    Found existing installation: jaxlib 0.4.26+cuda12.cudnn89
    Uninstalling jaxlib-0.4.26+cuda12.cudnn89:
      Successfully uninstalled jaxlib-0.4.26+cuda12.cudnn89
  Attempting uninstall: jax
    Found existing installation: jax 0.4.26
    Uninstalling jax-0.4.26:
      Successfully uninstalled jax-0.4.26
Successfully installed jax-0.4.30 jaxlib-0.4.30


In [15]:
#import libraries
import jax
import jax.numpy as jnp
from jax import grad
import numpy as np
import tqdm

# Fitting an Exponential Distribution

In this example, we will consider fitting an exponential distribution on some randomly generated data and then observing the output. Recall that the PDF of an exponential distribution is
$$
  f(x; \lambda) = λe^{-λx}
$$

Using this, we can code this PDF using the appropriate operations from Jax and use it in another function that computes the Negative Log-Liklihood for us to minimize using gradient descent. Note that the loss function must consume both the data and the parameters of the distribution, but that the paramter, in this case $λ$ **must** come first in the argument list since we want to take its gradient for optimization purposes

In [3]:
def pdf(lam, x):
  return lam * jnp.exp(-lam * x)

def loss_function(lam, x, eps=0.000000001):
  components_before_log = pdf(lam, x)
  # take log of each component. We add a small corrective factor to
  #prevent numerical underflow in the log function
  components_after_log = jnp.log(components_before_log + eps)
  nll = -jnp.sum(components_after_log)
  return nll

In [4]:
rate_parameter = 1.4  # You can adjust this value as needed

# Sample 100 values from the exponential distribution
sample_size = 100
exponential_samples = np.random.exponential(1 / rate_parameter, sample_size)
exponential_samples = jnp.array(exponential_samples)

In [5]:
loss_function_grad = grad(loss_function, 0)

In [6]:
learning_rate = 0.01 # set learning rate
num_iterations = 5000 # set maximum number of iterations
current_guess = np.random.random() # set initial guess
print('Initial guess ', current_guess)
for i in tqdm.tqdm(range(num_iterations)): #tqdm is used to help us monitor progress
  current_gradient = loss_function_grad(current_guess, exponential_samples)
  update = learning_rate * current_gradient
  current_guess = current_guess - update
print('Final guess ', current_guess)

Initial guess  0.6556326151788402


100%|██████████| 5000/5000 [00:35<00:00, 139.65it/s]

Final guess  1.4509127





Note that the above result is close to our value for lambda. Let us now consider a two-parameter case with the [Gamma distribution](https://en.wikipedia.org/wiki/Gamma_distribution)

In [7]:
def pdf(k, theta, x):
  factor = 1 / (jax.scipy.special.gamma(k) * jnp.power(theta, k))
  return factor * (jnp.power(x, k - 1) * jnp.exp(-x/theta))

def loss_function(k, theta, x, eps=0.000001):
  components_before_log = pdf(k, theta, x)
  components_after_log = jnp.log(components_before_log + eps)
  nll = -jnp.sum(components_after_log)
  return nll

loss_function_grad_k = grad(loss_function, 0)
loss_function_grad_theta = grad(loss_function, 1)

In [8]:
k = 2.0  # Shape parameter
theta = 2.0  # Scale parameter

# Sample size
sample_size = 1000

# Sample from the Gamma distribution
gamma_samples = np.random.gamma(k, theta, sample_size)
gamma_samples = jnp.array(gamma_samples)

In [9]:
learning_rate = 0.001 # set learning rate
num_iterations = 1000 # set maximum number of iterations
k_guess = np.random.random() # set initial guess
theta_guess = np.random.random() # set initial guess
print('Initial guess ', (k_guess, theta_guess))
for i in tqdm.tqdm(range(num_iterations)): #tqdm is used to help us monitor progress

  k_grad = loss_function_grad_k(k_guess, theta_guess, gamma_samples)
  theta_grad = loss_function_grad_theta(k_guess, theta_guess, gamma_samples)

  k_guess = k_guess - learning_rate * k_grad
  theta_guess = theta_guess - learning_rate * theta_grad

print('\nFinal guess ', (k_guess, theta_guess))

Initial guess  (0.8046370082095061, 0.6335587655633199)


100%|██████████| 1000/1000 [00:29<00:00, 34.44it/s]


Final guess  (Array(1.9457594, dtype=float32, weak_type=True), Array(2.0648713, dtype=float32, weak_type=True))





To make our code more well-abstracted and simplier, we can instead pass an array as our paramter, with one component containing the value for $k$ and another with our value for $\theta$

In [10]:
def pdf(params, x):
  k = params[0]
  theta = params[1]
  factor = 1 / (jax.scipy.special.gamma(k) * jnp.power(theta, k))
  return factor * (jnp.power(x, k - 1) * jnp.exp(-x/theta))

def loss_function(params, x, eps=0.000001):
  components_before_log = pdf(params, x)
  components_after_log = jnp.log(components_before_log + eps)
  nll = -jnp.sum(components_after_log)
  return nll

loss_function_grad_params = grad(loss_function, 0) # only need one gradient function

In [11]:
learning_rate = 0.001 # set learning rate
num_iterations = 1000 # set maximum number of iterations
param_guess = np.random.random(size=2)
print('Initial guess ', param_guess)
for i in tqdm.tqdm(range(num_iterations)): #tqdm is used to help us monitor progress
  param_grad = loss_function_grad_params(param_guess, gamma_samples)
  param_guess = param_guess - learning_rate * param_grad # exploit broadcasting to simplify operation

print('\nFinal guess ', (k_guess, theta_guess))

Initial guess  [0.70806647 0.78696594]


100%|██████████| 1000/1000 [00:19<00:00, 50.36it/s]


Final guess  (Array(1.9457594, dtype=float32, weak_type=True), Array(2.0648713, dtype=float32, weak_type=True))





With this simplification of operations, we can abstract away our SGD loop into a function. We can also even abstract out our gradient function as an argument

In [16]:
def sgd(param_guess, grad_func, samples, learning_rate=0.001, num_iterations=1000):
  for i in tqdm.tqdm(range(num_iterations)): #tqdm is used to help us monitor progress
    param_grad = grad_func(param_guess, samples)
    param_guess = param_guess - learning_rate * param_grad # exploit broadcasting to simplify operation
  return param_guess

Using the above, we can even go further and abstract the entire MLE pipeline by supplying a PDF, samples, and number of parameters!

In [17]:
def mle(num_params, pdf, samples, learning_rate=0.001, num_iterations=1000):
  # create a function within a function
  # see https://realpython.com/python-functional-programming/
  def loss_function(params, x, eps=0.000001):
    components_before_log = pdf(params, x)
    components_after_log = jnp.log(components_before_log + eps)
    nll = -jnp.sum(components_after_log)
    return nll

  loss_function_grad_params = grad(loss_function, 0)
  param_guess = np.random.random(size=num_params)
  param_guess = sgd(param_guess, loss_function_grad_params, samples,
                    learning_rate=0.001, num_iterations=1000)
  return param_guess

In [18]:
def exponential_pdf(lam, x):
  return lam * jnp.exp(-lam * x)

lambda_estimate = mle(num_params=1, pdf=exponential_pdf, samples=exponential_samples)

100%|██████████| 1000/1000 [00:08<00:00, 117.37it/s]


In [19]:
print(lambda_estimate)

[1.4510068]


In [21]:
def gamma_pdf(params, x):
  k = params[0]
  theta = params[1]
  factor = 1 / (jax.scipy.special.gamma(k) * jnp.power(theta, k))
  return factor * (jnp.power(x, k - 1) * jnp.exp(-x/theta))

param_estimates = mle(num_params=2, pdf=gamma_pdf, samples=gamma_samples)
print('\nEstimate of k ', param_estimates[0])
print('Estimate of theta ', param_estimates[1])

100%|██████████| 1000/1000 [00:17<00:00, 56.03it/s]


Estimate of k  1.9457594
Estimate of theta  2.0648715



