<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_12_Gibbs_Sampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Introduction

Gibbs Sampling is a special case of the MCMC family of algorithms. It is widely used to sample from complex joint distributions when direct sampling is intractable but sampling from the conditional distributions is feasible. The algorithm is especially useful in Bayesian inference where one needs to generate samples from the posterior distribution.

## 3.1 Basic Idea

Suppose we wish to sample from a joint probability distribution $ p(x_1, x_2, \ldots, x_n) $. Direct sampling may be infeasible, but if we can easily sample from each conditional distribution:
- $ p(x_1 \mid x_2, \dots, x_n) $
- $ p(x_2 \mid x_1, x_3, \dots, x_n) $
- $ \cdots $
- $ p(x_n \mid x_1, \dots, x_{n-1}) $

then Gibbs Sampling offers a practical solution.

### 3.2 Algorithm Steps

1. **Initialization:** Choose initial values $ x_1^{(0)}, x_2^{(0)}, \dots, x_n^{(0)} $ (often arbitrarily).

2. **Iterative Sampling:** For iteration $ t = 1, 2, \dots, T $:
   - Sample
     $$
     x_1^{(t)} \sim p\Big(x_1 \mid x_2^{(t-1)}, x_3^{(t-1)}, \dots, x_n^{(t-1)}\Big)
     $$
   - Sample
     $$
     x_2^{(t)} \sim p\Big(x_2 \mid x_1^{(t)}, x_3^{(t-1)}, \dots, x_n^{(t-1)}\Big)
     $$
   - Continue in this fashion until
     $$
     x_n^{(t)} \sim p\Big(x_n \mid x_1^{(t)}, x_2^{(t)}, \dots, x_{n-1}^{(t)}\Big)
     $$

3. **Convergence:** Under regularity conditions, the chain $\{(x_1^{(t)}, \dots, x_n^{(t)})\}$ converges to the target distribution $ p(x_1, \dots, x_n) $.

### 3.3 Pseudocode

```python
# Pseudocode for Gibbs Sampling

# Initialize x[1], x[2], ..., x[n]
initialize x = [x_1^(0), x_2^(0), ..., x_n^(0)]

for t in 1 to T:
    x[1] = sample from p(x[1] | x[2]^(t-1), x[3]^(t-1), ..., x[n]^(t-1))
    x[2] = sample from p(x[2] | x[1]^(t),   x[3]^(t-1), ..., x[n]^(t-1))
    ...
    x[n] = sample from p(x[n] | x[1]^(t), x[2]^(t), ..., x[n-1]^(t))
```


## Requirements and Considerations

### 4.1 Markov Chains and Stationarity

- **Markov Property:** The next state depends only on the current state and not on the previous history.
- **Stationary Distribution:** The target distribution $ p(x_1, \dots, x_n) $ is invariant under the Gibbs sampling transition, meaning that once the chain has converged, successive samples are drawn from this distribution.

### 4.2 Detailed Balance and Ergodicity

- **Detailed Balance:** For any two states $ X $ and $ Y $, the transition probabilities satisfy:
  \[
  p(X) \, P(X \to Y) = p(Y) \, P(Y \to X)
  \]
  ensuring that the chain is reversible.
- **Ergodicity:** The chain must be irreducible and aperiodic so that it eventually reaches every part of the state space and the time averages converge to ensemble averages.

### 4.3 Convergence Considerations

- **Burn-in Period:** An initial number of samples may be discarded to allow the chain to reach its stationary distribution.
- **Thinning:** To reduce autocorrelation in the sample chain, one may retain only every $ k $-th sample.
- **Diagnostics:** Use trace plots, autocorrelation functions, or convergence tests (e.g., Gelman–Rubin diagnostics) to assess if the chain has converged.

## 6. Example: Gibbs Sampling for a Bivariate Normal Distribution

To illustrate the method, we consider a simple bivariate normal distribution defined by,
$$
\begin{pmatrix} x \\ y \end{pmatrix} \sim \mathcal{N} \Bigg(\begin{pmatrix} \mu_x \\ \mu_y \end{pmatrix}, \begin{pmatrix} \sigma_x^2 & \rho \sigma_x \sigma_y \\ \rho \sigma_x \sigma_y & \sigma_y^2 \end{pmatrix}\Bigg).
$$

