In [17]:
import gymnasium as gym
import grid_world as _
from grid_world.envs.grid_world_env_v2 import GridWorldEnv_v2
from grid_world.envs.grid_world_env_v3 import GridWorldEnv_v3

import matplotlib.pyplot as plt
from IPython import display
import imageio

from ray.rllib.algorithms.dqn.dqn import DQNConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.td3.td3 import TD3Config #TD3 only valid for continuous action spaces
from ray.tune.registry import register_env

In [2]:
ENV_VERSION = 'grid_world/GridWorld-v3'
RENDER_MODE = None
ENV_SHAPE = (25,25)
ALGORITHM = 'PPO'

env_classes = {
    'grid_world/GridWorld-v2': GridWorldEnv_v2,
    'grid_world/GridWorld-v3': GridWorldEnv_v3
}

env_config = {
    'env_class': env_classes[ENV_VERSION],
    'env_version': ENV_VERSION,
    'render_mode':RENDER_MODE, 
    'shape':ENV_SHAPE,
    }

In [3]:
# Register the environment with gym
env = gym.make(env_config['env_version'])
observation, info = env.reset()
print(env.observation_space)
env.close()
print(observation.shape)

Box(0, 10, (4,), int64)
(4,)


In [5]:
# def env_creator(env_config):

#     env = gym.make(
#         ENV_VERSION, 
#         env_config['render_mode'], 
#         env_config['shape']
#         )

#     return env

In [4]:
def env_creator(env_config):
    
    env = env_config['env_class']

    return env(env_config['render_mode'], env_config['shape'])

In [5]:
register_env(env_config['env_version'], env_creator)

In [6]:
if ALGORITHM == 'DQN':
    config = DQNConfig()

if ALGORITHM == 'PPO':
    config = PPOConfig()

if ALGORITHM == 'TD3':
    config = TD3Config()

# Define the environment
config = config.environment(env_config['env_version'])

# Set the environment configuration
config.env_config.update(env_config)

# Give access to the gpu
config = config.resources(num_gpus=1)

# Set the number of training roll out workers
config = config.rollouts(num_rollout_workers=4)

# Set the max number of episode steps
config.horizon = 30

# Set the number of evaluation workers
# config = config.evaluation(evaluation_num_workers=1)

# Set framework to pytorch
config = config.framework('torch')

if ALGORITHM == 'DQN':
    config.replay_buffer_config.update(
        {
            "capacity": 60000,
            "prioritized_replay_alpha": 0.5,
            "prioritized_replay_beta": 0.5,
            "prioritized_replay_eps": 3e-6,
            "prioritized_replay": True
        }
        )

    # Set the training configuration
    config = config.training(
        noisy=True,
        double_q=True,
        dueling=True
        )

In [7]:
if ALGORITHM == 'DQN':
    print(config.replay_buffer_config)
print(config.exploration_config)
print(config.env_config)


{'type': 'StochasticSampling'}
{'env_class': <class 'grid_world.envs.grid_world_env_v3.GridWorldEnv_v3'>, 'env_version': 'grid_world/GridWorld-v3', 'render_mode': None, 'shape': (25, 25)}


In [8]:
algo = config.build()

2023-06-20 14:24:59,815	INFO worker.py:1553 -- Started a local Ray instance.


In [9]:
def plot_metrics(rewards, lengths):
    fig = plt.figure(1, figsize=(16, 8))
    plt.clf()

    # plt.subplots(ncols=2, figsize=(12,6))
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)

    ax1.set_title('Mean Rewards')
    ax1.set_xlabel('Evaluation Interval')
    ax1.set_ylabel('Mean Reward')
    ax1.plot(rewards)

    ax2.set_title('Mean Episode Length')
    ax2.set_xlabel('Evaluation Interval')   
    ax2.set_ylabel('Episode Length')
    ax2.plot(lengths)
    
    display.display(plt.gcf())
    display.clear_output(wait=True)

In [10]:
rewards = []
lengths = []

In [11]:
policy = algo.get_policy()
weights = policy.get_weights()
for k,v in weights.items():
    print(k)
    print(v.shape)

_logits._model.0.weight
(4, 256)
_logits._model.0.bias
(4,)
_hidden_layers.0._model.0.weight
(256, 4)
_hidden_layers.0._model.0.bias
(256,)
_hidden_layers.1._model.0.weight
(256, 256)
_hidden_layers.1._model.0.bias
(256,)
_value_branch_separate.0._model.0.weight
(256, 4)
_value_branch_separate.0._model.0.bias
(256,)
_value_branch_separate.1._model.0.weight
(256, 256)
_value_branch_separate.1._model.0.bias
(256,)
_value_branch._model.0.weight
(1, 256)
_value_branch._model.0.bias
(1,)


In [12]:
TRAINING_ITERATIONS = 100
EVAL_INTERVAL = 5


for i in range(TRAINING_ITERATIONS):
    
    algo.train()

    if (i+1) % EVAL_INTERVAL == 0:
        print('Completed {} training intervals'.format(i+1))
        # metrics = algo.evaluate()['evaluation']
        # rewards.append(metrics['episode_reward_mean'])
        # lengths.append(metrics['episode_len_mean'])
        # plot_metrics(rewards, lengths)

Completed 5 training intervals
Completed 10 training intervals
Completed 15 training intervals
Completed 20 training intervals
Completed 25 training intervals
Completed 30 training intervals
Completed 35 training intervals
Completed 40 training intervals
Completed 45 training intervals
Completed 50 training intervals
Completed 55 training intervals
Completed 60 training intervals
Completed 65 training intervals
Completed 70 training intervals
Completed 75 training intervals
Completed 80 training intervals
Completed 85 training intervals
Completed 90 training intervals
Completed 95 training intervals
Completed 100 training intervals


In [13]:
# Register the environment with gym
env = gym.make(env_config['env_version'], shape=env_config['shape'])
observation, info = env.reset()
env.close()

In [14]:
observation.shape

(4,)

In [15]:
action = algo.compute_single_action(observation)
print(action)

2


In [16]:
env = gym.make(env_config['env_version'], render_mode="human", shape=env_config['shape'])
observation, info = env.reset(seed=42)

for _ in range(50):
   action = algo.compute_single_action(observation)
   observation, reward, terminated, truncated, info = env.step(action)

   if terminated or truncated:
      observation, info = env.reset()

env.close()

In [18]:
env = gym.make(env_config['env_version'], render_mode="rgb_array", shape=env_config['shape'])
observation, info = env.reset(seed=42)
images = []
images.append(env.render())

for _ in range(50):
   action = algo.compute_single_action(observation)
   observation, reward, terminated, truncated, info = env.step(action)
   images.append(env.render())

   if terminated or truncated:
      observation, info = env.reset()

env.close()
imageio.mimsave('./gifs/25-25-PPO.gif', images, fps = 5)      