In [None]:
%pip install highway-env

In [None]:
import gymnasium as gym
import highway_env
import numpy as np
import matplotlib.pyplot as plt

env = gym.make(
    "highway-fast-v0",
    render_mode="rgb_array",
    config={
        "action": {
            "type": "DiscreteMetaAction",
        },
        "observation": {
            "type": "LidarObservation",
            "cells": 4,
        },
        "vehicles_count": 4
    },
)
obj, info = env.reset()

learning_rate = 0.2
discount = 0.95
epochs = 10
episodes = 100
noise_eps = 1
accuracy = 1

rewards = []

obj

In [None]:
import pickle

# Load q-table from file else create a new one
qtable = dict()
try:
    with open("qtable.pickle", "rb") as f:
        qtable = pickle.load(f)
except:
    pass

qtable

In [None]:
def generalizer(state):
    # Round all elements in the state
    state = np.round(state, accuracy)
    sorted_state = np.sort(state, axis=0)
    return tuple(sorted_state)

def hash_state(state):
    return str(generalizer(state))

def update(state, action, nextState, reward):
    hstate = hash_state(state)
    hnextState = hash_state(nextState)
    if hstate not in qtable:
        qtable[hstate] = np.zeros(env.action_space.n)
    if hnextState not in qtable:
        qtable[hnextState] = np.zeros(env.action_space.n)
    newQvalue = qtable[hstate][action] + learning_rate * (reward + discount * max(qtable[hnextState]) - qtable[hstate][action])
    qtable[hstate][action] = newQvalue

def getQ(state, action):
    hstate = hash_state(state)
    if hstate not in qtable:
        qtable[hstate] = np.zeros(env.action_space.n)
    return qtable[hstate][action]

def _bestVA(state):
    hstate = hash_state(state)
    if not hstate in qtable:
        qtable[hstate] = np.zeros(env.action_space.n)
    return max([(a, qtable[hstate][a]) for a in range(1, env.action_space.n)], key=lambda x: x[1])

def getValue(state):
    return _bestVA(state)[1]

def getPolicy(state):
    return _bestVA(state)[0]

def getAction(state):
    if np.random.rand() < noise_eps:
        return env.action_space.sample()
    return getPolicy(state)

In [None]:
# Train the agent
for epoch in range(epochs):
    for episode in range(episodes):
        oobj, info = env.reset(seed=episode)
        done = False
        truncated = False
        reward = 0
        while not truncated and not done:
            action = getAction(oobj)
            nobj, reward, done, truncated, info = env.step(action)
            update(obj, action, nobj, reward)
            oobj = nobj
            # env.render()

        rewards.append(reward)
        print("Epoch: ", epoch, "\tEpisode: ", episode, "\tReward: ", reward)
        # Save the Q table
        with open("qtable.pickle", "wb") as f:
            pickle.dump(qtable, f)

plt.plot(rewards)
plt.show()

In [None]:
# qtable = dict()
# try:
#     with open("qtable.pickle", "rb") as f:
#         qtable = pickle.load(f)
# except:
#     pass

# print(qtable)

# Test the agent
test_runs = 50
test_rewards = []
for i in range(test_runs):
    oobj, info = env.reset(seed=i)
    done = False
    truncated = False
    reward = 0

    while not truncated and not done:
        action = getPolicy(oobj)
        nobj, reward, done, truncated, info = env.step(action)
        env.render()
        oobj = nobj

    test_rewards.append(reward)

plt.plot(test_rewards)
plt.show()