In [4]:
import jax
import jax.numpy as jnp
import numpy as np
import optax
from typing import Sequence, Callable, Tuple, Union, Dict, Any
import flax
from flax import linen as nn
from flax.training import train_state
import matplotlib.pyplot as plt

In [5]:
def gaussian_mixture_density(x):
    """Example: 2-component Gaussian mixture"""
    mu1 = jnp.array([-5.0])
    mu2 = jnp.array([5.0])
    sigma = 1.0
    
    log_p1 = -0.5 * jnp.sum(((x - mu1) / sigma) ** 2)
    log_p2 = -0.5 * jnp.sum(((x - mu2) / sigma) ** 2)
    return jnp.logaddexp(log_p1, log_p2)  # log(exp(log_p1) + exp(log_p2))

In [6]:
def amh(target_log_pdf, key: jax.random.PRNGKey, d: int, m: int, gamma_init: float, alpha_star: float = 0.234):
    """
    Adaptive Metropolis-Hastings algorithm for warm-up
    
    Args:
        target_log_pdf: Function that returns log probability of target distribution
        key: JAX random key
        d: Dimension of the parameter space
        m: Number of iterations
        gamma_init: Initial adaptation rate
        alpha_star: Target acceptance rate
        
    Returns:
        samples, means, covariances, scale factors, acceptance indicators
    """
    x = jnp.zeros((d,))
    mu = jnp.zeros((d,))
    sigma = jnp.eye(d)
    lam = 1.0

    samples = jnp.zeros((m, d))
    means = jnp.zeros((m, d))
    sigmas = jnp.zeros((m, d, d)).at[0].set(jnp.eye(d))
    lambdas = jnp.zeros((m,))
    accept_status = jnp.zeros((m,), dtype=bool)

    def gamma_(g, i):
        return g * (1 - (i / m))

    def step(carry, i):
        x, gamma, mu, sigma, lam, key = carry
        
        sample_key, uni_key, key = jax.random.split(key, 3)

        # Sample from the proposal distribution
        x_prop = x + jax.random.multivariate_normal(sample_key, mean=jnp.zeros(d), cov=lam * sigma)

        # Compute the acceptance probability (in log space)
        log_alpha = target_log_pdf(x_prop) - target_log_pdf(x)
        alpha = jnp.minimum(0.0, log_alpha)

        # Accept or reject the new sample
        accept = jnp.log(jax.random.uniform(uni_key)) < alpha
        x_new = jnp.where(accept, x_prop, x)

        # Update the parameters
        mu_new = mu + gamma * (x_new - mu)
        sigma_new = sigma + gamma * ((x_new - mu)[:, None] @ (x_new - mu)[None, :] - sigma)
        log_lam = jnp.log(lam) + gamma * (jnp.exp(alpha) - alpha_star)
        lam_new = jnp.exp(log_lam)

        gamma = gamma_(gamma, i)

        outputs = (x_new, mu_new, sigma_new, lam_new, accept)

        return (x_new, gamma, mu_new, sigma_new, lam_new, key), outputs

    carry = (x, gamma_init, mu, sigma, lam, key)
    _, (xs, mus, Sigmas, lambdas_, acc) = jax.lax.scan(
        step, carry, jnp.arange(1, m)
    )

    # Add initial values
    samples = samples.at[1:].set(xs)
    means = means.at[1:].set(mus)
    sigmas = sigmas.at[1:].set(Sigmas)
    lambdas = lambdas.at[1:].set(lambdas_)
    accept_status = accept_status.at[1:].set(acc)

    return samples, means, sigmas, lambdas, accept_status

