# Location: notebooks/06_special_topics/18_probabilistic_gradients.ipynb

## Probabilistic Programming and Stochastic Gradients in JAX

This notebook explores probabilistic programming concepts in JAX, including variational inference, stochastic gradients, and gradient estimation techniques for discrete and continuous random variables.

## Basic Probability Distributions and Sampling

```python
import jax
import jax.numpy as jnp
from jax import random, vmap, grad
from jax.scipy import stats
import matplotlib.pyplot as plt
from functools import partial

# Basic sampling and probability computations
def demonstrate_basic_distributions():
    """Demonstrate basic probability distributions in JAX"""
    key = random.PRNGKey(42)
    
    # Continuous distributions
    keys = random.split(key, 5)
    
    # Normal distribution
    normal_samples = random.normal(keys[0], (1000,))
    normal_logprob = stats.norm.logpdf(normal_samples, 0.0, 1.0)
    
    # Beta distribution  
    beta_samples = random.beta(keys[1], 2.0, 5.0, (1000,))
    beta_logprob = stats.beta.logpdf(beta_samples, 2.0, 5.0)
    
    # Gamma distribution
    gamma_samples = random.gamma(keys[2], 2.0, (1000,))
    gamma_logprob = stats.gamma.logpdf(gamma_samples, 2.0)
    
    # Discrete distributions
    categorical_samples = random.categorical(keys[3], jnp.array([0.1, 0.3, 0.4, 0.2]), (1000,))
    
    # Binomial (using Bernoulli approximation)
    bernoulli_samples = random.bernoulli(keys[4], 0.3, (1000,))
    
    return {
        'normal': (normal_samples, normal_logprob),
        'beta': (beta_samples, beta_logprob), 
        'gamma': (gamma_samples, gamma_logprob),
        'categorical': categorical_samples,
        'bernoulli': bernoulli_samples
    }

# Test basic distributions
distributions = demonstrate_basic_distributions()

print("Distribution Statistics:")
for name, data in distributions.items():
    if isinstance(data, tuple):
        samples, logprobs = data
        print(f"{name:12}: mean={jnp.mean(samples):.3f}, std={jnp.std(samples):.3f}, "
              f"logprob_mean={jnp.mean(logprobs):.3f}")
    else:
        samples = data
        print(f"{name:12}: mean={jnp.mean(samples):.3f}, unique_vals={len(jnp.unique(samples))}")
```

## Stochastic Computation Graphs

```python
def create_stochastic_graph():
    """Create a stochastic computation graph"""
    
    def stochastic_forward(key, x, noise_scale=0.1):
        """Forward pass with stochastic operations"""
        k1, k2, k3 = random.split(key, 3)
        
        # Add Gaussian noise to input
        noisy_x = x + noise_scale * random.normal(k1, x.shape)
        
        # Stochastic linear transformation
        w_noise = random.normal(k2, (x.shape[-1], 64)) * 0.1
        w_deterministic = jnp.ones((x.shape[-1], 64)) * 0.5
        w = w_deterministic + w_noise
        
        h = jnp.tanh(noisy_x @ w)
        
        # Dropout-like stochastic masking
        mask = random.bernoulli(k3, 0.8, h.shape)
        h_masked = h * mask / 0.8  # Scale to maintain expectation
        
        # Final linear layer (deterministic)
        output_w = jnp.ones((64, 10)) * 0.1
        output = h_masked @ output_w
        
        return output, {
            'noisy_input': noisy_x,
            'weights': w,
            'hidden': h,
            'mask': mask,
            'masked_hidden': h_masked
        }
    
    def expected_output(x, n_samples=100):
        """Estimate expected output via Monte Carlo"""
        keys = random.split(random.PRNGKey(0), n_samples)
        
        def single_forward(key):
            output, _ = stochastic_forward(key, x)
            return output
        
        outputs = vmap(single_forward)(keys)
        return jnp.mean(outputs, axis=0), jnp.std(outputs, axis=0)
    
    return stochastic_forward, expected_output

# Test stochastic computation graph
stoch_forward, expected_forward = create_stochastic_graph()

# Test input
test_x = jnp.array([[1.0, 2.0, 3.0]])

# Single stochastic forward pass
single_output, intermediates = stoch_forward(random.PRNGKey(123), test_x)
print(f"Single forward output shape: {single_output.shape}")
print(f"Output values: {single_output[0][:5]}")  # First 5 values

# Expected output via Monte Carlo
expected_out, std_out = expected_forward(test_x)
print(f"Expected output: {expected_out[0][:5]}")
print(f"Output std: {std_out[0][:5]}")

# Compare deterministic vs stochastic
print(f"Coefficient of variation: {jnp.mean(std_out / jnp.abs(expected_out)):.3f}")
```

