<a href="https://colab.research.google.com/github/Purushotham-Mani/CS238/blob/main/HighwayAttention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install highway-env
!pip install tensorboardx gym pyvirtualdisplay
!apt-get install -y xvfb ffmpeg
# Environment
import gymnasium as gym
import highway_env

gym.register_envs(highway_env)

# Models and computation
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from collections import namedtuple, deque

# Visualization
import matplotlib.pyplot as plt
%matplotlib inline

from tqdm.notebook import trange
import base64
from pathlib import Path

from gymnasium.wrappers import RecordVideo
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display

Collecting highway-env
  Downloading highway_env-1.10.1-py3-none-any.whl.metadata (16 kB)
Collecting gymnasium>=1.0.0a2 (from highway-env)
  Downloading gymnasium-1.0.0-py3-none-any.whl.metadata (9.5 kB)
Collecting farama-notifications>=0.0.1 (from highway-env)
  Downloading Farama_Notifications-0.0.4-py3-none-any.whl.metadata (558 bytes)
Downloading highway_env-1.10.1-py3-none-any.whl (104 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m105.0/105.0 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading Farama_Notifications-0.0.4-py3-none-any.whl (2.5 kB)
Downloading gymnasium-1.0.0-py3-none-any.whl (958 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m958.1/958.1 kB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: farama-notifications, gymnasium, highway-env
Successfully installed farama-notifications-0.0.4 gymnasium-1.0.0 highway-env-1.10.1
Collecting tensorboardx
  Downloading tensorboardX-2.6.2.2-py2.py3-no

In [2]:
display = Display(visible=0, size=(1400, 900))
display.start()

def show_videos(path="videos"):
    html = []
    for mp4 in Path(path).glob("*.mp4"):
        video_b64 = base64.b64encode(mp4.read_bytes())
        html.append(
            """<video alt="{}" autoplay
                      loop controls style="height: 400px;">
                      <source src="data:video/mp4;base64,{}" type="video/mp4" />
                 </video>""".format(
                mp4, video_b64.decode("ascii")
            )
        )
    ipythondisplay.display(ipythondisplay.HTML(data="<br>".join(html)))

In [3]:
class Encoder(nn.Module):
  def __init__(self, input_dim, embed_dim):
    super(Encoder, self).__init__()
    self.fc1 = nn.Linear(input_dim, embed_dim)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    return x

In [8]:
class EgoAttention(nn.Module):
  def __init__(self, embed_dim, num_heads):
    super(EgoAttention, self).__init__()
    self.Q = nn.Linear(embed_dim, embed_dim)
    self.K = nn.Linear(embed_dim, embed_dim)
    self.V = nn.Linear(embed_dim, embed_dim)
    self.fe = nn.Linear(embed_dim, embed_dim)

    self.embed_dim = embed_dim
    self.num_heads = num_heads
    self.head_dim = embed_dim // num_heads

  def forward(self, ego_embed, vehicle_embeds):

    batch_size = ego_embed.size(0)

    # Linear transformations for query, key, value
    Q = self.Q(ego_embed).view(batch_size, 1, self.num_heads, self.head_dim)
    K = self.K(vehicle_embeds).view(batch_size, -1, self.num_heads, self.head_dim)
    V = self.V(vehicle_embeds).view(batch_size, -1, self.num_heads, self.head_dim)

    # Reshape to (batch_size, num_heads, seq_len, head_dim)
    Q = Q.permute(0, 2, 1, 3)  # (batch_size, num_heads, 1, head_dim)
    K = K.permute(0, 2, 1, 3)  # (batch_size, num_heads, num_vehicles, head_dim)
    V = V.permute(0, 2, 1, 3)  # (batch_size, num_heads, num_vehicles, head_dim)

    # Scaled Dot-Product Attention
    attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
    attention_weights = F.softmax(attention_scores, dim=-1)
    attention_output = torch.matmul(attention_weights, V)  # (batch_size, num_heads, 1, head_dim)

    # Concatenate heads and project
    attention_output = attention_output.permute(0, 2, 1, 3).contiguous()
    attention_output = attention_output.view(batch_size, -1, self.embed_dim)  # (batch_size, 1, embed_dim)
    output = self.fe(attention_output.squeeze(1))

    return output, attention_weights
    # Q = self.Q(ego_embed).unsqueeze(1)
    # K = self.K(vehicle_embeds)
    # V = self.V(vehicle_embeds)

    # scores = torch.matmul(Q, K.transpose(-2, -1))/np.sqrt(K.size(-1))
    # attention_weights = F.softmax(scores, dim=-1)
    # output = torch.matmul(attention_weights, V)

    # return output.squeeze(1), attention_weights

In [5]:
class AttentionDQN(nn.Module):
  def __init__(self, input_dim, embed_dim, num_heads, output_dim, lr):
    super(AttentionDQN, self).__init__()
    self.encoder = Encoder(input_dim, embed_dim)
    self.ego_attention = EgoAttention(embed_dim, num_heads)
    self.f = nn.Linear(embed_dim, output_dim)

    self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
    self.loss = nn.MSELoss()
    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print(self.device)
    self.to(self.device)

  def forward(self, ego_state, vehicle_states):
    ego_embed = self.encoder(ego_state)
    vehicle_embeds = self.encoder(vehicle_states)
    ego_attention, attention_weights = self.ego_attention(ego_embed, vehicle_embeds)
    q_values = self.f(ego_attention)
    return q_values, attention_weights

In [6]:
class Agent():
  def __init__(self, input_dim, embed_dim, output_dim, num_heads = 2, lr = 1e-4, gamma = 0.95, epsilon = 1, max_mem_size=100000, eps_end=0.01, eps_dec=5e-4):
    self.gamma = gamma
    self.epsilon = epsilon
    self.eps = 1
    self.eps_min = eps_end
    self.eps_dec = eps_dec
    self.mem_size = max_mem_size
    self.batch_size = 64
    self.action_space = [i for i in range(output_dim)]

    self.Q_fn = AttentionDQN(input_dim[1], embed_dim, num_heads, output_dim, lr)

    self.mem_cntr = 0
    self.state_memory = np.zeros((self.mem_size, *input_dim), dtype=np.float32) #have to adapt this
    self.new_state_memory = np.zeros((self.mem_size, *input_dim), dtype=np.float32)
    self.action_memory = np.zeros(self.mem_size, dtype=np.int32)
    self.reward_memory = np.zeros(self.mem_size, dtype=np.float32)
    self.terminal_memory = np.zeros(self.mem_size, dtype=bool)

  def store_transition(self, state, action, reward, state_, done):
    index = self.mem_cntr % self.mem_size

    self.state_memory[index] = state
    self.new_state_memory[index] = state_
    self.reward_memory[index] = reward
    self.action_memory[index] = action
    self.terminal_memory[index] = done

    self.mem_cntr += 1

  def choose_action(self, ego_features, vehicle_features):
    if np.random.random() > self.epsilon:
      ego = torch.tensor(ego_features).to(self.Q_fn.device).unsqueeze(0)
      vehicles = torch.tensor(vehicle_features).to(self.Q_fn.device).unsqueeze(0)
      q_values, _ = self.Q_fn(ego, vehicles)
      action = torch.argmax(q_values).item()
    else:
      action = np.random.choice(self.action_space)
    return action

  def train(self):
    if self.mem_cntr < self.batch_size:
      return

    self.Q_fn.optimizer.zero_grad()
    max_mem = min(self.mem_cntr, self.mem_size)
    batch = np.random.choice(max_mem, self.batch_size, replace=False)

    batch_index = np.arange(self.batch_size, dtype=np.int32)

    state_batch = torch.tensor(self.state_memory[batch]).to(self.Q_fn.device)
    new_state_batch = torch.tensor(self.new_state_memory[batch]).to(self.Q_fn.device)
    reward_batch = torch.tensor(self.reward_memory[batch]).to(self.Q_fn.device)
    terminal_batch = torch.tensor(self.terminal_memory[batch]).to(self.Q_fn.device)
    action_batch = self.action_memory[batch]

    q_eval, _ = self.Q_fn.forward(state_batch[:,0,:],state_batch) ##adapt
    q_eval = q_eval[batch_index, action_batch]
    q_next, _ = self.Q_fn.forward(new_state_batch[:,0,:],new_state_batch)
    q_next[terminal_batch] = 0.0

    q_target = reward_batch + self.gamma * torch.max(q_next, dim=1)[0]

    loss = self.Q_fn.loss(q_target, q_eval).to(self.Q_fn.device)
    loss.backward()
    self.Q_fn.optimizer.step()

    self.epsilon = self.epsilon - self.eps_dec if self.epsilon > self.eps_min else self.eps_min

#Highway Environment

### Train

In [9]:
env = gym.make("highway-v0")
# Default Config: Observation: Kinematics, Actions: DiscreteMetaAction
# ACTIONS_ALL = {0: 'LANE_LEFT',1: 'IDLE',2: 'LANE_RIGHT',3: 'FASTER',4: 'SLOWER'}

agent = Agent(input_dim=env.observation_space.shape, embed_dim = 128, output_dim=env.action_space.n)
scores, eps_history = [], []
n_games = 50
if Path("./highway_weights").exists():
  agent.Q_fn.load_state_dict(torch.load("./highway_weights", weights_only=True))
for i in range(n_games):
  done = False
  score = 0
  observation, info = env.reset()
  while not done:
    action = agent.choose_action(observation[0,:],observation)
    observation_, reward, done, truncated, info = env.step(action)
    score += reward
    agent.store_transition(observation, action, reward, observation_, done)
    agent.train()
    # torch.save(agent.Q_fn.state_dict(),"./highway_weights" )
    observation = observation_
  scores.append(score)
  eps_history.append(agent.epsilon)

  avg_score = np.mean(scores[-100:])
  print('episode ', i, 'score %.2f' % score, 'average score %.2f' % avg_score,'epsilon %.2f' % agent.epsilon)


cpu
episode  0 score 12.03 average score 12.03 epsilon 1.00
episode  1 score 5.49 average score 8.76 epsilon 1.00
episode  2 score 1.67 average score 6.40 epsilon 1.00
episode  3 score 5.67 average score 6.21 epsilon 1.00
episode  4 score 30.27 average score 11.02 epsilon 1.00
episode  5 score 14.26 average score 11.56 epsilon 0.99
episode  6 score 6.02 average score 10.77 epsilon 0.98


KeyboardInterrupt: 

### Visualize

In [None]:
# wont run, have to edit

env = gym.make("highway-v0",render_mode="rgb_array")
# Default Config: Observation: Kinematics, Actions: DiscreteMetaAction
# ACTIONS_ALL = {0: 'LANE_LEFT',1: 'IDLE',2: 'LANE_RIGHT',3: 'FASTER',4: 'SLOWER'}
# Obeservations : Vehicle x y vx vy / first row is always ego vehicle
env = RecordVideo(env, video_folder="./videos", episode_trigger=lambda e: True)
observation, info = env.reset()
done = False
score = 0
while not done:
    state = torch.tensor(np.array(observation.reshape((1,-1)))).to(agent.Q_fn.device) ## may need to edit this
    actions = agent.Q_fn.forward(state)
    action = torch.argmax(actions).item()
    obs, reward, done, truncated, info = env.step(action)
    score += reward
    observation = obs
print('score %.2f' % score)
env.close()
show_videos()

#Intersection Environment

### Train turning left

In [None]:
env = gym.make("intersection-v0")
env.unwrapped.config["destination"] = "o3"   ## to decide where to turn

# Default Config: Observation: Kinematics, Actions: DiscreteMetaAction
# ACTIONS_ALL = {0: 'LANE_LEFT',1: 'IDLE',2: 'LANE_RIGHT',3: 'FASTER',4: 'SLOWER'}

agent = Agent(input_dim=env.observation_space.shape, embed_dim = 128, output_dim=env.action_space.n)
scores, eps_history = [], []
n_games = 500
if Path("./intersection_weights").exists():
  agent.Q_fn.load_state_dict(torch.load("./intersection_weights", weights_only=True))
for i in range(n_games):
  done = False
  score = 0
  observation, info = env.reset()
  while not done:
    action = agent.choose_action(observation[0,:],observation)
    observation_, reward, done, truncated, info = env.step(action)
    score += reward
    agent.store_transition(observation, action, reward, observation_, done)
    agent.train()
    # torch.save(agent.Q_fn.state_dict(),"./intersection_weights" )
    observation = observation_
  scores.append(score)
  eps_history.append(agent.epsilon)

  avg_score = np.mean(scores[-100:])
  print('episode ', i, 'score %.2f' % score, 'average score %.2f' % avg_score,'epsilon %.2f' % agent.epsilon)


### Visualize

In [None]:
# won't run have to edit

env = gym.make("intersection-v0",render_mode="rgb_array")
# Default Config: Observation: Kinematics, Actions: DiscreteMetaAction
# ACTIONS_ALL = {0: 'LANE_LEFT',1: 'IDLE',2: 'LANE_RIGHT',3: 'FASTER',4: 'SLOWER'}
# Obeservations : Vehicle x y vx vy / first row is always ego vehicle
env = RecordVideo(env, video_folder="./videos", episode_trigger=lambda e: True)
observation, info = env.reset()
done = False
score = 0
while not done:
  ego = torch.tensor(ego_features).to(self.Q_fn.device).unsqueeze(0)
  vehicles = torch.tensor(vehicle_features).to(self.Q_fn.device).unsqueeze(0)
  q_values, _ = self.Q_fn(ego, vehicles)
  action = torch.argmax(q_values).item()
  state = torch.tensor(np.array(observation.reshape((1,-1)))).to(agent.Q_fn.device) ## may need to edit this
  actions = agent.Q_fn.forward(state)
  action = torch.argmax(actions).item()
  obs, reward, done, truncated, info = env.step(action)
  score += reward
  observation = obs
print('score %.2f' % score)
env.close()
show_videos()