In [2]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from wrappers.joint_MPD_env import JointMDPEnv
from policies.Q_learner import Q_learner

from environments.road_env import RoadEnvironment
from environments.config.environment_presets import small_environment_dict, smallest_environment_dict

### Define the environment

In [4]:
# env = RoadEnvironment(**small_environment_dict)

env = RoadEnvironment(**smallest_environment_dict)

env = JointMDPEnv(env)

num_joint_states: 336
num_joint_actions: 9


Baseline: Do-nothing action

In [5]:
NUM_EPISODES = 1_000

store_do_nothing_rewards = np.zeros(NUM_EPISODES)

for episode in range(NUM_EPISODES):

    _ = env.reset()
    action = 0 # do-nothing actions
    done = False

    total_reward = 0

    while not done:

        _, reward, done, info = env.step(action)

        total_reward += reward

    store_do_nothing_rewards[episode] = total_reward

print(f'Mean reward over {NUM_EPISODES} episodes: {np.mean(store_do_nothing_rewards)}')

Mean reward over 1000 episodes: -606.6205120000001


## Q Learning

In [6]:
q_learning_agent = Q_learner(env,
                            num_episodes=20_000,
                            discount_factor=1,
                            lr_start=1,
                            lr_end=0.3,
                            epsilon_start=0.1,
                            epsilon_end=0.01,)

Q_values, policy = q_learning_agent.train(verbose=False)

## Inference with optimal policy

In [7]:
num_inference = 1_000
store_rewards = np.zeros(num_inference)

for ep in range(num_inference):

    state = env.reset()

    done = False

    total_reward = 0

    while not done:

        action = policy(state)

        state, reward, done, info = env.step(action)

        total_reward += reward

    store_rewards[ep] = total_reward

In [8]:
print(f'Mean reward over {num_inference} episodes: {np.mean(store_rewards):.3f}')

print(f'Improvement: {np.mean(store_rewards) - np.mean(store_do_nothing_rewards):.3f}')

Mean reward over 1000 episodes: -121.158
Improvement: 485.463


## Optimal Policy

In [9]:
print("State ID: State | Optimal action(s) \n")

for idx, state in enumerate(env.joint_state_space):
    optimal_action = env.decode_action(policy(idx))
    print(f"{idx:3d}: {state} | {tuple(optimal_action)}")

State ID: State | Optimal action(s) 

  0: (0, (0, 0)) | ([0], [0])
  1: (0, (0, 1)) | ([2], [0])
  2: (0, (0, 2)) | ([2], [3])
  3: (0, (0, 3)) | ([2], [0])
  4: (0, (1, 0)) | ([0], [0])
  5: (0, (1, 1)) | ([0], [3])
  6: (0, (1, 2)) | ([2], [3])
  7: (0, (1, 3)) | ([0], [3])
  8: (0, (2, 0)) | ([2], [2])
  9: (0, (2, 1)) | ([2], [0])
 10: (0, (2, 2)) | ([0], [0])
 11: (0, (2, 3)) | ([3], [3])
 12: (0, (3, 0)) | ([3], [0])
 13: (0, (3, 1)) | ([0], [0])
 14: (0, (3, 2)) | ([3], [0])
 15: (0, (3, 3)) | ([3], [3])
 16: (1, (0, 0)) | ([0], [0])
 17: (1, (0, 1)) | ([0], [2])
 18: (1, (0, 2)) | ([3], [0])
 19: (1, (0, 3)) | ([2], [0])
 20: (1, (1, 0)) | ([0], [0])
 21: (1, (1, 1)) | ([2], [2])
 22: (1, (1, 2)) | ([3], [2])
 23: (1, (1, 3)) | ([0], [0])
 24: (1, (2, 0)) | ([2], [0])
 25: (1, (2, 1)) | ([2], [2])
 26: (1, (2, 2)) | ([3], [3])
 27: (1, (2, 3)) | ([2], [3])
 28: (1, (3, 0)) | ([0], [3])
 29: (1, (3, 1)) | ([3], [3])
 30: (1, (3, 2)) | ([0], [3])
 31: (1, (3, 3)) | ([0], [3])
 3