## Variational Inference with Reparameterization

```python
def create_variational_autoencoder():
    """Create a simple VAE for demonstrating reparameterization trick"""
    
    def encoder(params, x):
        """Encoder network: x -> (mu, log_sigma)"""
        h1 = jax.nn.relu(x @ params['enc_w1'] + params['enc_b1'])
        h2 = jax.nn.relu(h1 @ params['enc_w2'] + params['enc_b2'])
        
        mu = h2 @ params['enc_mu_w'] + params['enc_mu_b']
        log_sigma = h2 @ params['enc_sigma_w'] + params['enc_sigma_b']
        
        return mu, log_sigma
    
    def decoder(params, z):
        """Decoder network: z -> x_recon"""
        h1 = jax.nn.relu(z @ params['dec_w1'] + params['dec_b1'])
        h2 = jax.nn.relu(h1 @ params['dec_w2'] + params['dec_b2'])
        x_recon = jax.nn.sigmoid(h2 @ params['dec_w3'] + params['dec_b3'])
        
        return x_recon
    
    def reparameterize(key, mu, log_sigma):
        """Reparameterization trick: z = mu + sigma * epsilon"""
        epsilon = random.normal(key, mu.shape)
        sigma = jnp.exp(log_sigma)
        z = mu + sigma * epsilon
        return z
    
    def vae_loss(params, key, x):
        """VAE loss with KL divergence and reconstruction loss"""
        # Encode
        mu, log_sigma = encoder(params, x)
        
        # Reparameterize
        z = reparameterize(key, mu, log_sigma)
        
        # Decode
        x_recon = decoder(params, z)
        
        # Reconstruction loss (binary cross-entropy)
        recon_loss = -jnp.sum(x * jnp.log(x_recon + 1e-8) + 
                             (1 - x) * jnp.log(1 - x_recon + 1e-8))
        
        # KL divergence loss
        kl_loss = 0.5 * jnp.sum(jnp.exp(2 * log_sigma) + mu**2 - 1 - 2 * log_sigma)
        
        total_loss = recon_loss + kl_loss
        
        return total_loss, {
            'recon_loss': recon_loss,
            'kl_loss': kl_loss,
            'mu': mu,
            'log_sigma': log_sigma,
            'z': z,
            'x_recon': x_recon
        }
    
    def init_vae_params(key, input_dim=784, latent_dim=20, hidden_dim=400):
        """Initialize VAE parameters"""
        keys = random.split(key, 10)
        
        params = {
            # Encoder
            'enc_w1': random.normal(keys[0], (input_dim, hidden_dim)) * 0.01,
            'enc_b1': jnp.zeros(hidden_dim),
            'enc_w2': random.normal(keys[1], (hidden_dim, hidden_dim)) * 0.01,
            'enc_b2': jnp.zeros(hidden_dim),
            'enc_mu_w': random.normal(keys[2], (hidden_dim, latent_dim)) * 0.01,
            'enc_mu_b': jnp.zeros(latent_dim),
            'enc_sigma_w': random.normal(keys[3], (hidden_dim, latent_dim)) * 0.01,
            'enc_sigma_b': jnp.zeros(latent_dim),
            
            # Decoder
            'dec_w1': random.normal(keys[4], (latent_dim, hidden_dim)) * 0.01,
            'dec_b1': jnp.zeros(hidden_dim),
            'dec_w2': random.normal(keys[5], (hidden_dim, hidden_dim)) * 0.01,
            'dec_b2': jnp.zeros(hidden_dim),
            'dec_w3': random.normal(keys[6], (hidden_dim, input_dim)) * 0.01,
            'dec_b3': jnp.zeros(input_dim)
        }
        
        return params
    
    return init_vae_params, vae_loss, encoder, decoder, reparameterize

# Create VAE
init_vae, vae_loss_fn, encode_fn, decode_fn, reparam_fn = create_variational_autoencoder()

# Initialize parameters
vae_params = init_vae(random.PRNGKey(456), input_dim=28, latent_dim=8, hidden_dim=64)

print("VAE Parameters:")
for name, param in vae_params.items():
    print(f"  {name}: {param.shape}")

# Test VAE forward pass
test_input = random.uniform(random.PRNGKey(789), (28,))  # Mock 28D input
loss, info = vae_loss_fn(vae_params, random.PRNGKey(101), test_input)

print(f"\nVAE Forward Pass:")
print(f"Total loss: {loss:.4f}")
print(f"Reconstruction loss: {info['recon_loss']:.4f}")
print(f"KL divergence: {info['kl_loss']:.4f}")
print(f"Latent mean: {info['mu'][:4]}")  # First 4 dimensions
print(f"Latent log_sigma: {info['log_sigma'][:4]}")
```

