In [None]:
import numpy as np
import tensorflow as tf
from garage import wrap_experiment
from garage.envs import GymEnv
from garage.experiment.deterministic import set_seed
from garage.sampler import LocalSampler
from garage.tf.models import GaussianMLPModel
from garage.tf.policies import CategoricalMLPPolicy
from garage.tf.algos import ModelBasedPolicyOptimization
from garage.tf.optimizers import FirstOrderOptimizer
from garage.replay_buffer import PathBuffer
from garage.sampler.default_worker import DefaultWorker
from garage.experiment.task_sampler import EnvPoolSampler


@wrap_experiment
def mb_rl_cartpole(ctxt=None, seed=42):
    """
    Model-Based RL for CartPole using Garage.
    Args:
        ctxt: Experiment context for saving results.
        seed: Random seed for reproducibility.
    """
    # Set the random seed for reproducibility
    set_seed(seed)
    
    # Load the CartPole environment
    env = GymEnv("CartPole-v1")
    env.spec.max_episode_length = 200
    
    # Define the dynamics model
    dynamics_model = GaussianMLPModel(
        input_dim=env.spec.observation_space.flat_dim + env.spec.action_space.flat_dim,
        output_dim=env.spec.observation_space.flat_dim,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=tf.nn.relu,
        output_nonlinearity=None,
        name="DynamicsModel"
    )
    
    # Define the policy
    policy = CategoricalMLPPolicy(
        env_spec=env.spec,
        hidden_sizes=(64, 64),
        hidden_nonlinearity=tf.nn.relu
    )
    
    # Define the optimizer for the dynamics model
    dynamics_optimizer = FirstOrderOptimizer(
        learning_rate=1e-3,
        max_epochs=50
    )
    
    # Replay buffer to store experiences
    replay_buffer = PathBuffer(capacity_in_transitions=100000)
    
    # Sampler for collecting trajectories
    sampler = LocalSampler(agents=policy, envs=env, max_episode_length=env.spec.max_episode_length)
    
    # Define the MBRL algorithm
    algo = ModelBasedPolicyOptimization(
        env_spec=env.spec,
        dynamics_model=dynamics_model,
        policy=policy,
        dynamics_optimizer=dynamics_optimizer,
        buffer=replay_buffer,
        sampler=sampler,
        imagination_horizon=5,  # Number of steps to simulate in the learned model
        discount=0.99,
        entropy_regularization=1e-3
    )
    
    # Train the algorithm
    algo.train(n_epochs=100, batch_size=4000)


if __name__ == "__main__":
    mb_rl_cartpole()
