In [2]:
 # Load the environment and the configuration

import gymnasium as gym
import highway_env
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt

# Load the configuration from the pickle file
with open("config.pkl", "rb") as f:
    config = pickle.load(f)

# Create and configure the environment
env = gym.make("highway-fast-v0", render_mode="rgb_array")
env.unwrapped.configure(config)

# Initialize the random seed
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


<torch._C.Generator at 0x2f84fe928d0>

In [3]:
# Define the Q-Network

class QNetwork(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


In [4]:
# Implement Experience Replay

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

    def size(self):
        return len(self.buffer)


In [5]:
# Define the DQN Agent
# The agent will use the Q-Network and the ReplayBuffer to learn from experiences


class DQNAgent:
    def __init__(self, state_dim, action_dim, replay_buffer, batch_size, gamma=0.99, epsilon=1.0, epsilon_min=0.01, epsilon_decay=0.995, lr=1e-3):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma  # Discount factor
        self.epsilon = epsilon  # Exploration rate
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.batch_size = batch_size

        # Initialize Q-network and target network
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())  # Initialize target network

        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)
        self.replay_buffer = replay_buffer

    def select_action(self, state):
        # Epsilon-greedy action selection
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)  # Random action
        else:
            state = torch.FloatTensor(state).unsqueeze(0)
            q_values = self.q_network(state)
            return torch.argmax(q_values).item()  # Best action according to the Q-network

    def update_q_values(self):
        # Sample a batch from the replay buffer
        if self.replay_buffer.size() < self.batch_size:
            return

        batch = self.replay_buffer.sample(self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # Compute Q-values
        current_q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        next_q_values = self.target_network(next_states).max(1)[0]
        target_q_values = rewards + (self.gamma * next_q_values * (1 - dones))

        # Compute loss
        loss = nn.MSELoss()(current_q_values, target_q_values)

        # Backpropagate
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Update epsilon (for exploration-exploitation)
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def update_target_network(self):
        # Periodically update the target network
        self.target_network.load_state_dict(self.q_network.state_dict())


In [6]:
# Training the DQN Agent
def train_dqn(episodes=1000, max_steps=200):
    replay_buffer = ReplayBuffer(10000)
    agent = DQNAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n, replay_buffer=replay_buffer, batch_size=64)
    
    rewards = []
    for episode in range(episodes):
        state, info = env.reset()
        episode_reward = 0

        for step in range(max_steps):
            action = agent.select_action(state)
            next_state, reward, terminated, truncated, info = env.step(action)

            replay_buffer.push(state, action, reward, next_state, terminated or truncated)

            agent.update_q_values()
            if step % 10 == 0:
                agent.update_target_network()

            state = next_state
            episode_reward += reward

            if terminated or truncated:
                break

        rewards.append(episode_reward)
        print(f"Episode {episode}/{episodes}, Reward: {episode_reward}, Epsilon: {agent.epsilon:.4f}")

    return rewards

# Train the DQN agent
rewards = train_dqn(episodes=100)


: 

In [None]:
# vizualize the training performance
plt.plot(rewards)
plt.xlabel('Episodes')
plt.ylabel('Cumulative Reward')
plt.title('DQN Training Performance')
plt.show()