### 6.1 Deriving the Conditional Distributions

For a bivariate normal:
- The conditional distribution for $ x $ given $ y $ is:
  $$
  x \mid y \sim \mathcal{N}\left(\mu_x + \rho \frac{\sigma_x}{\sigma_y}(y - \mu_y),\; (1-\rho^2)\sigma_x^2\right)
  $$
- Likewise, the conditional distribution for $ y $ given $ x $ is:
  $$
  y \mid x \sim \mathcal{N}\left(\mu_y + \rho \frac{\sigma_y}{\sigma_x}(x - \mu_x),\; (1-\rho^2)\sigma_y^2\right)
  $$

In [None]:
import jax
import jax.numpy as jnp
from jax import random, lax, jit
import matplotlib.pyplot as plt

# Set parameters for the bivariate normal distribution
mu_x, mu_y = 0.0, 0.0
sigma_x, sigma_y = 1.0, 1.0
rho = 0.8
n_samples = 5000

def gibbs_step(state, key):
    """
    Performs one Gibbs update step.

    Parameters:
        state: tuple (x, y)
        key: PRNGKey for randomness
    Returns:
        new_state: updated (x, y)
    """
    x, y = state
    # Split the key for independent randomness for each conditional sample
    key_x, key_y = random.split(key)

    # Sample x given y
    mu_cond_x = mu_x + rho * (sigma_x / sigma_y) * (y - mu_y)
    sigma_cond_x = jnp.sqrt((1 - rho**2) * sigma_x**2)
    x_new = mu_cond_x + sigma_cond_x * random.normal(key_x)

    # Sample y given the newly sampled x
    mu_cond_y = mu_y + rho * (sigma_y / sigma_x) * (x_new - mu_x)
    sigma_cond_y = jnp.sqrt((1 - rho**2) * sigma_y**2)
    y_new = mu_cond_y + sigma_cond_y * random.normal(key_y)

    return (x_new, y_new)

@jit
def run_gibbs(key, initial_state, n_samples):
    """
    Runs the Gibbs sampler for n_samples iterations using jax.lax.scan.

    Parameters:
        key: PRNGKey for randomness.
        initial_state: tuple (x, y) for initial state.
        n_samples: number of samples to generate.
    Returns:
        states: tuple of arrays (x_samples, y_samples) with shape (n_samples,)
    """
    # Prepare a sequence of keys for the scan
    keys = random.split(key, n_samples)

    def scan_body(state, key):
        new_state = gibbs_step(state, key)
        # The output of scan (new state) is collected; here we also output new_state.
        return new_state, new_state

    final_state, states = lax.scan(scan_body, initial_state, keys)
    return states

# Set up the PRNG key and initial state
seed = 42
key = random.PRNGKey(seed)
initial_state = (0.0, 0.0)

# Run the Gibbs sampler
states = run_gibbs(key, initial_state, n_samples)
x_samples, y_samples = states  # each is a JAX array of shape (n_samples,)

# Transfer arrays from device to host for plotting
x_samples = jax.device_get(x_samples)
y_samples = jax.device_get(y_samples)

# Plot the samples
plt.figure(figsize=(8, 6))
plt.scatter(x_samples, y_samples, alpha=0.3, s=10)
plt.xlabel("x")
plt.ylabel("y")
plt.title("Gibbs Sampling: Bivariate Normal Samples (BlackJAX Style)")
plt.show()


## Advantages and Limitations

### Advantages
- **Simplicity:** Straightforward to implement when the full conditional distributions are known.
- **Efficiency:** Can be very effective in high-dimensional settings with conditional independence structures.
- **Theoretical Guarantees:** Under proper conditions, convergence to the target distribution is guaranteed.

### Limitations
- **Slow Mixing:** The chain may mix slowly when variables are strongly correlated.
- **Dependency on Conditionals:** Requires that the conditional distributions are easy to sample from.
- **Curse of Dimensionality:** In cases where the conditionals are high-dimensional or not analytically tractable, alternative methods may be preferred.