<a href="https://colab.research.google.com/github/Purushotham-Mani/CS238/blob/main/HighwayAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!pip install highway-env
!pip install tensorboardx gym pyvirtualdisplay
!apt-get install -y xvfb ffmpeg
# Environment
import gymnasium as gym
import highway_env

gym.register_envs(highway_env)

# Models and computation
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import namedtuple, deque

# Visualization
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import trange
import base64
from pathlib import Path

from gymnasium.wrappers import RecordVideo
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display

Collecting highway-env
  Downloading highway_env-1.10.1-py3-none-any.whl.metadata (16 kB)
Collecting gymnasium>=1.0.0a2 (from highway-env)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from highway-env)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading highway_env-1.10.1-py3-none-any.whl (104 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 kB[0m [31m932.6 kB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m15.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: farama-notifications, gymnasium, highway-env
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0 highway-env-1.10.1
Collecting tensorboardx
  Downloading tensorboardX-2.6.2.2-py2.py3-

In [3]:
display = Display(visible=0, size=(1400, 900))
display.start()

def show_videos(path="videos"):
    html = []
    for mp4 in Path(path).glob("*.mp4"):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            """<video alt="{}" autoplay
                      loop controls style="height: 400px;">
                      <source src="data:video/mp4;base64,{}" type="video/mp4" />
                 </video>""".format(
                mp4, video_b64.decode("ascii")
            )
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))

In [8]:
class Encoder(nn.Module):
  def __init__(self, input_dim, embed_dim):
    super(Encoder, self).__init__()
    self.fc1 = nn.Linear(input_dim, embed_dim)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    return x

In [9]:
class EgoAttention(nn.Module):
  def __init__(self, embed_dim):
    super(EgoAttention, self).__init__()
    self.Q = nn.Linear(embed_dim, embed_dim)
    self.K = nn.Linear(embed_dim, embed_dim)
    self.V = nn.Linear(embed_dim, embed_dim)

  def forward(self, ego_embed, vehicle_embeds):
    Q = self.Q(ego_embed).unsqueeze(1)
    K = self.K(vehicle_embeds)
    V = self.V(vehicle_embeds)

    scores = torch.matmul(Q, K.transpose(-2, -1))/np.sqrt(K.size(-1))
    attention_weights = F.softmax(scores, dim=-1)
    output = torch.matmul(attention_weights, V)

    return output.squeeze(1), attention_weights

In [10]:
class AttentionDQN(nn.Module):
  def __init__(self, input_dim, embed_dim, output_dim, lr):
    super(AttentionDQN, self).__init__()
    self.encoder = Encoder(input_dim, embed_dim)
    self.ego_attention = EgoAttention(embed_dim)
    self.f = nn.Linear(embed_dim, output_dim)

    self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
    self.loss = nn.MSELoss()
    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(self.device)
    self.to(self.device)

  def forward(self, ego_state, vehicle_states):
    ego_embed = self.encoder(ego_state)
    vehicle_embeds = self.encoder(vehicle_states)
    ego_attention, attention_weights = self.ego_attention(ego_embed, vehicle_embeds)
    q_values = self.f(ego_attention)
    return q_values, attention_weights

In [28]:
class Agent():
  def __init__(self, input_dim, embed_dim, output_dim, lr = 1e-4, gamma = 0.95, epsilon = 1, max_mem_size=100000, eps_end=0.01, eps_dec=5e-4):
    self.gamma = gamma
    self.epsilon = epsilon
    self.eps = 1
    self.eps_min = eps_end
    self.eps_dec = eps_dec
    self.mem_size = max_mem_size
    self.batch_size = 64
    self.action_space = [i for i in range(output_dim)]

    self.Q_fn = AttentionDQN(input_dim[1], embed_dim, output_dim, lr)

    self.mem_cntr = 0
    self.state_memory = np.zeros((self.mem_size, *input_dim), dtype=np.float32) #have to adapt this
    self.new_state_memory = np.zeros((self.mem_size, *input_dim), dtype=np.float32)
    self.action_memory = np.zeros(self.mem_size, dtype=np.int32)
    self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
    self.terminal_memory = np.zeros(self.mem_size, dtype=bool)

  def store_transition(self, state, action, reward, state_, done):
    index = self.mem_cntr % self.mem_size

    self.state_memory[index] = state
    self.new_state_memory[index] = state_
    self.reward_memory[index] = reward
    self.action_memory[index] = action
    self.terminal_memory[index] = done

    self.mem_cntr += 1

  def choose_action(self, ego_features, vehicle_features):
    if np.random.random() > self.epsilon:
      ego = torch.tensor(ego_features).to(self.Q_fn.device).unsqueeze(0)
      vehicles = torch.tensor(vehicle_features).to(self.Q_fn.device).unsqueeze(0)
      q_values, _ = self.Q_fn(ego, vehicles)
      action = torch.argmax(q_values).item()
    else:
      action = np.random.choice(self.action_space)
    return action

  def train(self):
    if self.mem_cntr < self.batch_size:
      return

    self.Q_fn.optimizer.zero_grad()
    max_mem = min(self.mem_cntr, self.mem_size)
    batch = np.random.choice(max_mem, self.batch_size, replace=False)

    batch_index = np.arange(self.batch_size, dtype=np.int32)

    state_batch = torch.tensor(self.state_memory[batch]).to(self.Q_fn.device)
    new_state_batch = torch.tensor(self.new_state_memory[batch]).to(self.Q_fn.device)
    reward_batch = torch.tensor(self.reward_memory[batch]).to(self.Q_fn.device)
    terminal_batch = torch.tensor(self.terminal_memory[batch]).to(self.Q_fn.device)
    action_batch = self.action_memory[batch]

    q_eval, _ = self.Q_fn.forward(state_batch[:,0,:],state_batch) ##adapt
    q_eval = q_eval[batch_index, action_batch]
    q_next, _ = self.Q_fn.forward(new_state_batch[:,0,:],new_state_batch)
    q_next[terminal_batch] = 0.0

    q_target = reward_batch + self.gamma * torch.max(q_next, dim=1)[0]

    loss = self.Q_fn.loss(q_target, q_eval).to(self.Q_fn.device)
    loss.backward()
    self.Q_fn.optimizer.step()

    self.epsilon = self.epsilon - self.eps_dec if self.epsilon > self.eps_min else self.eps_min

In [None]:
env = gym.make("highway-v0")
# Default Config: Observation: Kinematics, Actions: DiscreteMetaAction
# ACTIONS_ALL = {0: 'LANE_LEFT',1: 'IDLE',2: 'LANE_RIGHT',3: 'FASTER',4: 'SLOWER'}

agent = Agent(input_dim=env.observation_space.shape, embed_dim = 128, output_dim=env.action_space.n)
scores, eps_history = [], []
n_games = 50
if Path("./weights").exists():
  agent.Q_fn.load_state_dict(torch.load("./weights", weights_only=True))
for i in range(n_games):
  done = False
  score = 0
  observation, info = env.reset()
  while not done:
    action = agent.choose_action(observation[0,:],observation)
    observation_, reward, done, truncated, info = env.step(action)
    score += reward
    agent.store_transition(observation, action, reward, observation_, done)
    agent.train()
    # torch.save(agent.Q_fn.state_dict(),"./weights" )
    observation = observation_
  scores.append(score)
  eps_history.append(agent.epsilon)

  avg_score = np.mean(scores[-100:])
  print('episode ', i, 'score %.2f' % score, 'average score %.2f' % avg_score,'epsilon %.2f' % agent.epsilon)


cpu
episode  0 score 4.57 average score 4.57 epsilon 1.00
episode  1 score 18.93 average score 11.75 epsilon 1.00
episode  2 score 25.34 average score 16.28 epsilon 1.00
episode  3 score 2.79 average score 12.91 epsilon 1.00
episode  4 score 41.73 average score 18.67 epsilon 0.97
episode  5 score 3.73 average score 16.18 epsilon 0.97
episode  6 score 1.91 average score 14.14 epsilon 0.97
episode  7 score 13.35 average score 14.04 epsilon 0.96
episode  8 score 13.48 average score 13.98 epsilon 0.95
episode  9 score 15.37 average score 14.12 epsilon 0.94
episode  10 score 18.77 average score 14.54 epsilon 0.93
episode  11 score 2.60 average score 13.55 epsilon 0.93
episode  12 score 1.89 average score 12.65 epsilon 0.93
episode  13 score 9.85 average score 12.45 epsilon 0.92
episode  14 score 6.97 average score 12.08 epsilon 0.91
episode  15 score 1.66 average score 11.43 epsilon 0.91
episode  16 score 8.93 average score 11.29 epsilon 0.91
episode  17 score 7.24 average score 11.06 epsil