## Gradient Estimation for Discrete Variables

```python
def demonstrate_gradient_estimation():
    """Demonstrate gradient estimation techniques for discrete variables"""
    
    # REINFORCE (likelihood ratio) estimator
    def reinforce_estimator(params, key, n_samples=1000):
        """REINFORCE gradient estimator"""
        
        def sample_and_loss(key):
            # Sample discrete action from categorical distribution
            logits = params['logits']
            probs = jax.nn.softmax(logits)
            action = random.categorical(key, logits)
            
            # Reward function (to maximize)
            reward = jnp.where(action == 2, 1.0, -0.1)  # Prefer action 2
            
            # REINFORCE gradient: grad_log_prob * reward
            log_prob = jax.nn.log_softmax(logits)[action]
            
            return -reward, log_prob, reward  # Negative for minimization
        
        keys = random.split(key, n_samples)
        losses, log_probs, rewards = vmap(sample_and_loss)(keys)
        
        # REINFORCE estimator
        def loss_fn(params):
            logits = params['logits']
            log_probs_param = jax.nn.log_softmax(logits)
            
            # Use pre-sampled actions and rewards
            return jnp.mean(-rewards * log_probs_param[actions])
        
        # This is a simplified version - normally we'd use the actual sampled actions
        mean_loss = jnp.mean(losses)
        grad_estimate = jnp.mean(vmap(lambda lp, r: -r * grad(lambda x: x)(lp))(log_probs, rewards))
        
        return mean_loss, grad_estimate, jnp.mean(rewards)
    
    # Gumbel-Softmax (reparameterization for discrete)
    def gumbel_softmax_estimator(params, key, temperature=1.0):
        """Gumbel-Softmax reparameterization trick"""
        
        def gumbel_softmax_sample(logits, key, temperature):
            # Sample Gumbel noise
            gumbel_noise = -jnp.log(-jnp.log(random.uniform(key, logits.shape) + 1e-8) + 1e-8)
            
            # Gumbel-Softmax
            y = jax.nn.softmax((logits + gumbel_noise) / temperature)
            return y
        
        def differentiable_loss(params, key):
            logits = params['logits']
            
            # Sample using Gumbel-Softmax
            soft_sample = gumbel_softmax_sample(logits, key, temperature)
            
            # Differentiable reward (prefer index 2)
            target = jnp.array([0.0, 0.0, 1.0, 0.0])  # One-hot for action 2
            reward = jnp.dot(soft_sample, target) - 0.1 * jnp.dot(soft_sample, 1 - target)
            
            return -reward  # Negative for minimization
        
        loss = differentiable_loss(params, key)
        grad_fn = grad(differentiable_loss)
        gradient = grad_fn(params, key)
        
        return loss, gradient
    
    # Control variate for REINFORCE
    def reinforce_with_baseline(params, key, n_samples=1000):
        """REINFORCE with baseline to reduce variance"""
        
        def sample_and_evaluate(key):
            logits = params['logits']
            action = random.categorical(key, logits)
            
            # Reward and baseline
            reward = jnp.where(action == 2, 1.0, -0.1)
            baseline = params['baseline']  # Learned baseline
            
            # Advantage
            advantage = reward - baseline
            
            log_prob = jax.nn.log_softmax(logits)[action]
            
            return reward, advantage, log_prob, action
        
        keys = random.split(key, n_samples)
        rewards, advantages, log_probs, actions = vmap(sample_and_evaluate)(keys)
        
        # Policy gradient with baseline
        policy_grad = jnp.mean(advantages * log_probs)
        
        # Baseline update (MSE with rewards)
        baseline_loss = jnp.mean((rewards - params['baseline']) ** 2)
        
        return policy_grad, baseline_loss, jnp.mean(rewards)
    
    return reinforce_estimator, gumbel_softmax_estimator, reinforce_with_baseline

# Test gradient estimation techniques
reinforce_est, gumbel_est, reinforce_baseline = demonstrate_gradient_estimation()

# Initialize parameters for discrete optimization
discrete_params = {
    'logits': jnp.array([0.1, 0.2, -0.3, 0.0]),  # 4 possible actions
    'baseline': 0.0
}

print("Gradient Estimation for Discrete Variables:")

# Test REINFORCE
key = random.PRNGKey(112)
reinforce_loss, reinforce_grad, avg_reward = reinforce_est(discrete_params, key)
print(f"REINFORCE - Loss: {reinforce_loss:.4f}, Avg Reward: {avg_reward:.4f}")

# Test Gumbel-Softmax  
gumbel_loss, gumbel_grad = gumbel_est(discrete_params, key)
print(f"Gumbel-Softmax - Loss: {gumbel_loss:.4f}")
print(f"Gumbel gradient: {gumbel_grad['logits']}")

# Test REINFORCE with baseline
policy_grad, baseline_loss, baseline_reward = reinforce_baseline(discrete_params, key)
print(f"REINFORCE + Baseline - Policy grad: {policy_grad:.4f}, "
      f"Baseline loss: {baseline_loss:.4f}, Reward: {baseline_reward:.4f}")
```

