To clarify, MBPO is an off-policy model-based reinforcement learning algorithm that combines model-free and model-based learning. It utilizes an ensemble of models to approximate the true environment dynamics, which helps with exploration and improves sample efficiency. The algorithm then uses these dynamics models to generate a buffer of imagined trajectories and updates the policy using a model-free method like Soft Actor-Critic (SAC) or Proximal Policy Optimization (PPO).

In comparison, Dyna-Q is a model-based reinforcement learning algorithm that alternates between learning a model of the environment's dynamics and using that model to improve the policy through simulated experience. It does not leverage an ensemble of models, and it relies on Q-learning for the policy update step.

Here are the main components of MBPO:

- Model Learning: MBPO learns a probabilistic model of the environment's dynamics using collected data. This model is used to predict the next state and reward given the current state and action. It can be a neural network or any other function approximator.
- Model Rollouts: Using the learned model, MBPO generates a set of trajectories by rolling out the model for multiple steps. These rollouts help in creating additional data points that can be used for training the policy and updating the value function.
- Policy Optimization: The policy optimization is performed using model-free techniques, such as Proximal Policy Optimization (PPO), Trust Region Policy Optimization (TRPO), or Soft Actor-Critic (SAC). The objective is to optimize the policy using both real and model-generated data.
- Iterative Update: MBPO alternates between model learning, model rollouts, and policy optimization, gradually improving the model's accuracy and the policy's performance.

The main advantage of MBPO is that it reduces the number of interactions with the real environment by generating additional data using the learned model. This makes it more sample-efficient compared to model-free methods. However, the quality of the learned policy depends on the accuracy of the learned model, which can be sensitive to errors or biases in the model.

Here's the high-level pseudocode for the MBPO implementation provided above:

- Initialize environment, hyperparameters, experience buffer, dynamics model, policy model, and their respective optimizers.
- Collect initial data from the environment and store it in the experience buffer.
- For each epoch:
  - Collect real data from the environment using the current policy and store it in the experience buffer.
  - Train the dynamics model using data from the experience buffer.
  - Generate model rollouts using the current dynamics model and policy. Store the generated data in the experience buffer.
  - Optimize the policy using the combined data (real and model-generated) from the experience buffer.
- Repeat for a specified number of epochs.

The pseudocode represents the main components of the MBPO algorithm, which includes data collection, dynamics model training, model rollouts generation, and policy optimization

In [25]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import gymnasium as gym
import numpy as np
import random
from typing import Tuple


In [26]:
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

num_epochs = 1000
rollout_steps = 5
num_model_rollouts = 20
num_model_updates = 5
buffer_size = 10000
batch_size = 256


In [38]:
class DynamicsModel(nn.Module):
    @nn.compact 
    def __call__(self, state, action) -> Tuple[jnp.ndarray, jnp.ndarray]:
        input = jnp.concatenate([state, action], axis=-1)
        x = nn.Dense(64)(input)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        output = nn.Dense(state_dim + 1)(x)
        return output

class PolicyModel(nn.Module):
    action_dim: jnp.int32

    @nn.compact
    def __call__(self, state) -> jnp.ndarray:
        x = nn.Dense(64)(state)
        x = nn.relu(x)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        logits = nn.Dense(action_dim)(x)
        return nn.softmax(logits)


In [39]:
# Initialize the neural networks and optimizers
dynamics_model = DynamicsModel()
policy_model = PolicyModel(2)

rng_key = jax.random.PRNGKey(0)
dynamics_params = dynamics_model.init(rng_key, jnp.ones((1, state_dim)), jnp.ones((1, 1), dtype=jnp.int32))
policy_params = policy_model.init(rng_key, jnp.ones((1, state_dim)))

dynamics_optimizer = optax.adam(1e-3)
policy_optimizer = optax.adam(1e-3)

dynamics_opt_state = dynamics_optimizer.init(dynamics_params)
policy_opt_state = policy_optimizer.init(policy_params)


In [None]:
@jax.jit
def dynamics_loss(params, state, action, next_state, reward) -> jnp.ndarray:
    predicted = DynamicsModel().apply(params, state, action)
    target = jnp.concatenate([next_state, reward], axis=-1)
    return jnp.mean(jnp.square(predicted - target))

