In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gymnasium as gym
import highway_env

import wandb
import json
import warnings
import random
import time
# Suppress DeprecationWarning
warnings.filterwarnings('ignore', category=DeprecationWarning)

In [228]:
class DQNModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DQNModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, output_dim)
        )
    
    def forward(self, x):
        return self.fc(x)
    
class DQNAgent:
    def __init__(self, agent_config_dict):
        # Extract all configurations from the dictionary
        self.env = agent_config_dict["env"]
        self.model = agent_config_dict["model"]
        self.gamma = agent_config_dict.get("gamma", 0.99)  # Discount Factor
        self.epsilon = agent_config_dict.get("epsilon", 1.0)  # epsilon for epsilon-greedy
        self.epsilon_decay = agent_config_dict.get("epsilon_decay", 0.995)
        self.epsilon_min = agent_config_dict.get("epsilon_min", 0.1)
        self.batch_size = agent_config_dict.get("batch_size", 64)
        self.replay_capacity = agent_config_dict.get("replay_capacity", 10000)
        self.optimizer = optim.Adam(self.model.parameters(), lr=agent_config_dict.get("lr", 1e-3))

        self.render_scene = agent_config_dict.get("render_scene", False)
        self.logs = {}

        self.criterion = nn.MSELoss()
        self.replay_buffer = []
        self.replay_buffer_scenes = []

        self.wandb_log = agent_config_dict["wandb_log"]
        # Initialize wandb
        if self.wandb_log:
            wandb.init(project=agent_config_dict.get("wandb_project", "dqn_highway_gym"))

    def select_action(self, state):
        if np.random.rand() < self.epsilon:
            return self.env.action_space.sample()
        else:
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                return torch.argmax(self.model(state)).item()

    def store_transition(self, episode, step, state, action, reward, next_state, done,scene_arr,log):
        if len(self.replay_buffer) >= self.replay_capacity:
            self.replay_buffer.pop(0)
        self.replay_buffer.append((state, action, reward, next_state, done))


        if log == True:
            if episode in self.logs:
                self.logs[episode]["step"].append(step)
                self.logs[episode]["state"].append(state)
                self.logs[episode]["action"].append(action)
                self.logs[episode]["reward"].append(reward)
                self.logs[episode]["next_state"].append(next_state)
                self.logs[episode]["done"].append(done)
                self.logs[episode]["scene_arr"].append(scene_arr)

            else:
                self.logs[episode] = {"step":[],
                                    "state":[],
                                    "action":[],
                                    "reward":[],
                                    "next_state":[],
                                    "done":[],
                                    "scene_arr":[]}



    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return (None,None)

        batch = random.sample(self.replay_buffer, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.long).unsqueeze(1)
        rewards = torch.tensor(rewards, dtype=torch.float32)
        next_states = torch.tensor(next_states, dtype=torch.float32)
        dones = torch.tensor(dones, dtype=torch.float32)

        # Compute Q-values
        # print(states.shape)
        # print(actions.shape)
        q_values = self.model(states).gather(1, actions).squeeze(1)

        q_val_norm = 0
        # print("passed")
        # Compute target Q-values
        with torch.no_grad():
            next_q_values = self.model(next_states).max(1)[0]
            q_val_norm = (np.linalg.norm(next_q_values))
            targets = rewards + self.gamma * next_q_values * (1 - dones)

        # Loss
        loss = self.criterion(q_values, targets)

        # Optimize
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Decay epsilon
        self.epsilon = max(self.epsilon * self.epsilon_decay, self.epsilon_min)
        return (q_val_norm,-1)

    def play_episode(self, episode_num,train=True,log=False):
        state, info = self.env.reset()
        state_flattened = state.reshape(-1)

        total_reward = 0
        done = False
        step = 0
        q_val_norm = None
        loss = None

        while True:
            action = self.select_action(state_flattened)
            next_state, reward, done, truncated, _ = self.env.step(action)

            if self.render_scene:
                scene_arr = self.env.render()
            else:
                scene_arr = None
                
            state_flattened = state.reshape(-1)
            next_state_flattened = next_state.reshape(-1)
            self.store_transition(episode_num,step, state_flattened, action, reward, next_state_flattened, done or truncated,scene_arr,log)
            if train:
                q_val_norm,loss = self.train()
            state = next_state
            total_reward += reward
            step += 1

            if done or truncated:
                break

        # Logging


        if train:
            episode_log = {
                "train/total_reward": total_reward,
                "train/last_step_reward": reward,
                "train/episode_length": step,
                "train/epsilon":self.epsilon,
                "train/Q_val_norm":q_val_norm,
                "train/loss":loss
            }
        else:
            episode_log = {
                "test/total_reward": total_reward,
                "test/last_step_reward": reward,
                "test/episode_length": step
            }


        print("Episode: {:<3} | Episode Length: {:<3} | Epsilon: {:<3.3f} | Total Reward: {:<4.3f} | Last Step Reward: {} | loss: {} | Q val norm: {}".format(episode_num, step, self.epsilon,total_reward, reward,loss,q_val_norm))
        if self.wandb_log:
            wandb.log(episode_log)

        return total_reward