## Stochastic Optimization and Natural Gradients

```python
def create_natural_gradient_optimizer():
    """Create natural gradient optimizer for probabilistic models"""
    
    def fisher_information_matrix(params, data_batch):
        """Approximate Fisher Information Matrix"""
        
        def single_data_score(params, x):
            # Log-likelihood for single data point
            mu = params['mu']
            log_sigma = params['log_sigma']
            sigma = jnp.exp(log_sigma)
            
            log_lik = stats.norm.logpdf(x, mu, sigma)
            return jnp.sum(log_lik)
        
        # Score function (gradient of log-likelihood)
        score_fn = grad(single_data_score)
        
        # Compute scores for batch
        scores = vmap(score_fn, in_axes=(None, 0))(params, data_batch)
        
        # Fisher Information = E[score * score^T]
        # Approximate with sample covariance of scores
        def vectorize_pytree(tree):
            leaves, treedef = jax.tree_flatten(tree)
            return jnp.concatenate([leaf.flatten() for leaf in leaves])
        
        score_vectors = vmap(vectorize_pytree)(scores)
        fisher_matrix = jnp.cov(score_vectors.T)
        
        return fisher_matrix, score_vectors
    
    def natural_gradient_step(params, data_batch, lr=0.01, damping=1e-4):
        """Natural gradient update step"""
        
        # Standard gradient
        def neg_log_likelihood(params):
            mu = params['mu'] 
            log_sigma = params['log_sigma']
            sigma = jnp.exp(log_sigma)
            
            log_liks = stats.norm.logpdf(data_batch, mu, sigma)
            return -jnp.mean(jnp.sum(log_liks, axis=1))  # Negative log-likelihood
        
        std_grad = grad(neg_log_likelihood)(params)
        
        # Fisher information matrix
        fisher_matrix, scores = fisher_information_matrix(params, data_batch)
        
        # Vectorize standard gradient
        std_grad_vec, treedef = jax.tree_flatten(std_grad)
        std_grad_flat = jnp.concatenate([g.flatten() for g in std_grad_vec])
        
        # Natural gradient = F^(-1) * standard_gradient
        damped_fisher = fisher_matrix + damping * jnp.eye(fisher_matrix.shape[0])
        nat_grad_flat = jnp.linalg.solve(damped_fisher, std_grad_flat)
        
        # Reconstruct gradient tree structure
        shapes = [g.shape for g in std_grad_vec]
        sizes = [g.size for g in std_grad_vec]
        
        start_idx = 0
        nat_grad_leaves = []
        for size, shape in zip(sizes, shapes):
            nat_grad_leaves.append(nat_grad_flat[start_idx:start_idx+size].reshape(shape))
            start_idx += size
        
        nat_grad = jax.tree_unflatten(treedef, nat_grad_leaves)
        
        # Update parameters
        new_params = jax.tree_map(lambda p, g: p - lr * g, params, nat_grad)
        
        return new_params, std_grad, nat_grad
    
    return natural_gradient_step

# Test natural gradients
natural_grad_step = create_natural_gradient_optimizer()

# Generate synthetic data from known distribution
true_mu, true_sigma = 2.0, 1.5
key = random.PRNGKey(333)
data = random.normal(key, (100,)) * true_sigma + true_mu
data = data[:, None]  # Make 2D for batch processing

# Initialize parameters
prob_params = {
    'mu': jnp.array([0.0]),
    'log_sigma': jnp.array([0.0])  # log(1.0) = 0.0
}

print("Natural Gradient Optimization:")
print(f"True parameters: mu={true_mu}, sigma={true_sigma}")
print(f"Initial parameters: mu={prob_params['mu'][0]:.3f}, sigma={jnp.exp(prob_params['log_sigma'][0]):.3f}")

# Optimization loop
for step in range(10):
    prob_params, std_g, nat_g = natural_grad_step(prob_params, data, lr=0.1)
    
    current_mu = prob_params['mu'][0]
    current_sigma = jnp.exp(prob_params['log_sigma'][0])
    
    if step % 2 == 0:
        print(f"Step {step}: mu={current_mu:.3f}, sigma={current_sigma:.3f}")

print(f"Final parameters: mu={current_mu:.3f}, sigma={current_sigma:.3f}")
print(f"Parameter errors: mu_err={abs(current_mu - true_mu):.3f}, "
      f"sigma_err={abs(current_sigma - true_sigma):.3f}")
```

