# Q-Learning

## Q-Learning Agent

In [1]:
import gymnasium as gym
import json
import random
import highway_env

import numpy as np
from collections import defaultdict
import os
from tqdm import tqdm


import sys
sys.path.append(os.path.abspath('..'))
from metrics import Metrics


class QLearningAgent:
    def __init__(self, env, params):
        self.env = env
 
        self.exploration_rate = params.get("exploration_rate", 0.3)
        self.q_table = defaultdict()
        self.q_table_path = "q_table.json"
        self.load_q_table()
        self.action_space = env.action_space.n
        use_metrics = params.get("use_metrics", False)
        self.leaning_rate = params.get("learning_rate", 0.1)

        self.discount_factor = params.get("gamma", 0.9) # Discount Factor
        self.episode_num = params.get("episode_num", 100)
        self.metrics = Metrics("value_iteration", "training_results", use_metrics)


    def choose_action(self, state):
        if random.random() < self.exploration_rate:
            return self.env.action_space.sample()  # Explore: random action
        else:
            try:
                return str(max(self.q_table[state], key = self.q_table[state].get))
            except:
                return self.env.action_space.sample()


    def train(self):
        total_rewards = []
        steps_per_episode = []
        average_rewards = []
        for episode in tqdm(range(self.episode_num), desc="Training Agent"):
            state = str(self.env.reset()[0])  # Convert state to string for indexing
            done = False
            truncated = False
            total_reward = 0
            steps = 0
            while not done and not truncated:
                steps += 1
                action = str(self.choose_action(state))
                next_obs, reward, done, truncated, info = self.env.step(action)
                next_state = str(next_obs)

                if state not in self.q_table:
                    self.q_table[state] = {str(i): 0 for i in range(0, self.action_space)}


                if next_state not in self.q_table:
                    self.q_table[next_state] = {str(i): 0 for i in range(0, self.action_space)}

                best_next_action = str(max(self.q_table[next_state], key = self.q_table[next_state].get))
                
                self.q_table[state][action] = self.q_table[state][action] + self.leaning_rate * (reward + self.discount_factor * self.q_table[next_state][best_next_action] - self.q_table[state][action])

                state = next_state

                total_reward += reward
            

            total_rewards.append(total_reward)
            steps_per_episode.append(steps)
            average_rewards.append(np.mean(total_rewards)) 

            self.exploration_rate = max(0.01, self.exploration_rate * 0.995)


            self.metrics.add("rollout/rewards", sum(total_rewards) / len(total_rewards), episode)
            self.metrics.add("rollout/steps", sum(steps_per_episode) / len(steps_per_episode), episode)

            self.metrics.add("rollout/episode-length", steps, episode)


        self.save_q_table()
        self.metrics.close()

    def evaluate(self, episodes = 10):
        for _ in tqdm(range(episodes), desc="Evaluating Agent"):
            state = str(self.env.reset()[0])
            done = False
            truncated = False
            total_reward = 0
            step = 0
            
            while not done and not truncated:
                step += 1
                
                try:
                    action = str(max(self.q_table[state], key = self.q_table[state].get))
                except:
                    action = self.env.action_space.sample()

                next_obs, reward, done, truncated, _ = self.env.step(action)
                state = str(next_obs)
                total_reward += reward
                if total_reward > 50:
                    break
                self.env.render()

    def load_q_table(self):
        if os.path.exists(self.q_table_path):
            try:
                with open(self.q_table_path, 'r') as file:
                    loaded =  json.load(file)
                    self.q_table.update(loaded)
                    print("Q-table loaded successfully.")
            except Exception as e:
                print(f"Error loading Q-table: {e}")

    def save_q_table(self):
        try:
            with open(self.q_table_path, 'w') as file:  
                json.dump(self.q_table, file, indent=4)
                print("Q-table saved successfully.")
        except Exception as e:
            print(f"Error saving Q-table: {e}")

## Agent Training and Initialization

In [None]:
config = {
    "lanes_count": 3,
    "observation": {
        "type": "TimeToCollision",
        "horizon": 5,
    }}

env = gym.make("highway-fast-v0", render_mode="rgb_array", config=config)

params = {
    "use_metrics": True,
    "episode_num": 1,
    "gamma": 0.9, # Discount Factor
    "exploration_rate": 0.3,
    "learning_rate": 0.1,
}

agent = QLearningAgent(env, params=params)
agent.train()

## Agent Evaluation

In [None]:
env = gym.make("highway-v0", render_mode="rgb_array",   config=config)
agent = QLearningAgent(env, params=params)
agent.evaluate(20)

## Run Tensorboard

In [None]:
%reload_ext tensorboard

%tensorboard --logdir training_results --host localhost --port 6011