In [1]:
import random
from collections import deque
import torch
from env import HighwayEnv, convert_highd_sample_to_gail_expert
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from buffer import HighwayEnvMemoryBuffer

In [9]:
tracks = pd.read_csv(r"D:\Productivity\Projects\High-D\highd-dataset-v1.0\data\26_tracks.csv")
samples = tracks[(tracks.frame>=300*25)&(tracks.frame<=480*25)]
samples.to_csv("./data/26_sample_tracks.csv")

In [10]:
buffer = HighwayEnvMemoryBuffer(300)

expert_data, df = convert_highd_sample_to_gail_expert(
    sample_csv=r"./data/26_sample_tracks.csv",
    meta_csv=r"D:\Productivity\Projects\High-D\highd-dataset-v1.0\data\26_recordingMeta.csv",
    forward=False,
    p_agent=0.90
)

In [11]:
NUM_STEPS = 300

In [12]:
# Create the environment and set expert data.
env = HighwayEnv(dt=0.2, T=50, generation_mode=False, demo_mode=True)
# Uncomment and update the following line when expert_data is available:
env.set_expert_data(expert_data)


In [13]:
# 
obs = env.reset() 
#
for step in trange(NUM_STEPS):
    # For demonstration, sample random actions for each vehicle slot.
    # Action shape: (N_max, 2). Since have set DEMO_MODE=TRUE, the actions will be overwritten by the agents
    action = torch.full((env.N_max, 2), 0.0)
    # Step the environment: we get new observation, reward, done, and info.
    next_obs, reward, done, info = env.step(action)

    buffer.push(obs, action, reward, next_obs, done)


100%|██████████| 300/300 [00:04<00:00, 65.83it/s]


In [14]:
obs, action, reward, next_obs, done = buffer.sample(64)

In [15]:
obs[0].shape

torch.Size([64, 50, 100, 7])

In [16]:
obs[1].shape

torch.Size([64, 10])

In [17]:
obs[2].shape

torch.Size([64, 2])

In [18]:
obs[3].shape

torch.Size([64, 100])

In [61]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import torch.optim as optim