In [13]:
class PhiNetwork(nn.Module):
    """Neural network for the proposal distribution"""
    features: Sequence[int]
    mean: jnp.ndarray
    sigma: jnp.ndarray
    l: float = 10.0
    
    @nn.compact
    def __call__(self, x):
        # Computing square root of the covariance matrix
        sigma_reg = self.sigma + 1e-6 * jnp.eye(self.sigma.shape[0])  # Regularization
        eigvals, eigvecs = jnp.linalg.eigh(sigma_reg)
        inv_sqrt_sigma = eigvecs @ jnp.diag(1.0 / jnp.sqrt(jnp.maximum(eigvals, 1e-6))) @ eigvecs.T
        sqrt_sigma = eigvecs @ jnp.diag(jnp.sqrt(jnp.maximum(eigvals, 1e-6))) @ eigvecs.T
        
        # Neural network for nu_theta
        h = x
        for feat in self.features:
            h = nn.relu(nn.Dense(feat, kernel_init=nn.initializers.xavier_uniform())(h))
        nu_x = nn.Dense(self.sigma.shape[0], kernel_init=nn.initializers.xavier_uniform())(h)
        
        # Define psi (mean + transformation from neural network)
        psi = self.mean + sqrt_sigma @ nu_x
        
        # Apply the regularity-ensuring transformation
        eta = jnp.sum(jnp.square(inv_sqrt_sigma @ (x - self.mean))) / self.l**2
        gamma = self.compute_gamma(eta)
        
        # Final phi output
        phi = psi + gamma * (x - psi)
        return phi
    
    def compute_gamma(self, eta):
        """Compute the gamma function for regularity"""
        return jnp.where(eta <= 0.5, 0.0,
               jnp.where(eta >= 1.0, 1.0,
               1.0 / (1.0 + jnp.exp(-(4.0*eta - 3.0)/(4.0*eta*eta - 6.0*eta + 2.0)))))


class ActorNetwork(nn.Module):
    """Actor network for RLMH"""
    features: Sequence[int]
    mean: jnp.ndarray
    sigma: jnp.ndarray
    l: float = 10.0
    
    def setup(self):
        self.phi_net = PhiNetwork(features=self.features, mean=self.mean, sigma=self.sigma, l=self.l)
    
    def __call__(self, state):
        d = self.mean.shape[0]
        x = state[:d]
        x_prop = state[d:]
        
        phi_x = self.phi_net(x)
        phi_prop = self.phi_net(x_prop)
        
        return jnp.concatenate([phi_x, phi_prop], axis=0)


class CriticNetwork(nn.Module):
    """Critic network for RLMH"""
    features: Sequence[int]
    
    @nn.compact
    def __call__(self, state, action):
        # Concatenate state and action
        x = jnp.concatenate([state, action], axis=0)
        
        # Forward pass through hidden layers
        for feat in self.features:
            x = nn.relu(nn.Dense(feat, kernel_init=nn.initializers.xavier_uniform())(x))
        
        # Output Q-value
        q_value = nn.Dense(1, kernel_init=nn.initializers.xavier_uniform())(x)
        return q_value.squeeze()

In [14]:
def log_laplace(y, sigma_rootinv, mean):
    """Log density of Laplace distribution
    
    Args:
        y: Proposal sample
        sigma_rootinv: Inverse square root of covariance
        mean: Mean of the distribution
    
    Returns:
        Log probability
    """
    return -jnp.linalg.norm(sigma_rootinv @ (y - mean), ord=1)

@jax.jit
def sample_laplace(key, mean, sigma):
    """Sample from Laplace distribution
    
    Args:
        key: JAX PRNG key
        mean: Mean of the distribution
        sigma: Covariance matrix
    
    Returns:
        Sample from Laplace distribution
    """
    d = mean.shape[0]
    
    # Compute the inverse square root of sigma
    eigvals, eigvecs = jnp.linalg.eigh(sigma)
    sigma_rootinv = eigvecs @ jnp.diag(1.0 / jnp.sqrt(jnp.maximum(eigvals, 1e-6))) @ eigvecs.T
    
    # Generate standard exponential random variables
    key1, key2 = jax.random.split(key)
    z1 = jax.random.exponential(key1, shape=(d,))
    z2 = jax.random.exponential(key2, shape=(d,))
    
    # Difference of exponentials gives Laplace
    z = z1 - z2
    
    # Transform to get the right mean and covariance
    return mean + jnp.linalg.solve(sigma_rootinv, z)