In [None]:

# Load configuration from file
with open('env_config_dict.txt', 'r') as f:
    env_config_dict = json.load(f)


# env = gym.make("intersection-v0", render_mode="rgb_array", config=env_config_dict)
env = gym.make("highway-v0", render_mode="rgb_array")

input_dim = env.observation_space.shape[0]*env.observation_space.shape[1] # Flattening the matrix
output_dim = env.action_space.n

# Create the model
model = DQNModel(input_dim, output_dim)

agent_config_dict = {
    "env": env,
    "model": model,
    "gamma": 0.99,
    "epsilon": 1.0,
    "epsilon_decay": 0.99,
    "epsilon_min": 0.1,
    "batch_size": 64,
    "replay_capacity": 10000,
    "lr": 1e-3,
    "wandb_log": True,
    "wandb_project": "RoadSense Project",
    "render_scene":False
}

agent = DQNAgent(agent_config_dict)

# # Training loop
num_train_episodes = 100
for episode in range(num_train_episodes):
    reward = agent.play_episode(episode_num = episode, train=True,log=False)

print()
print("Evaluation")

# Evaluation loop
num_eval_episodes = 10
for episode in range(num_eval_episodes):
    reward = agent.play_episode(episode_num = episode, train=False,log=False)
    # print(f"Evaluation Episode {episode + 1}: Total Reward: {reward}")

wandb.finish()


Episode: 0   | Episode Length: 40  | Epsilon: 1.000 | Total Reward: 28.656 | Last Step Reward: 0.7951727812171602 | loss: None | Q val norm: None
Episode: 1   | Episode Length: 8   | Epsilon: 1.000 | Total Reward: 6.390 | Last Step Reward: 0.06666666666666665 | loss: None | Q val norm: None
Episode: 2   | Episode Length: 6   | Epsilon: 1.000 | Total Reward: 4.515 | Last Step Reward: 0.09496034658703383 | loss: None | Q val norm: None
Episode: 3   | Episode Length: 4   | Epsilon: 1.000 | Total Reward: 2.757 | Last Step Reward: 0.09440737566895548 | loss: None | Q val norm: None
Episode: 4   | Episode Length: 12  | Epsilon: 0.932 | Total Reward: 8.617 | Last Step Reward: 0.04444444444444443 | loss: -1 | Q val norm: 2.7980899810791016
Episode: 5   | Episode Length: 26  | Epsilon: 0.718 | Total Reward: 19.016 | Last Step Reward: 0.0 | loss: -1 | Q val norm: 32.89784240722656
Episode: 6   | Episode Length: 40  | Epsilon: 0.480 | Total Reward: 30.337 | Last Step Reward: 0.7771992679041103 | 

# Visualisation 

In [None]:
# Reset the environment
def run_eps(env, num_eps):

    ep = 0
    step = 0
    observation, info = env.reset()
    rewards_eps = {}
    obs_eps = {}
    actions_eps = {}
    rendered_eps = {}
    length_eps = {}

    # Run for a fixed number of steps
    while ep<num_eps:
        reward_ep_list = []
        obs_ep_list = []
        action_ep_list = []
        rendered_ep_list = []

        while True:
            # Choose a random action
            # action =  env.action_space.sample()
            action = agent.get_action(observation.reshape(-1))
            
            # Take a step in the environment
            present_time = time.time()
            observation, reward, terminated, truncated, info = env.step(action)
            after_time = time.time()

            step_time_duration = after_time - present_time


            # Render the environment
            scene_arr = env.render()

            reward_ep_list.append(reward)
            obs_ep_list.append(observation)
            action_ep_list.append(action)
            rendered_ep_list.append(scene_arr)

            # Check if the episode is done
            if terminated or truncated:
                observation, info = env.reset()
                rewards_eps[ep] = reward_ep_list
                obs_eps[ep] = obs_ep_list
                actions_eps[ep] = action_ep_list
                rendered_eps[ep] = rendered_ep_list
                length_eps[ep] = step

                print("Episode: {} | Num Time Steps: {} | Terminated".format(ep,step))
                step = 0
                ep += 1
                break



            step +=1 
        # time.sleep(0.1)  # Delay for 0.1 seconds

    info = {"rewards_eps":rewards_eps,
            "obs_eps":obs_eps,
            "actions_eps":actions_eps,
            "rendered_eps":rendered_eps,
            "length_eps":length_eps}
    # Close the environment
    env.close()
    return info