## Bayesian Neural Networks

```python
def create_bayesian_neural_network():
    """Create Bayesian Neural Network with variational inference"""
    
    def init_bnn_params(key, layer_sizes):
        """Initialize BNN parameters with variational distributions"""
        keys = random.split(key, len(layer_sizes) - 1)
        
        params = []
        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            layer_key = keys[i]
            w_key, b_key = random.split(layer_key)
            
            # Weight parameters: mean and log_std
            w_mu = random.normal(w_key, (in_size, out_size)) * 0.1
            w_log_std = jnp.full((in_size, out_size), -2.0)  # Small initial std
            
            # Bias parameters
            b_mu = jnp.zeros(out_size)
            b_log_std = jnp.full(out_size, -2.0)
            
            params.append({
                'w_mu': w_mu, 'w_log_std': w_log_std,
                'b_mu': b_mu, 'b_log_std': b_log_std
            })
        
        return params
    
    def sample_bnn_weights(params, key):
        """Sample concrete weights from variational distributions"""
        keys = random.split(key, len(params))
        
        sampled_params = []
        for layer_params, layer_key in zip(params, keys):
            w_key, b_key = random.split(layer_key)
            
            # Sample weights
            w_eps = random.normal(w_key, layer_params['w_mu'].shape)
            w = layer_params['w_mu'] + jnp.exp(layer_params['w_log_std']) * w_eps
            
            # Sample biases
            b_eps = random.normal(b_key, layer_params['b_mu'].shape)
            b = layer_params['b_mu'] + jnp.exp(layer_params['b_log_std']) * b_eps
            
            sampled_params.append({'w': w, 'b': b})
        
        return sampled_params
    
    def bnn_forward(sampled_params, x):
        """Forward pass through BNN with sampled weights"""
        h = x
        for i, layer in enumerate(sampled_params[:-1]):
            h = jax.nn.relu(h @ layer['w'] + layer['b'])
        
        # Final layer (no activation)
        output = h @ sampled_params[-1]['w'] + sampled_params[-1]['b']
        return output
    
    def bnn_loss(params, key, x_batch, y_batch, n_samples=5):
        """BNN loss with KL divergence and likelihood terms"""
        
        # KL divergence between q(w) and prior p(w) = N(0, 1)
        kl_loss = 0.0
        for layer in params:
            # KL for weights
            w_mu, w_log_std = layer['w_mu'], layer['w_log_std']
            w_var = jnp.exp(2 * w_log_std)
            w_kl = 0.5 * jnp.sum(w_mu**2 + w_var - 1 - 2*w_log_std)
            
            # KL for biases  
            b_mu, b_log_std = layer['b_mu'], layer['b_log_std']
            b_var = jnp.exp(2 * b_log_std)
            b_kl = 0.5 * jnp.sum(b_mu**2 + b_var - 1 - 2*b_log_std)
            
            kl_loss += w_kl + b_kl
        
        # Likelihood term via Monte Carlo sampling
        keys = random.split(key, n_samples)
        
        def single_sample_loss(key):
            sampled_weights = sample_bnn_weights(params, key)
            pred = bnn_forward(sampled_weights, x_batch)
            return jnp.mean((pred - y_batch) ** 2)
        
        likelihood_losses = vmap(single_sample_loss)(keys)
        avg_likelihood_loss = jnp.mean(likelihood_losses)
        
        # Total ELBO loss
        total_loss = avg_likelihood_loss + kl_loss / x_batch.shape[0]
        
        return total_loss, {
            'likelihood_loss': avg_likelihood_loss,
            'kl_loss': kl_loss,
            'per_sample_losses': likelihood_losses
        }
    
    def bnn_predict(params, key, x, n_samples=50):
        """BNN prediction with uncertainty quantification"""
        keys = random.split(key, n_samples)
        
        def single_prediction(key):
            sampled_weights = sample_bnn_weights(params, key)
            return bnn_forward(sampled_weights, x)
        
        predictions = vmap(single_prediction)(keys)
        
        mean_pred = jnp.mean(predictions, axis=0)
        std_pred = jnp.std(predictions, axis=0)
        
        return mean_pred, std_pred, predictions
    
    return init_bnn_params, bnn_loss, bnn_predict

# Create and test Bayesian Neural Network
init_bnn, bnn_loss_fn, bnn_predict_fn = create_bayesian_neural_network()

# Initialize BNN
layer_sizes = [1, 32, 32, 1]
bnn_params = init_bnn(random.PRNGKey(444), layer_sizes)

print("Bayesian Neural Network:")
print(f"Layer sizes: {layer_sizes}")
print(f"Number of variational parameters per layer:")
for i, layer in enumerate(bnn_params):
    n_params = sum(param.size for param in layer.values())
    print(f"  Layer {i}: {n_params} parameters")

# Generate synthetic regression data
key = random.PRNGKey(555)
x_train = random.uniform(key, (50, 1)) * 4 - 2  # [-2, 2]
y_train = x_train**2 + 0.1 * random.normal(random.split(key)[0], x_train.shape)

# Test BNN forward pass
loss, info = bnn_loss_fn(bnn_params, random.PRNGKey(666), x_train, y_train)
print(f"\nBNN Loss Components:")
print(f"Total loss: {loss:.4f}")
print(f"Likelihood loss: {info['likelihood_loss']:.4f}")
print(f"KL loss: {info['kl_loss']:.4f}")

# Test BNN prediction
x_test = jnp.array([[-1.5], [0.0], [1.5]])
mean_pred, std_pred, all_preds = bnn_predict_fn(bnn_params, random.PRNGKey(777), x_test)

print(f"\nBNN Predictions (with uncertainty):")
for i, (x_val, mean_val, std_val) in enumerate(zip(x_test.flatten(), mean_pred.flatten(), std_pred.flatten())):
    true_val = x_val**2
    print(f"x={x_val:4.1f}: pred={mean_val:6.3f}±{std_val:5.3f}, true={true_val:6.3f}")

# Training step for BNN
def bnn_train_step(params, key, x_batch, y_batch, lr=0.01):
    """Single training step for BNN"""
    loss, grads = jax.value_and_grad(lambda p: bnn_loss_fn(p, key, x_batch, y_batch)[0])(params)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grads)
    return new_params, loss

# Quick training demo
print("\nBNN Training:")
for epoch in range(5):
    key = random.PRNGKey(epoch + 1000)
    bnn_params, loss = bnn_train_step(bnn_params, key, x_train, y_train, lr=0.1)
    print(f"Epoch {epoch}: Loss = {loss:.4f}")
```