In [15]:
class ReplayBuffer:
    """Experience replay buffer for DDPG"""
    
    def __init__(self, state_dim, action_dim, max_size=100000):
        self.max_size = max_size
        self.ptr = 0
        self.size = 0
        
        self.states = np.zeros((max_size, state_dim))
        self.actions = np.zeros((max_size, action_dim))
        self.rewards = np.zeros((max_size, 1))
        self.next_states = np.zeros((max_size, state_dim))
        
    def add(self, state, action, reward, next_state):
        self.states[self.ptr] = state
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_states[self.ptr] = next_state
        
        self.ptr = (self.ptr + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
        
    def sample(self, batch_size):
        ind = np.random.randint(0, self.size, size=batch_size)
        return (
            jnp.array(self.states[ind]),
            jnp.array(self.actions[ind]),
            jnp.array(self.rewards[ind]),
            jnp.array(self.next_states[ind])
        )


class DDPGTrainState:
    """Training state for DDPG algorithm"""
    
    def __init__(self, actor_params, critic_params, target_actor_params, 
                 target_critic_params, actor_opt_state, critic_opt_state):
        self.actor_params = actor_params
        self.critic_params = critic_params
        self.target_actor_params = target_actor_params
        self.target_critic_params = target_critic_params
        self.actor_opt_state = actor_opt_state
        self.critic_opt_state = critic_opt_state

In [16]:
class RLMH:
    """Reinforcement Learning Metropolis-Hastings (RLMH) algorithm
    
    This class implements the RLMH algorithm described in the paper.
    """
    
    def __init__(self, target_log_density, d, actor_features, critic_features, 
                 batch_size=64, gamma=0.99, tau=0.005, buffer_size=100000, 
                 actor_lr=1e-4, critic_lr=1e-3):
        """Initialize RLMH
        
        Args:
            target_log_density: Log density function to sample from
            d: Dimension of the parameter space
            actor_features: Hidden layer sizes for actor network
            critic_features: Hidden layer sizes for critic network
            batch_size: Batch size for training
            gamma: Discount factor for DDPG
            tau: Soft update coefficient
            buffer_size: Size of the replay buffer
            actor_lr: Learning rate for actor
            critic_lr: Learning rate for critic
        """
        self.target_log_density = target_log_density
        self.d = d
        self.actor_features = actor_features
        self.critic_features = critic_features
        self.batch_size = batch_size
        self.gamma = gamma
        self.tau = tau
        self.replay_buffer = ReplayBuffer(2*d, 2*d, buffer_size)
        self.actor_lr = actor_lr
        self.critic_lr = critic_lr
        
        # Will be initialized after warm-up
        self.mean = None
        self.sigma = None
        self.actor = None
        self.critic = None
        self.train_state = None
        self.current_state = None
        
    def warmup(self, key, m=10000, gamma_init=0.1, alpha_star=0.234):
        """Run AMH to get initial samples and adaptation parameters
        
        Args:
            key: JAX PRNG key
            m: Number of warm-up iterations
            gamma_init: Initial adaptation rate
            alpha_star: Target acceptance rate
            
        Returns:
            Samples from the warm-up phase
        """
        print("Running AMH warm-up...")
        
        samples, means, sigmas, lambdas, accept_status = amh(
            self.target_log_density, key, self.d, m, gamma_init, alpha_star
        )
        
        # Store the mean and covariance from warm-up
        self.mean = means[-1]
        self.sigma = sigmas[-1]
        
        print(f"Warm-up acceptance rate: {jnp.mean(accept_status):.3f}")
        print(f"Final lambda: {lambdas[-1]:.5f}")
        
        return samples
    
    def initialize_networks(self, key):
        """Initialize actor and critic networks after warm-up
        
        Args:
            key: JAX PRNG key
        """
        if self.mean is None or self.sigma is None:
            raise ValueError("Must run warm-up before initializing networks")
        
        key1, key2 = jax.random.split(key)
        
        # Initialize actor network
        self.actor = ActorNetwork(
            features=self.actor_features,
            mean=self.mean,
            sigma=self.sigma
        )
        
        # Initialize critic network
        self.critic = CriticNetwork(
            features=self.critic_features
        )
        
        # Initial state for both networks
        dummy_state = jnp.zeros(2 * self.d)
        dummy_action = jnp.zeros(2 * self.d)
        
        actor_params = self.actor.init(key1, dummy_state)
        critic_params = self.critic.init(key2, dummy_state, dummy_action)
        
        # Initialize optimizers
        actor_optimizer = optax.adam(learning_rate=self.actor_lr)
        critic_optimizer = optax.adam(learning_rate=self.critic_lr)
        
        actor_opt_state = actor_optimizer.init(actor_params)
        critic_opt_state = critic_optimizer.init(critic_params)
        
        # Initialize target networks with same parameters
        target_actor_params = actor_params
        target_critic_params = critic_params
        
        # Create training state
        self.train_state = DDPGTrainState(
            actor_params=actor_params,
            critic_params=critic_params,
            target_actor_params=target_actor_params,
            target_critic_params=target_critic_params,
            actor_opt_state=actor_opt_state,
            critic_opt_state=critic_opt_state
        )
        
        # Set up optimizers as functions
        self.actor_optimizer = actor_optimizer
        self.critic_optimizer = critic_optimizer
        
        # Initialize JIT-compiled functions
        self.init_jitted_functions()
    
    def initialize_state(self, key, x_init=None):
        """Initialize state for RLMH
        
        Args:
            key: JAX PRNG key
            x_init: Initial state (optional)
        """
        if x_init is None:
            x_init = jnp.zeros(self.d)
        
        # Generate proposal using Laplace distribution
        key, subkey = jax.random.split(key)
        x_prop = sample_laplace(subkey, self.mean, self.sigma)
        
        # Set current state
        self.current_state = jnp.concatenate([x_init, x_prop])
        
        return key
    
    #@jax.jit
    def _compute_acceptance_probability(self, x, x_prop, phi_x, phi_prop):
        """Compute Metropolis-Hastings acceptance probability
        
        Args:
            x: Current state
            x_prop: Proposed state
            phi_x: Mean of proposal for x
            phi_prop: Mean of proposal for x_prop
            
        Returns:
            Log acceptance probability
        """
        # Log probability ratio of target distribution
        log_prob_ratio = self.target_log_density(x_prop) - self.target_log_density(x)
        
        # Log proposal ratio for Laplace distribution
        log_prop_ratio = log_laplace(x, self.sigma, phi_prop) - log_laplace(x_prop, self.sigma, phi_x)
        
        # Log acceptance probability
        return jnp.minimum(0.0, log_prob_ratio + log_prop_ratio)
    
    #@jax.jit
    def _compute_reward(self, x, x_prop, log_alpha):
        """Compute reward for RLMH
        
        Args:
            x: Current state
            x_prop: Proposed state
            log_alpha: Log acceptance probability
            
        Returns:
            Reward
        """
        # Reward as defined in the paper
        return 2.0 * jnp.log(jnp.linalg.norm(x - x_prop) + 1e-10) + log_alpha
    
    # Create JIT-compiled versions separately
    def init_jitted_functions(self):
        """Initialize JIT-compiled functions"""
        # We need to use closures to capture self
        target_log_density = self.target_log_density
        sigma = self.sigma
        
        @jax.jit
        def _jit_compute_acceptance_probability(x, x_prop, phi_x, phi_prop):
            # Log probability ratio of target distribution
            log_prob_ratio = target_log_density(x_prop) - target_log_density(x)
            
            # Log proposal ratio for Laplace distribution
            log_prop_ratio = log_laplace(x, sigma, phi_prop) - log_laplace(x_prop, sigma, phi_x)
            
            # Log acceptance probability
            return jnp.minimum(0.0, log_prob_ratio + log_prop_ratio)
        
        @jax.jit
        def _jit_compute_reward(x, x_prop, log_alpha):
            # Reward as defined in the paper
            return 2.0 * jnp.log(jnp.linalg.norm(x - x_prop) + 1e-10) + log_alpha
        
        self._jit_compute_acceptance_probability = _jit_compute_acceptance_probability
        self._jit_compute_reward = _jit_compute_reward
    
    def step(self, key):
        """Take one step of RLMH"""
        key1, key2, key3 = jax.random.split(key, 3)
        
        # Get current state
        state = self.current_state
        x = state[:self.d]
        x_prop = state[self.d:]
        
        # Get action from actor
        action = self.actor.apply(self.train_state.actor_params, state)
        phi_x = action[:self.d]
        phi_prop = action[self.d:]
        
        # Check if we have JIT-compiled functions
        if not hasattr(self, '_jit_compute_acceptance_probability'):
            self.init_jitted_functions()
        
        # Use JIT-compiled functions
        log_alpha = self._jit_compute_acceptance_probability(x, x_prop, phi_x, phi_prop)
        reward = self._jit_compute_reward(x, x_prop, log_alpha)
        
        # Accept/reject step
        accept = jnp.log(jax.random.uniform(key1)) < log_alpha
        x_next = jnp.where(accept, x_prop, x)
        
        # Generate next proposal
        x_prop_next = sample_laplace(key2, phi_x, self.sigma)
        
        # Update current state
        next_state = jnp.concatenate([x_next, x_prop_next])
        
        # Store experience in replay buffer
        self.replay_buffer.add(state, action, reward, next_state)
        
        # Update current state
        self.current_state = next_state
        
        return key3, reward, accept
    
    def update_networks(self):
        """Update actor and critic networks using DDPG
        
        Returns:
            Updated train state, actor loss, critic loss
        """
        if self.replay_buffer.size < self.batch_size:
            return self.train_state, 0.0, 0.0
        
        # Sample a batch from replay buffer
        states, actions, rewards, next_states = self.replay_buffer.sample(self.batch_size)
        
        # Compute critic loss
        def critic_loss_fn(critic_params):
            # Q values for current state-action pairs
            q_values = jax.vmap(self.critic.apply, in_axes=(None, 0, 0))(
                critic_params, states, actions
            )
            
            # Next actions from target actor
            next_actions = jax.vmap(self.actor.apply, in_axes=(None, 0))(
                self.train_state.target_actor_params, next_states
            )
            
            # Q values for next state-action pairs from target critic
            next_q_values = jax.vmap(self.critic.apply, in_axes=(None, 0, 0))(
                self.train_state.target_critic_params, next_states, next_actions
            )
            
            # Target Q values
            target_q = rewards + self.gamma * next_q_values
            
            # Mean squared error loss
            return jnp.mean(jnp.square(q_values - target_q))
        
        # Compute gradients for critic
        critic_grad_fn = jax.value_and_grad(critic_loss_fn)
        critic_loss, critic_grads = critic_grad_fn(self.train_state.critic_params)
        
        # Update critic
        critic_updates, critic_opt_state = self.critic_optimizer.update(
            critic_grads, self.train_state.critic_opt_state
        )
        critic_params = optax.apply_updates(self.train_state.critic_params, critic_updates)
        
        # Compute actor loss
        def actor_loss_fn(actor_params):
            # Actions from actor
            pred_actions = jax.vmap(self.actor.apply, in_axes=(None, 0))(
                actor_params, states
            )
            
            # Q values from critic
            q_values = jax.vmap(self.critic.apply, in_axes=(None, 0, 0))(
                self.train_state.critic_params, states, pred_actions
            )
            
            # Negative mean Q value (we want to maximize Q)
            return -jnp.mean(q_values)
        
        # Compute gradients for actor
        actor_grad_fn = jax.value_and_grad(actor_loss_fn)
        actor_loss, actor_grads = actor_grad_fn(self.train_state.actor_params)
        
        # Update actor
        actor_updates, actor_opt_state = self.actor_optimizer.update(
            actor_grads, self.train_state.actor_opt_state
        )
        actor_params = optax.apply_updates(self.train_state.actor_params, actor_updates)
        
        # Soft update for target networks
        target_actor_params = jax.tree_map(
            lambda x, y: (1 - self.tau) * x + self.tau * y,
            self.train_state.target_actor_params,
            actor_params
        )
        
        target_critic_params = jax.tree_map(
            lambda x, y: (1 - self.tau) * x + self.tau * y,
            self.train_state.target_critic_params,
            critic_params
        )
        
        # Update training state
        self.train_state = DDPGTrainState(
            actor_params=actor_params,
            critic_params=critic_params,
            target_actor_params=target_actor_params,
            target_critic_params=target_critic_params,
            actor_opt_state=actor_opt_state,
            critic_opt_state=critic_opt_state
        )
        
        return self.train_state, actor_loss, critic_loss
    
    def train(self, key, num_iterations=10000, update_every=10, log_every=1000):
        """Train RLMH
        
        Args:
            key: JAX PRNG key
            num_iterations: Number of training iterations
            update_every: Update networks every N steps
            log_every: Log stats every N steps
            
        Returns:
            Samples, rewards, acceptance rates
        """
        if self.train_state is None:
            raise ValueError("Must initialize networks before training")
        
        samples = np.zeros((num_iterations, self.d))
        rewards = np.zeros(num_iterations)
        accepts = np.zeros(num_iterations, dtype=bool)
        
        for i in range(num_iterations):
            # Take a step
            key, reward, accept = self.step(key)
            
            # Store results
            samples[i] = self.current_state[:self.d]
            rewards[i] = reward
            accepts[i] = accept
            
            # Update networks
            if i % update_every == 0:
                _, actor_loss, critic_loss = self.update_networks()
            
            # Log progress
            if i % log_every == 0:
                recent_accept_rate = np.mean(accepts[max(0, i-log_every):i+1])
                print(f"Iteration {i}: Accept rate = {recent_accept_rate:.3f}, " +
                      f"Average reward = {np.mean(rewards[max(0, i-log_every):i+1]):.3f}")
        
        return samples, rewards, accepts

In [19]:
def run_rlmh_experiment(target_log_density, d=1, num_warmup=10000, num_train=50000):
    """Run a complete RLMH experiment
    
    Args:
        target_log_density: Log density function to sample from
        d: Dimension of the parameter space
        num_warmup: Number of warm-up iterations
        num_train: Number of training iterations
        
    Returns:
        RLMH instance, samples, rewards, acceptance rates
    """
    # Initialize
    key = jax.random.PRNGKey(0)
    key, subkey = jax.random.split(key)
    
    # Create RLMH instance
    rlmh = RLMH(
        target_log_density=target_log_density,
        d=d,
        actor_features=[64, 64],
        critic_features=[64, 64]
    )
    
    # Run warm-up phase
    warmup_samples = rlmh.warmup(subkey, m=num_warmup)
    
    # Initialize networks
    key, subkey = jax.random.split(key)
    rlmh.initialize_networks(subkey)
    
    # Initialize state with last sample from warm-up
    key, subkey = jax.random.split(key)
    rlmh.initialize_state(subkey, x_init=warmup_samples[-1])
    
    # Train RLMH
    key, subkey = jax.random.split(key)
    samples, rewards, accepts = rlmh.train(subkey, num_iterations=num_train)
    
    print(f"Overall acceptance rate: {np.mean(accepts):.3f}")
    
    return rlmh, samples, rewards, accepts

In [20]:
# Define target density (e.g., Gaussian mixture)
def target_log_density(x):
    # 2-component Gaussian mixture
    mu1 = jnp.array([-5.0])
    mu2 = jnp.array([5.0])
    sigma = 1.0
    
    log_p1 = -0.5 * jnp.sum(((x - mu1) / sigma) ** 2)
    log_p2 = -0.5 * jnp.sum(((x - mu2) / sigma) ** 2)
    return jnp.logaddexp(log_p1, log_p2)

# Run RLMH experiment
rlmh, samples, rewards, accepts = run_rlmh_experiment(
    target_log_density, 
    d=1,
    num_warmup=5000,
    num_train=20000
)

# Plot results
plot_results(rlmh.warmup(jax.random.PRNGKey(1), m=5000), samples, target_log_density)

Running AMH warm-up...
Warm-up acceptance rate: 0.411
Final lambda: 9.24790


TypeError: Cannot interpret value of type <class '__main__.RLMH'> as an abstract array; it does not have a dtype attribute