In [None]:
!pip install -q swig
!pip install -q gymnasium[box2d]
!pip install -q stable-baselines3[extra]

In [None]:
import base64
import gymnasium as gym
import json
import numpy as np

from io import BytesIO
from PIL import Image
from stable_baselines3 import PPO
from itertools import count

In [None]:
def render_env_base64(env):
    frame = env.render()
    img = Image.fromarray(frame)
    buf = BytesIO()
    img.save(buf, format='PNG')
    return base64.b64encode(buf.getvalue()).decode('utf-8')

In [None]:
# Set default values
env_name = 'LunarLander-v3'
iterations = 100  # Match the default value that MIXTAPE uses
train_batch_size = 64  # Match the default value that MIXTAPE uses

In [None]:
# Build output
output = {
    'action_mapping': {
        '0': 'None',
        '1': 'Left engine',
        '2': 'Main engine',
        '3': 'Right engine'
    },
    'training': {
        'environment': env_name,
        'algorithm': 'PPO',
        'iterations': iterations,
        'config': {},
    },
    'inference': {
        'parallel': False,
        'config': {},
        'steps': []
    }
}

In [None]:
# Train the model
model_path = "ppo_cartpole_sb3"
train_env = gym.make("LunarLander-v3")
model = PPO("MlpPolicy", train_env, verbose=0)
# NOTE: The Stable Baselines3 library does not have a way to specify the number of training iterations
# like we do in the MIXTAPE system. In order to match what the MIXTAPE system does as closely as
# possible, we set the total_timesteps to be the same as the train_batch_size that we use in MIXTAPE and
# the number of iterations to be the same as the default number of training iterations.
for i in range(iterations):
  model.learn(total_timesteps=train_batch_size, reset_num_timesteps=False)
model.save(model_path)
output['training']['config'] = model.get_parameters()['policy.optimizer']['param_groups'][0]

In [None]:
# Run inference and log data
env_id = "LunarLander-v3"
env = gym.make(env_id, render_mode="rgb_array")
obs, info = env.reset()
done = False

for step_num in count():
    action, _ = model.predict(obs, deterministic=True)
    next_obs, reward, terminated, truncated, info = env.step(action)
    frame_b64 = render_env_base64(env)
    output['inference']['steps'].append({
        "number": step_num,
        "image": frame_b64,
        "agent_steps": [
            {
                "agent": "agent_0",
                "action": int(action),
                "reward": float(reward),
                "observation_space": np.array(obs).tolist()
            }
        ]
    })
    obs = next_obs
    if terminated or truncated:
        break

env.close()

In [None]:
# Write log to JSON file
with open("stable_baselines_example.json", "w") as f:
    json.dump(output, f, indent=2)