@jax.jit
def policy_loss(params, state, action, advantage) -> jnp.ndarray:
    logits = PolicyModel().apply(params, state)
    log_probs = jax.nn.log_softmax(logits)
    action_log_probs = jnp.take_along_axis(log_probs, action[..., None], axis=-1)
    return -jnp.mean(action_log_probs * advantage)

dynamics_grad_fn = jax.jit(jax.grad(dynamics_loss))
policy_grad_fn = jax.jit(jax.grad(policy_loss))


In [None]:
# 1. Define the Experience Buffer class
class ExperienceBuffer:
    def __init__(self, buffer_size):
        self.buffer_size = buffer_size
        self.buffer = []
        self.position = 0

    def add(self, state, action, reward, next_state, done):
        transition = (state, action, reward, next_state, done)
        if len(self.buffer) < self.buffer_size:
            self.buffer.append(transition)
        else:
            self.buffer[self.position] = transition
            self.position = (self.position + 1) % self.buffer_size

    def sample(self, batch_size):
        idx = np.random.randint(0, len(self.buffer), size=batch_size)
        states, actions, rewards, next_states, dones = zip(*[self.buffer[i] for i in idx])
        return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)

    def __len__(self):
        return len(self.buffer)

# 2. Initialize the experience buffer
buffer = ExperienceBuffer(buffer_size)

# 3. Collect initial data from the environment and add it to the experience buffer
state = env.reset()
for _ in range(buffer_size):
    action = env.action_space.sample()
    next_state, reward, done, _ = env.step(action)
    buffer.add(state, action, reward, next_state, done)
    state = next_state if not done else env.reset()

# 4. Perform model rollouts
def rollout_model(params, state, policy_params, rollout_steps, rng_key):
    states, actions, rewards, next_states, dones = [], [], [], [], []
    for _ in range(rollout_steps):
        action_probs = PolicyModel().apply(policy_params, state)
        action = jax.random.categorical(rng_key, action_probs)
        
        next_state, reward = jnp.split(DynamicsModel().apply(params, state, action), [state_dim])
        next_state, reward = jnp.squeeze(next_state), jnp.squeeze(reward)
        done = False  # Assume the model rollouts don't reach terminal states
        
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        next_states.append(next_state)
        dones.append(done)
        
        state = next_state
        
    return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)


In [None]:
# 5. Training loop
rng_key = jax.random.PRNGKey(42)
for epoch in range(num_epochs):
    # Collect real data
    state, _ = env.reset()
    for _ in range(100):
        action = jax.random.choice(rng_key, action_dim, p=PolicyModel().apply(policy_params, state))
        next_state, reward, done, _, _ = env.step(int(action))
        buffer.add(state, action, reward, next_state, done)
        state = next_state if not done else env.reset()[0]

    # Train the dynamics model
    for _ in range(num_model_updates):
        states, actions, rewards, next_states, dones = buffer.sample(batch_size)
        grads = dynamics_grad_fn(dynamics_params, states, actions, next_states, rewards)
        updates, dynamics_opt_state = dynamics_optimizer.update(grads, dynamics_opt_state)
        dynamics_params = optax.apply_updates(dynamics_params, updates)

    # Generate model rollouts
    state, _ = env.reset()
    for _ in range(num_model_rollouts):
        rollout_rng_key, rng_key = jax.random.split(rng_key)
        states, actions, rewards, next_states, dones = rollout_model(dynamics_params, state, policy_params, rollout_steps, rollout_rng_key)

        for s, a, r, ns, d in zip(states, actions, rewards, next_states, dones):
            buffer.add(s, a, r, ns, d)
            
        state, _ = env.reset()

    # Optimize the policy using both real and model-generated data
    for _ in range(100):
        states, actions, rewards, next_states, dones = buffer.sample(batch_size)
        next_action_probs = PolicyModel().apply(policy_params, next_states)
        next_action_values = jnp.sum(next_action_probs * DynamicsModel().apply(dynamics_params, next_states, jnp.arange(action_dim)[:, None]), axis=-1)
        next_action_values = jnp.squeeze(next_action_values)
        target_values = rewards + 0.99 * (1 - dones) * next_action_values
        advantages = target_values - DynamicsModel().apply(dynamics_params, states, actions)
        advantages = jnp.squeeze(advantages)

        grads = policy_grad_fn(policy_params, states, actions, advantages)
        updates, policy_opt_state = policy_optimizer.update(grads, policy_opt_state)
        policy_params = optax.apply_updates(policy_params, updates)

