## Training Agents

1. Load required libraries:

In [1]:
from stable_baselines3.ppo import CnnPolicy
from stable_baselines3 import PPO
from pettingzoo.butterfly import pistonball_v5
import supersuit as ss




2. Initialize the PettingZoo environment: 

In [2]:
env = pistonball_v5.parallel_env(n_pistons=20, time_penalty=-0.1, 
                                 continuous=True, random_drop=True, 
                                 random_rotate=True, ball_mass=0.75, 
                                 ball_friction=0.3, ball_elasticity=1.5, 
                                 max_cycles=125)

3. We wrap our environment in SuperSuit to remove the 3 color channels, as we only need grey scale for this task

In [3]:
env = ss.color_reduction_v0(env, mode='B')

4. We resize our input to 84x84 pixels

In [4]:
env = ss.resize_v0(env, x_size=84, y_size=84)

5. We stack the previous frames together to give information related to the balls movement

In [5]:
env = ss.frame_stack_v1(env, 4)

6. We convert the environments API which causes Stable Baselines to do parameter sharing of the policy network on a multiagent environment

In [6]:
env = ss.pettingzoo_env_to_vec_env_v1(env)

7. This step enables parallel training

In [7]:
# 8 threads on 4 cpus
env = ss.concat_vec_envs_v1(env, 8, num_cpus=4, base_class='stable_baselines3')

8. We can finally train our agent

In [8]:
model = PPO(CnnPolicy, env, verbose=3, gamma=0.95, n_steps=256,
            ent_coef=0.0905168, learning_rate=0.0006221, vf_coef=0.042202,
            max_grad_norm=0.9, gae_lambda=0.99, n_epochs=5, clip_range=0.3,
            batch_size=256)

model.learn(total_timesteps=2_000_000)

model.save('policy')

Using cuda device
Wrapping the env in a VecTransposeImage.
------------------------------
| time/              |       |
|    fps             | 2877  |
|    iterations      | 1     |
|    time_elapsed    | 14    |
|    total_timesteps | 40960 |
------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1994        |
|    iterations           | 2           |
|    time_elapsed         | 41          |
|    total_timesteps      | 81920       |
| train/                  |             |
|    approx_kl            | 0.011979329 |
|    clip_fraction        | 0.0433      |
|    clip_range           | 0.3         |
|    entropy_loss         | -1.51       |
|    explained_variance   | 0.051       |
|    learning_rate        | 0.000622    |
|    loss                 | 0.11        |
|    n_updates            | 5           |
|    policy_gradient_loss | 0.00314     |
|    std                  | 1.12        |
|    value

-----------------------------------------
| time/                   |             |
|    fps                  | 1570        |
|    iterations           | 12          |
|    time_elapsed         | 313         |
|    total_timesteps      | 491520      |
| train/                  |             |
|    approx_kl            | 0.009718051 |
|    clip_fraction        | 0.0391      |
|    clip_range           | 0.3         |
|    entropy_loss         | -2.5        |
|    explained_variance   | -0.0135     |
|    learning_rate        | 0.000622    |
|    loss                 | 1.45        |
|    n_updates            | 55          |
|    policy_gradient_loss | 0.00121     |
|    std                  | 2.99        |
|    value_loss           | 44.7        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1563        |
|    iterations           | 13          |
|    time_elapsed         | 340   

-----------------------------------------
| time/                   |             |
|    fps                  | 1537        |
|    iterations           | 23          |
|    time_elapsed         | 612         |
|    total_timesteps      | 942080      |
| train/                  |             |
|    approx_kl            | 0.009103665 |
|    clip_fraction        | 0.0356      |
|    clip_range           | 0.3         |
|    entropy_loss         | -3.41       |
|    explained_variance   | 0.0444      |
|    learning_rate        | 0.000622    |
|    loss                 | 2.85        |
|    n_updates            | 110         |
|    policy_gradient_loss | 0.00222     |
|    std                  | 7.41        |
|    value_loss           | 71.6        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1536        |
|    iterations           | 24          |
|    time_elapsed         | 639   

------------------------------------------
| time/                   |              |
|    fps                  | 1526         |
|    iterations           | 34           |
|    time_elapsed         | 912          |
|    total_timesteps      | 1392640      |
| train/                  |              |
|    approx_kl            | 0.0072430028 |
|    clip_fraction        | 0.0311       |
|    clip_range           | 0.3          |
|    entropy_loss         | -4.24        |
|    explained_variance   | 0.0107       |
|    learning_rate        | 0.000622     |
|    loss                 | 2.4          |
|    n_updates            | 165          |
|    policy_gradient_loss | 0.00218      |
|    std                  | 16.9         |
|    value_loss           | 60.3         |
------------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1525        |
|    iterations           | 35          |
|    time_elaps

-----------------------------------------
| time/                   |             |
|    fps                  | 1519        |
|    iterations           | 45          |
|    time_elapsed         | 1212        |
|    total_timesteps      | 1843200     |
| train/                  |             |
|    approx_kl            | 0.007270542 |
|    clip_fraction        | 0.0298      |
|    clip_range           | 0.3         |
|    entropy_loss         | -5.12       |
|    explained_variance   | -0.00946    |
|    learning_rate        | 0.000622    |
|    loss                 | 2.02        |
|    n_updates            | 220         |
|    policy_gradient_loss | 0.00163     |
|    std                  | 41          |
|    value_loss           | 64.1        |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 1519        |
|    iterations           | 46          |
|    time_elapsed         | 1240  

## Watching Agents Play

Now that the model is trained and saved, we can load the policy and watch it play!

We reinstantiate the environment

In [10]:
env = pistonball_v5.env()
env = ss.color_reduction_v0(env, mode='B')
env = ss.resize_v0(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 4)

Load the policy

In [11]:
model = PPO.load('policy')

And finally we can use the policy and render this

In [12]:
env.reset()
for agent in env.agent_iter():
    obs, reward, done, info = env.last()
    act = model.predict(obs, deterministic=True)[0] if not done else None
    env.step(act)
    env.render()

Finally, we close the environment.

In [13]:
env.close()