In [1]:
import os
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.evaluation import evaluate_policy


In [2]:
environment_name = 'CartPole-v1'
env = gym.make(environment_name, render_mode="human")


In [3]:
episodes = 5
for episode in range(1,episodes+1):
    state = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action = env.action_space.sample()
        n_state, reward, terminated, truncated, info = env.step(action)
        score += reward

        done = terminated or truncated
    print('Episode:{} Score:{}'.format(episode, score))
env.close()

Episode:1 Score:11.0
Episode:2 Score:25.0
Episode:3 Score:28.0
Episode:4 Score:10.0
Episode:5 Score:12.0


In [4]:
env.action_space

Discrete(2)

In [5]:
env.action_space.sample()

0

In [6]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [7]:
env.observation_space.sample()


array([ 3.5075727e-01, -1.7252343e+38, -3.6350441e-01,  1.0343658e+38],
      dtype=float32)

In [8]:
log_path = os.path.join('Training', 'Logs')

In [9]:
log_path

'Training\\Logs'

In [10]:
env = gym.make(environment_name)
env = DummyVecEnv([lambda: env])
model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path)

Using cpu device


In [11]:
model.learn(total_timesteps=20000)

Logging to Training\Logs\PPO_24
-----------------------------
| time/              |      |
|    fps             | 1522 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------
------------------------------------------
| time/                   |              |
|    fps                  | 787          |
|    iterations           | 2            |
|    time_elapsed         | 5            |
|    total_timesteps      | 4096         |
| train/                  |              |
|    approx_kl            | 0.0074076606 |
|    clip_fraction        | 0.0844       |
|    clip_range           | 0.2          |
|    entropy_loss         | -0.687       |
|    explained_variance   | -0.000131    |
|    learning_rate        | 0.0003       |
|    loss                 | 8.32         |
|    n_updates            | 10           |
|    policy_gradient_loss | -0.0138      |
|    value_loss           | 59.8         |
---------------------------

<stable_baselines3.ppo.ppo.PPO at 0x286add4af50>

In [12]:
PPO_Path = os.path.join('Training', 'Saved Models', 'PPO_Model_CartPole')

In [13]:
model.save(PPO_Path)

In [14]:
PPO_Path

'Training\\Saved Models\\PPO_Model_CartPole'

In [15]:
model.learn(total_timesteps=1000)

Logging to Training\Logs\PPO_25
-----------------------------
| time/              |      |
|    fps             | 1257 |
|    iterations      | 1    |
|    time_elapsed    | 1    |
|    total_timesteps | 2048 |
-----------------------------


<stable_baselines3.ppo.ppo.PPO at 0x286add4af50>

In [16]:
model = PPO.load(PPO_Path, env=env)

## Evaluation ##

In [17]:
evaluate_policy(model, env, n_eval_episodes=10, render=True)



(500.0, 0.0)

In [18]:
env.close()

In [19]:
episodes = 5
for episode in range(1,episodes+1):
    obs = env.reset()
    done = False
    score = 0

    while not done:
        env.render()
        action, _ = model.predict(obs) #now using model
        obs, reward, done, info = env.step(action)
        score += reward

        # done = terminated or truncated
    print('Episode:{} Score:{}'.format(episode, score))
#env.close()

Episode:1 Score:[348.]
Episode:2 Score:[419.]
Episode:3 Score:[500.]
Episode:4 Score:[480.]
Episode:5 Score:[345.]


In [20]:
obs = env.reset()

In [21]:
model.predict(obs)

(array([1], dtype=int64), None)

In [22]:
action, _ = model.predict(obs)

In [23]:
env.action_space.sample()

1

In [24]:
env.step(action)

(array([[ 0.02470808, -0.21190622, -0.01354875,  0.26632926]],
       dtype=float32),
 array([1.], dtype=float32),
 array([False]),
 [{'TimeLimit.truncated': False}])