## Summary

In this notebook, we explored probabilistic programming and stochastic gradients in JAX:

**Core Concepts:**
- **Probability Distributions**: Sampling and density evaluation
- **Stochastic Computation**: Graphs with randomness and noise
- **Reparameterization Trick**: Making sampling differentiable  
- **Gradient Estimation**: REINFORCE, Gumbel-Softmax, control variates

**Variational Methods:**
- **Variational Autoencoders**: Latent variable models with reparameterization
- **Variational Inference**: Approximate posterior inference
- **Natural Gradients**: Using Fisher information for efficient optimization
- **Bayesian Neural Networks**: Weight uncertainty quantification

**Gradient Estimation Techniques:**
- **REINFORCE**: Likelihood ratio estimator for discrete variables
- **Control Variates**: Variance reduction with baselines
- **Gumbel-Softmax**: Continuous relaxation of discrete distributions
- **Pathwise Derivatives**: Direct differentiation through continuous paths

**Applications:**
- **Generative Modeling**: VAEs and normalizing flows  
- **Reinforcement Learning**: Policy gradient methods
- **Uncertainty Quantification**: Bayesian neural networks
- **Stochastic Optimization**: Natural gradient methods

**Key Benefits:**
- **End-to-End Differentiability**: Through stochastic computation graphs
- **Variance Control**: Multiple techniques for reducing gradient variance  
- **Uncertainty Quantification**: Principled handling of model uncertainty
- **Flexible Modeling**: Support for complex probabilistic models

Probabilistic programming in JAX enables sophisticated Bayesian modeling and inference while maintaining the efficiency and flexibility of automatic differentiation and just-in-time compilation.