In [51]:
class ActorNetwork(nn.Module):
    def __init__(self, T, M, F, lstm_hidden=64, global_dim=12, combined_hidden=64, output_size=2):
        """
        Actor network for generating vehicle acceleration decisions.
        
        Parameters:
          T (int): History length (number of time steps).
          M (int): Maximum number of vehicles per sample.
          F (int): Number of features per vehicle per timestep (e.g. [x, y, xVelocity, yVelocity]).
          lstm_hidden (int): Hidden dimension for the LSTM.
          global_dim (int): Dimension for processing the global inputs (lane markers + road boundaries).
          combined_hidden (int): Hidden dimension in the combined FC layers.
          output_size (int): Dimension of the acceleration output (typically 2).
        """
        super(ActorNetwork, self).__init__()
        self.T = T
        self.M = M
        self.F = F
        self.lstm_hidden = lstm_hidden
        
        # LSTM to process each vehicle's time-dependent sequence (shape: (T, F)).
        self.lstm = nn.LSTM(input_size=F, hidden_size=lstm_hidden, batch_first=True)
        
        # Global information comes from lane markers (10) and road boundaries (2) => 12 values.
        self.global_fc = nn.Linear(12, global_dim)
        
        # Combined fully connected layers mapping concatenated per-vehicle and global features to the output.
        self.combine_fc = nn.Sequential(
            nn.Linear(lstm_hidden + global_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, output_size)
        )
    
    def forward(self, time_dep, lane_markers, boundaries):
        """
        Forward pass.
        
        Parameters:
          time_dep (torch.Tensor): Time-dependent kinematics, shape (N, T, M, F).
            (Missing entries are NaN; they are replaced by 0 before processing.)
          lane_markers (torch.Tensor): Global lane markers, shape (N, 10). May contain NaNs.
          boundaries (torch.Tensor): Global road boundaries, shape (N, 2). May contain NaNs.
          
        Returns:
          accelerations (torch.Tensor): Output accelerations of shape (N, M, 2).
        """
        N, T, M, F = time_dep.shape  # Unpack dimensions.
        
        # --- Process Time-Dependent Kinematics ---
        # Rearrange to shape (N, M, T, F), then flatten the N and M dimensions: (N*M, T, F).
        time_dep = time_dep.permute(0, 2, 1, 3).contiguous()
        time_dep = time_dep.view(N * M, T, F)
        
        # Compute valid mask for each sequence (valid if not all features are NaN)
        valid_mask = ~torch.all(torch.isnan(time_dep), dim=-1)  # shape: (N*M, T)
        seq_lengths = valid_mask.sum(dim=1)  # shape: (N*M,)
        # Ensure a minimum length of 1 for any sequence.
        seq_lengths[seq_lengths == 0] = 1
        
        # Replace NaN values in the kinematics with 0.
        time_dep_clean = torch.nan_to_num(time_dep, nan=0.0)
        
        # Pack the sequence so that the LSTM ignores padded time steps.
        packed = rnn_utils.pack_padded_sequence(time_dep_clean, lengths=seq_lengths.cpu(),
                                                  batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        # For a single-layer LSTM, hn has shape (1, N*M, lstm_hidden)
        hn = hn.squeeze(0)  # Now shape: (N*M, lstm_hidden)
        
        # Reshape to get per-vehicle representation: (N, M, lstm_hidden)
        vehicle_repr = hn.view(N, M, self.lstm_hidden)
        
        # --- Process Global Information (lane markers and boundaries) ---
        # Before processing, replace NaN values in lane markers and boundaries with 0.
        # Alternatively, you might choose to use a meaningful default.
        lane_markers_clean = torch.nan_to_num(lane_markers, nan=0.0)  # shape: (N, 10)
        boundaries_clean = torch.nan_to_num(boundaries, nan=0.0)        # shape: (N, 2)
        # Concatenate lane markers and boundaries into a single tensor: shape: (N, 12)
        global_input = torch.cat([lane_markers_clean, boundaries_clean], dim=1)
        global_info = torch.relu(self.global_fc(global_input))  # shape: (N, global_dim)
        # Broadcast to each vehicle slot: shape becomes (N, M, global_dim)
        global_info = global_info.unsqueeze(1).expand(-1, M, -1)
        
        # --- Combine and Generate Accelerations ---
        # Concatenate per-vehicle representation with global info along the feature dimension.
        combined = torch.cat([vehicle_repr, global_info], dim=-1)  # shape: (N, M, lstm_hidden + global_dim)
        # Pass through FC layers to generate acceleration outputs.
        accelerations = self.combine_fc(combined)  # shape: (N, M, output_size)
        
        # Note: even though accelerations has shape (N, M, 2), downstream you should use the agent mask to choose only valid outputs.
        return accelerations

In [62]:
class ActorNetwork_(nn.Module):
    def __init__(self, T, M, F, lstm_hidden=64, global_dim=12, combined_hidden=64, output_size=2):
        """
        Actor network for generating vehicle acceleration decisions.
        
        Parameters:
          T (int): History length (number of time steps).
          M (int): Maximum number of vehicles per sample.
          F (int): Number of features per vehicle per timestep (e.g. [x, y, xVelocity, yVelocity]).
          lstm_hidden (int): Hidden dimension for the LSTM.
          global_dim (int): Dimension for processing the global inputs (lane markers + road boundaries).
          combined_hidden (int): Hidden dimension in the combined FC layers.
          output_size (int): Dimension of the acceleration output (typically 2).
        """
        super(ActorNetwork_, self).__init__()
        self.T = T
        self.M = M
        self.F = F
        self.lstm_hidden = lstm_hidden
        
        # LSTM to process each vehicle's time-dependent sequence (shape: (T, F)).
        self.lstm = nn.LSTM(input_size=F, hidden_size=lstm_hidden, batch_first=True)
        
        # Global information comes from lane markers (10) and road boundaries (2) => 12 values.
        # We'll process the lane markers after masking them.
        self.global_fc = nn.Linear(10, global_dim)  # process lane markers separately
        self.boundary_fc = nn.Linear(2, global_dim)   # process boundaries separately
        
        # Combine global features (concatenated) and map to final global representation.
        self.global_combine_fc = nn.Linear(2 * global_dim, global_dim)
        
        # Combined fully connected layers mapping concatenated per-vehicle and global features to the output.
        self.combine_fc = nn.Sequential(
            nn.Linear(lstm_hidden + global_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, output_size)
        )
    
    def forward(self, time_dep, lane_markers, boundaries):
        """
        Forward pass.
        
        Parameters:
          time_dep (torch.Tensor): Time-dependent kinematics, shape (N, T, M, F).
            (Missing entries are NaN; they are replaced by 0 before processing.)
          lane_markers (torch.Tensor): Global lane markers, shape (N, 10). May contain NaNs.
          boundaries (torch.Tensor): Global road boundaries, shape (N, 2). (Assumed to be complete.)
          
        Returns:
          accelerations (torch.Tensor): Output accelerations of shape (N, M, 2).
        """
        N, T, M, F = time_dep.shape  # Unpack dimensions
        
        # --- Process Time-Dependent Kinematics ---
        # Rearrange to shape (N, M, T, F) and flatten to (N*M, T, F)
        time_dep = time_dep.permute(0, 2, 1, 3).contiguous().view(N * M, T, F)
        
        # Compute valid mask: a time step is valid if not all F values are NaN.
        valid_mask = ~torch.all(torch.isnan(time_dep), dim=-1)  # (N*M, T)
        seq_lengths = valid_mask.sum(dim=1)  # (N*M,)
        seq_lengths[seq_lengths == 0] = 1
        
        # Replace NaNs with 0 in the time-dependent input.
        time_dep_clean = torch.nan_to_num(time_dep, nan=0.0)
        
        # Pack padded sequence and process with LSTM.
        packed = rnn_utils.pack_padded_sequence(time_dep_clean, lengths=seq_lengths.cpu(),
                                                  batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        # hn: (num_layers, N*M, lstm_hidden) → assume single layer and squeeze:
        hn = hn.squeeze(0)  # (N*M, lstm_hidden)
        
        # Reshape vehicle representation to (N, M, lstm_hidden)
        vehicle_repr = hn.view(N, M, self.lstm_hidden)
        
        # --- Process Global Information with Masking for Lane Markers ---
        # lane_markers: (N, 10)
        # Create a mask for valid lane markers: 1 where not NaN, 0 where NaN.
        lane_mask = (~torch.isnan(lane_markers)).float()  # (N, 10)
        # Replace NaNs with 0.
        lane_markers_clean = torch.nan_to_num(lane_markers, nan=0.0)
        # Process lane markers via a linear layer.
        global_lane_features = torch.relu(self.global_fc(lane_markers_clean))  # (N, global_dim)
        # Scale lane features by the fraction of valid markers.
        valid_ratio = lane_mask.mean(dim=1, keepdim=True)  # (N, 1)
        global_lane_features = global_lane_features * valid_ratio
        
        # Process boundaries via another linear layer.
        global_boundaries_features = torch.relu(self.boundary_fc(boundaries))  # (N, global_dim)
        
        # Combine the two global features.
        global_combined = torch.cat([global_lane_features, global_boundaries_features], dim=1)  # (N, 2*global_dim)
        global_info = torch.relu(self.global_combine_fc(global_combined))  # (N, global_dim)
        # Broadcast global_info to each vehicle slot: (N, M, global_dim)
        global_info = global_info.unsqueeze(1).expand(-1, M, -1)
        
        # --- Combine Per-Vehicle and Global Representations ---
        combined = torch.cat([vehicle_repr, global_info], dim=-1)  # (N, M, lstm_hidden+global_dim)
        accelerations = self.combine_fc(combined)  # (N, M, output_size) (typically 2: [xAccel, yAccel])
        
        return accelerations

In [60]:
actor = ActorNetwork_(50, 100, 7)

In [64]:
time_dep, lane_markers, boundaries = obs[:3]

In [65]:
actor.train()  # set to training mode

# Define an optimizer.
optimizer = optim.Adam(actor.parameters(), lr=1e-3)

# Training loop: We'll do a few iterations updating the network using a fake loss defined as the mean absolute acceleration.
num_iterations = 1000
for i in trange(num_iterations):
    optimizer.zero_grad()
    # Forward pass.
    accel_output = actor(time_dep, lane_markers, boundaries)  # shape (N, M, 2)
    # Compute a fake loss: mean of the absolute acceleration values.
    loss = torch.abs(accel_output).mean()
    loss.backward()
    optimizer.step()
    
    if (i+1) % 10 == 0:
        print(f"Iteration {i+1}/{num_iterations}, Loss: {loss.item():.4f}")

  2%|▎         | 25/1000 [00:00<00:12, 78.60it/s]

Iteration 10/1000, Loss: 0.0122
Iteration 20/1000, Loss: 0.0070


  4%|▍         | 43/1000 [00:00<00:11, 83.04it/s]

Iteration 30/1000, Loss: 0.0072
Iteration 40/1000, Loss: 0.0029


  6%|▌         | 62/1000 [00:00<00:11, 84.38it/s]

Iteration 50/1000, Loss: 0.0030
Iteration 60/1000, Loss: 0.0014


  8%|▊         | 80/1000 [00:00<00:10, 86.72it/s]

Iteration 70/1000, Loss: 0.0017
Iteration 80/1000, Loss: 0.0026


 10%|▉         | 98/1000 [00:01<00:10, 86.49it/s]

Iteration 90/1000, Loss: 0.0005
Iteration 100/1000, Loss: 0.0011


 13%|█▎        | 126/1000 [00:01<00:09, 89.16it/s]

Iteration 110/1000, Loss: 0.0012
Iteration 120/1000, Loss: 0.0013


 14%|█▍        | 144/1000 [00:01<00:09, 88.34it/s]

Iteration 130/1000, Loss: 0.0005
Iteration 140/1000, Loss: 0.0013


 16%|█▋        | 164/1000 [00:01<00:09, 89.76it/s]

Iteration 150/1000, Loss: 0.0019
Iteration 160/1000, Loss: 0.0007


 18%|█▊        | 183/1000 [00:02<00:08, 91.33it/s]

Iteration 170/1000, Loss: 0.0016
Iteration 180/1000, Loss: 0.0011
Iteration 190/1000, Loss: 0.0006


 21%|██▏       | 213/1000 [00:02<00:08, 90.83it/s]

Iteration 200/1000, Loss: 0.0020
Iteration 210/1000, Loss: 0.0023


 23%|██▎       | 233/1000 [00:02<00:08, 90.39it/s]

Iteration 220/1000, Loss: 0.0003
Iteration 230/1000, Loss: 0.0011


 25%|██▌       | 253/1000 [00:02<00:08, 89.92it/s]

Iteration 240/1000, Loss: 0.0002
Iteration 250/1000, Loss: 0.0022
Iteration 260/1000, Loss: 0.0021


 28%|██▊       | 283/1000 [00:03<00:07, 90.51it/s]

Iteration 270/1000, Loss: 0.0025
Iteration 280/1000, Loss: 0.0019


 30%|███       | 302/1000 [00:03<00:07, 88.97it/s]

Iteration 290/1000, Loss: 0.0004
Iteration 300/1000, Loss: 0.0013


 32%|███▏      | 322/1000 [00:03<00:07, 91.75it/s]

Iteration 310/1000, Loss: 0.0015
Iteration 320/1000, Loss: 0.0012


 34%|███▍      | 342/1000 [00:03<00:07, 91.46it/s]

Iteration 330/1000, Loss: 0.0030
Iteration 340/1000, Loss: 0.0010


 36%|███▌      | 362/1000 [00:04<00:06, 93.81it/s]

Iteration 350/1000, Loss: 0.0027
Iteration 360/1000, Loss: 0.0036


 38%|███▊      | 382/1000 [00:04<00:06, 94.54it/s]

Iteration 370/1000, Loss: 0.0007
Iteration 380/1000, Loss: 0.0023


 40%|████      | 402/1000 [00:04<00:06, 92.31it/s]

Iteration 390/1000, Loss: 0.0036
Iteration 400/1000, Loss: 0.0022


 42%|████▏     | 422/1000 [00:04<00:06, 90.83it/s]

Iteration 410/1000, Loss: 0.0015
Iteration 420/1000, Loss: 0.0012


 44%|████▍     | 442/1000 [00:04<00:05, 93.09it/s]

Iteration 430/1000, Loss: 0.0009
Iteration 440/1000, Loss: 0.0032


 46%|████▌     | 462/1000 [00:05<00:05, 91.65it/s]

Iteration 450/1000, Loss: 0.0021
Iteration 460/1000, Loss: 0.0021


 48%|████▊     | 482/1000 [00:05<00:05, 90.70it/s]

Iteration 470/1000, Loss: 0.0019
Iteration 480/1000, Loss: 0.0021


 50%|█████     | 502/1000 [00:05<00:05, 91.22it/s]

Iteration 490/1000, Loss: 0.0017
Iteration 500/1000, Loss: 0.0017


 52%|█████▏    | 522/1000 [00:05<00:05, 92.39it/s]

Iteration 510/1000, Loss: 0.0025
Iteration 520/1000, Loss: 0.0013
Iteration 530/1000, Loss: 0.0013


 55%|█████▌    | 552/1000 [00:06<00:04, 93.05it/s]

Iteration 540/1000, Loss: 0.0012
Iteration 550/1000, Loss: 0.0019


 57%|█████▋    | 572/1000 [00:06<00:04, 87.97it/s]

Iteration 560/1000, Loss: 0.0003
Iteration 570/1000, Loss: 0.0007


 59%|█████▉    | 593/1000 [00:06<00:04, 92.47it/s]

Iteration 580/1000, Loss: 0.0012
Iteration 590/1000, Loss: 0.0011


 61%|██████▏   | 613/1000 [00:06<00:04, 93.22it/s]

Iteration 600/1000, Loss: 0.0002
Iteration 610/1000, Loss: 0.0009


 63%|██████▎   | 633/1000 [00:07<00:04, 91.09it/s]

Iteration 620/1000, Loss: 0.0020
Iteration 630/1000, Loss: 0.0003


 65%|██████▌   | 653/1000 [00:07<00:03, 90.90it/s]

Iteration 640/1000, Loss: 0.0003
Iteration 650/1000, Loss: 0.0005


 67%|██████▋   | 673/1000 [00:07<00:03, 90.57it/s]

Iteration 660/1000, Loss: 0.0006
Iteration 670/1000, Loss: 0.0011


 69%|██████▉   | 693/1000 [00:07<00:03, 91.72it/s]

Iteration 680/1000, Loss: 0.0015
Iteration 690/1000, Loss: 0.0009
Iteration 700/1000, Loss: 0.0020


 72%|███████▏  | 723/1000 [00:08<00:03, 91.49it/s]

Iteration 710/1000, Loss: 0.0013
Iteration 720/1000, Loss: 0.0033


 74%|███████▍  | 743/1000 [00:08<00:02, 91.99it/s]

Iteration 730/1000, Loss: 0.0015
Iteration 740/1000, Loss: 0.0008


 76%|███████▌  | 762/1000 [00:08<00:02, 88.52it/s]

Iteration 750/1000, Loss: 0.0011
Iteration 760/1000, Loss: 0.0008


 78%|███████▊  | 781/1000 [00:08<00:02, 89.35it/s]

Iteration 770/1000, Loss: 0.0007
Iteration 780/1000, Loss: 0.0013


 80%|████████  | 801/1000 [00:08<00:02, 91.36it/s]

Iteration 790/1000, Loss: 0.0015
Iteration 800/1000, Loss: 0.0008
Iteration 810/1000, Loss: 0.0002


 83%|████████▎ | 831/1000 [00:09<00:01, 92.02it/s]

Iteration 820/1000, Loss: 0.0006
Iteration 830/1000, Loss: 0.0009


 85%|████████▌ | 851/1000 [00:09<00:01, 93.54it/s]

Iteration 840/1000, Loss: 0.0010
Iteration 850/1000, Loss: 0.0016


 87%|████████▋ | 871/1000 [00:09<00:01, 90.50it/s]

Iteration 860/1000, Loss: 0.0013
Iteration 870/1000, Loss: 0.0009


 89%|████████▉ | 891/1000 [00:09<00:01, 89.43it/s]

Iteration 880/1000, Loss: 0.0014
Iteration 890/1000, Loss: 0.0003


 91%|█████████ | 911/1000 [00:10<00:00, 90.18it/s]

Iteration 900/1000, Loss: 0.0012
Iteration 910/1000, Loss: 0.0003
Iteration 920/1000, Loss: 0.0012


 94%|█████████▍| 941/1000 [00:10<00:00, 89.83it/s]

Iteration 930/1000, Loss: 0.0009
Iteration 940/1000, Loss: 0.0019


 96%|█████████▌| 961/1000 [00:10<00:00, 89.86it/s]

Iteration 950/1000, Loss: 0.0023
Iteration 960/1000, Loss: 0.0024


 98%|█████████▊| 981/1000 [00:10<00:00, 91.13it/s]

Iteration 970/1000, Loss: 0.0006
Iteration 980/1000, Loss: 0.0012


100%|██████████| 1000/1000 [00:11<00:00, 90.38it/s]

Iteration 990/1000, Loss: 0.0003
Iteration 1000/1000, Loss: 0.0014





In [58]:
actor(obs[0], obs[1], obs[2])

tensor([[[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        [[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        [[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        ...,

        [[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        [[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        [[-0.0430, -0.0210],
       

In [7]:
batch = random.sample(buffer.buffer, 64)
states, actions, rewards, next_states, dones = zip(*batch)

In [11]:
state_keys=['time_dependent','time_independent', 'lane_markers', 'boundary_lines', 'agent_mask']

In [17]:
[ torch.stack([obs[key] for obs in states]) for key in state_keys ][2].shape

torch.Size([64, 2])

In [None]:
actions

AttributeError: 'tuple' object has no attribute 'shape'