In [1]:
import random
from collections import deque
import torch
from env import HighwayEnv, convert_highd_sample_to_gail_expert
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from buffer import HighwayEnvMemoryBuffer

In [2]:
buffer = HighwayEnvMemoryBuffer(300)

expert_data, df = convert_highd_sample_to_gail_expert(
    sample_csv=r"./data/26_sample_tracks.csv",
    meta_csv=r"E:\Data\highd-dataset-v1.0\data\26_recordingMeta.csv",
    forward=False,
    p_agent=0.90
)

In [3]:
NUM_STEPS = 300

In [4]:
# Create the environment and set expert data.
env = HighwayEnv(dt=0.2, T=50, generation_mode=False, demo_mode=True)
# Uncomment and update the following line when expert_data is available:
env.set_expert_data(expert_data)


In [12]:
# 
obs = env.reset() 
#
for step in trange(NUM_STEPS):
    # For demonstration, sample random actions for each vehicle slot.
    # Action shape: (N_max, 2). Since have set DEMO_MODE=TRUE, the actions will be overwritten by the agents
    action = torch.full((env.N_max, 2), 0.0)
    log_prob = torch.full((env.N_max, 2), 1.0)
    # Step the environment: we get new observation, reward, done, and info.
    next_obs, reward, done, info = env.step(action)
    buffer.push(obs, action, reward, log_prob, next_obs, done)


100%|██████████| 300/300 [00:04<00:00, 69.18it/s]


In [13]:
obs, action, reward, log_prob, next_obs, done = buffer.sample(64)

In [9]:
obs[0].shape

torch.Size([64, 50, 100, 7])

In [10]:
obs[1].shape

torch.Size([64, 10])

In [11]:
obs[2].shape

torch.Size([64, 2])

In [12]:
obs[3].shape

torch.Size([64, 100])

In [13]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import torch.optim as optim

In [14]:
class ActorNetwork(nn.Module):
    def __init__(self, T, M, F, lstm_hidden=64, global_dim=12, combined_hidden=64, output_size=2):
        """
        Actor network for generating vehicle acceleration decisions.
        
        Parameters:
          T (int): History length (number of time steps).
          M (int): Maximum number of vehicles per sample.
          F (int): Number of features per vehicle per timestep (e.g. [x, y, xVelocity, yVelocity]).
          lstm_hidden (int): Hidden dimension for the LSTM.
          global_dim (int): Dimension for processing the global inputs (lane markers + road boundaries).
          combined_hidden (int): Hidden dimension in the combined FC layers.
          output_size (int): Dimension of the acceleration output (typically 2).
        """
        super(ActorNetwork, self).__init__()
        self.T = T
        self.M = M
        self.F = F
        self.lstm_hidden = lstm_hidden
        
        # LSTM to process each vehicle's time-dependent sequence (shape: (T, F)).
        self.lstm = nn.LSTM(input_size=F, hidden_size=lstm_hidden, batch_first=True)
        
        # Global information comes from lane markers (10) and road boundaries (2) => 12 values.
        self.global_fc = nn.Linear(12, global_dim)
        
        # Combined fully connected layers mapping concatenated per-vehicle and global features to the output.
        self.combine_fc = nn.Sequential(
            nn.Linear(lstm_hidden + global_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, output_size)
        )
    
    def forward(self, time_dep, lane_markers, boundaries):
        """
        Forward pass.
        
        Parameters:
          time_dep (torch.Tensor): Time-dependent kinematics, shape (N, T, M, F).
            (Missing entries are NaN; they are replaced by 0 before processing.)
          lane_markers (torch.Tensor): Global lane markers, shape (N, 10). May contain NaNs.
          boundaries (torch.Tensor): Global road boundaries, shape (N, 2). May contain NaNs.
          
        Returns:
          accelerations (torch.Tensor): Output accelerations of shape (N, M, 2).
        """
        N, T, M, F = time_dep.shape  # Unpack dimensions.
        
        # --- Process Time-Dependent Kinematics ---
        # Rearrange to shape (N, M, T, F), then flatten the N and M dimensions: (N*M, T, F).
        time_dep = time_dep.permute(0, 2, 1, 3).contiguous()
        time_dep = time_dep.view(N * M, T, F)
        
        # Compute valid mask for each sequence (valid if not all features are NaN)
        valid_mask = ~torch.all(torch.isnan(time_dep), dim=-1)  # shape: (N*M, T)
        seq_lengths = valid_mask.sum(dim=1)  # shape: (N*M,)
        # Ensure a minimum length of 1 for any sequence.
        seq_lengths[seq_lengths == 0] = 1
        
        # Replace NaN values in the kinematics with 0.
        time_dep_clean = torch.nan_to_num(time_dep, nan=0.0)
        
        # Pack the sequence so that the LSTM ignores padded time steps.
        packed = rnn_utils.pack_padded_sequence(time_dep_clean, lengths=seq_lengths.cpu(),
                                                  batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        # For a single-layer LSTM, hn has shape (1, N*M, lstm_hidden)
        hn = hn.squeeze(0)  # Now shape: (N*M, lstm_hidden)
        
        # Reshape to get per-vehicle representation: (N, M, lstm_hidden)
        vehicle_repr = hn.view(N, M, self.lstm_hidden)
        
        # --- Process Global Information (lane markers and boundaries) ---
        # Before processing, replace NaN values in lane markers and boundaries with 0.
        # Alternatively, you might choose to use a meaningful default.
        lane_markers_clean = torch.nan_to_num(lane_markers, nan=0.0)  # shape: (N, 10)
        boundaries_clean = torch.nan_to_num(boundaries, nan=0.0)        # shape: (N, 2)
        # Concatenate lane markers and boundaries into a single tensor: shape: (N, 12)
        global_input = torch.cat([lane_markers_clean, boundaries_clean], dim=1)
        global_info = torch.relu(self.global_fc(global_input))  # shape: (N, global_dim)
        # Broadcast to each vehicle slot: shape becomes (N, M, global_dim)
        global_info = global_info.unsqueeze(1).expand(-1, M, -1)
        
        # --- Combine and Generate Accelerations ---
        # Concatenate per-vehicle representation with global info along the feature dimension.
        combined = torch.cat([vehicle_repr, global_info], dim=-1)  # shape: (N, M, lstm_hidden + global_dim)
        # Pass through FC layers to generate acceleration outputs.
        accelerations = self.combine_fc(combined)  # shape: (N, M, output_size)
        
        # Note: even though accelerations has shape (N, M, 2), downstream you should use the agent mask to choose only valid outputs.
        return accelerations

In [15]:
class ActorNetwork_(nn.Module):
    def __init__(self, T, M, F, lstm_hidden=64, global_dim=12, combined_hidden=64, output_size=2):
        """
        Actor network for generating vehicle acceleration decisions.
        
        Parameters:
          T (int): History length (number of time steps).
          M (int): Maximum number of vehicles per sample.
          F (int): Number of features per vehicle per timestep (e.g. [x, y, xVelocity, yVelocity]).
          lstm_hidden (int): Hidden dimension for the LSTM.
          global_dim (int): Dimension for processing the global inputs (lane markers + road boundaries).
          combined_hidden (int): Hidden dimension in the combined FC layers.
          output_size (int): Dimension of the acceleration output (typically 2).
        """
        super(ActorNetwork_, self).__init__()
        self.T = T
        self.M = M
        self.F = F
        self.lstm_hidden = lstm_hidden
        
        # LSTM to process each vehicle's time-dependent sequence (shape: (T, F)).
        self.lstm = nn.LSTM(input_size=F, hidden_size=lstm_hidden, batch_first=True)
        
        # Global information comes from lane markers (10) and road boundaries (2) => 12 values.
        # We'll process the lane markers after masking them.
        self.global_fc = nn.Linear(10, global_dim)  # process lane markers separately
        self.boundary_fc = nn.Linear(2, global_dim)   # process boundaries separately
        
        # Combine global features (concatenated) and map to final global representation.
        self.global_combine_fc = nn.Linear(2 * global_dim, global_dim)
        
        # Combined fully connected layers mapping concatenated per-vehicle and global features to the output.
        self.combine_fc = nn.Sequential(
            nn.Linear(lstm_hidden + global_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, output_size)
        )
    
    def forward(self, time_dep, lane_markers, boundaries):
        """
        Forward pass.
        
        Parameters:
          time_dep (torch.Tensor): Time-dependent kinematics, shape (N, T, M, F).
            (Missing entries are NaN; they are replaced by 0 before processing.)
          lane_markers (torch.Tensor): Global lane markers, shape (N, 10). May contain NaNs.
          boundaries (torch.Tensor): Global road boundaries, shape (N, 2). (Assumed to be complete.)
          
        Returns:
          accelerations (torch.Tensor): Output accelerations of shape (N, M, 2).
        """
        N, T, M, F = time_dep.shape  # Unpack dimensions
        
        # --- Process Time-Dependent Kinematics ---
        # Rearrange to shape (N, M, T, F) and flatten to (N*M, T, F)
        time_dep = time_dep.permute(0, 2, 1, 3).contiguous().view(N * M, T, F)
        
        # Compute valid mask: a time step is valid if not all F values are NaN.
        valid_mask = ~torch.all(torch.isnan(time_dep), dim=-1)  # (N*M, T)
        seq_lengths = valid_mask.sum(dim=1)  # (N*M,)
        seq_lengths[seq_lengths == 0] = 1
        
        # Replace NaNs with 0 in the time-dependent input.
        time_dep_clean = torch.nan_to_num(time_dep, nan=0.0)
        
        # Pack padded sequence and process with LSTM.
        packed = rnn_utils.pack_padded_sequence(time_dep_clean, lengths=seq_lengths.cpu(),
                                                  batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        # hn: (num_layers, N*M, lstm_hidden) → assume single layer and squeeze:
        hn = hn.squeeze(0)  # (N*M, lstm_hidden)
        
        # Reshape vehicle representation to (N, M, lstm_hidden)
        vehicle_repr = hn.view(N, M, self.lstm_hidden)
        
        # --- Process Global Information with Masking for Lane Markers ---
        # lane_markers: (N, 10)
        # Create a mask for valid lane markers: 1 where not NaN, 0 where NaN.
        lane_mask = (~torch.isnan(lane_markers)).float()  # (N, 10)
        # Replace NaNs with 0.
        lane_markers_clean = torch.nan_to_num(lane_markers, nan=0.0)
        # Process lane markers via a linear layer.
        global_lane_features = torch.relu(self.global_fc(lane_markers_clean))  # (N, global_dim)
        # Scale lane features by the fraction of valid markers.
        valid_ratio = lane_mask.mean(dim=1, keepdim=True)  # (N, 1)
        global_lane_features = global_lane_features * valid_ratio
        
        # Process boundaries via another linear layer.
        global_boundaries_features = torch.relu(self.boundary_fc(boundaries))  # (N, global_dim)
        
        # Combine the two global features.
        global_combined = torch.cat([global_lane_features, global_boundaries_features], dim=1)  # (N, 2*global_dim)
        global_info = torch.relu(self.global_combine_fc(global_combined))  # (N, global_dim)
        # Broadcast global_info to each vehicle slot: (N, M, global_dim)
        global_info = global_info.unsqueeze(1).expand(-1, M, -1)
        
        # --- Combine Per-Vehicle and Global Representations ---
        combined = torch.cat([vehicle_repr, global_info], dim=-1)  # (N, M, lstm_hidden+global_dim)
        accelerations = self.combine_fc(combined)  # (N, M, output_size) (typically 2: [xAccel, yAccel])
        
        return accelerations

In [16]:
actor = ActorNetwork_(50, 100, 7)

In [17]:
time_dep, lane_markers, boundaries = obs[:3]

In [18]:
actor.train()  # set to training mode

# Define an optimizer.
optimizer = optim.Adam(actor.parameters(), lr=1e-3)

# Training loop: We'll do a few iterations updating the network using a fake loss defined as the mean absolute acceleration.
num_iterations = 1000
for i in trange(num_iterations):
    optimizer.zero_grad()
    # Forward pass.
    accel_output = actor(time_dep, lane_markers, boundaries)  # shape (N, M, 2)
    # Compute a fake loss: mean of the absolute acceleration values.
    loss = torch.abs(accel_output).mean()
    loss.backward()
    optimizer.step()
    
    if (i+1) % 10 == 0:
        print(f"Iteration {i+1}/{num_iterations}, Loss: {loss.item():.4f}")

  2%|▏         | 16/1000 [00:00<00:17, 55.69it/s]

Iteration 10/1000, Loss: 0.0218
Iteration 20/1000, Loss: 0.0060


  4%|▎         | 37/1000 [00:00<00:15, 61.32it/s]

Iteration 30/1000, Loss: 0.0033
Iteration 40/1000, Loss: 0.0017


  6%|▌         | 58/1000 [00:00<00:15, 61.94it/s]

Iteration 50/1000, Loss: 0.0019
Iteration 60/1000, Loss: 0.0014


  8%|▊         | 79/1000 [00:01<00:14, 62.62it/s]

Iteration 70/1000, Loss: 0.0022
Iteration 80/1000, Loss: 0.0010


 10%|█         | 100/1000 [00:01<00:14, 63.21it/s]

Iteration 90/1000, Loss: 0.0013
Iteration 100/1000, Loss: 0.0015


 12%|█▏        | 121/1000 [00:01<00:13, 64.59it/s]

Iteration 110/1000, Loss: 0.0001
Iteration 120/1000, Loss: 0.0022


 14%|█▎        | 135/1000 [00:02<00:13, 62.25it/s]

Iteration 130/1000, Loss: 0.0014
Iteration 140/1000, Loss: 0.0025


 16%|█▌        | 156/1000 [00:02<00:14, 59.90it/s]

Iteration 150/1000, Loss: 0.0008
Iteration 160/1000, Loss: 0.0008


 18%|█▊        | 180/1000 [00:02<00:13, 59.28it/s]

Iteration 170/1000, Loss: 0.0007
Iteration 180/1000, Loss: 0.0006


 20%|██        | 200/1000 [00:03<00:13, 57.64it/s]

Iteration 190/1000, Loss: 0.0012
Iteration 200/1000, Loss: 0.0004


 22%|██▏       | 218/1000 [00:03<00:13, 57.01it/s]

Iteration 210/1000, Loss: 0.0016
Iteration 220/1000, Loss: 0.0006


 24%|██▎       | 236/1000 [00:03<00:14, 54.25it/s]

Iteration 230/1000, Loss: 0.0018
Iteration 240/1000, Loss: 0.0011


 26%|██▌       | 256/1000 [00:04<00:12, 58.20it/s]

Iteration 250/1000, Loss: 0.0002
Iteration 260/1000, Loss: 0.0002


 28%|██▊       | 276/1000 [00:04<00:12, 59.60it/s]

Iteration 270/1000, Loss: 0.0014
Iteration 280/1000, Loss: 0.0007


 30%|██▉       | 296/1000 [00:04<00:11, 59.38it/s]

Iteration 290/1000, Loss: 0.0011
Iteration 300/1000, Loss: 0.0013


 32%|███▏      | 316/1000 [00:05<00:11, 59.41it/s]

Iteration 310/1000, Loss: 0.0013
Iteration 320/1000, Loss: 0.0023


 34%|███▎      | 337/1000 [00:05<00:10, 61.94it/s]

Iteration 330/1000, Loss: 0.0003
Iteration 340/1000, Loss: 0.0015


 36%|███▌      | 358/1000 [00:05<00:10, 61.90it/s]

Iteration 350/1000, Loss: 0.0007
Iteration 360/1000, Loss: 0.0014


 38%|███▊      | 379/1000 [00:06<00:09, 62.72it/s]

Iteration 370/1000, Loss: 0.0009
Iteration 380/1000, Loss: 0.0013


 40%|████      | 400/1000 [00:06<00:09, 62.43it/s]

Iteration 390/1000, Loss: 0.0014
Iteration 400/1000, Loss: 0.0014


 42%|████▏     | 421/1000 [00:06<00:09, 62.38it/s]

Iteration 410/1000, Loss: 0.0016
Iteration 420/1000, Loss: 0.0009


 44%|████▍     | 442/1000 [00:07<00:08, 62.62it/s]

Iteration 430/1000, Loss: 0.0011
Iteration 440/1000, Loss: 0.0011


 46%|████▌     | 456/1000 [00:07<00:08, 62.89it/s]

Iteration 450/1000, Loss: 0.0014
Iteration 460/1000, Loss: 0.0007


 48%|████▊     | 477/1000 [00:07<00:08, 63.03it/s]

Iteration 470/1000, Loss: 0.0013
Iteration 480/1000, Loss: 0.0011


 50%|████▉     | 498/1000 [00:08<00:08, 61.64it/s]

Iteration 490/1000, Loss: 0.0010
Iteration 500/1000, Loss: 0.0008


 52%|█████▏    | 519/1000 [00:08<00:07, 61.76it/s]

Iteration 510/1000, Loss: 0.0008
Iteration 520/1000, Loss: 0.0009


 54%|█████▍    | 540/1000 [00:08<00:07, 63.46it/s]

Iteration 530/1000, Loss: 0.0013
Iteration 540/1000, Loss: 0.0008


 56%|█████▌    | 561/1000 [00:09<00:06, 63.52it/s]

Iteration 550/1000, Loss: 0.0009
Iteration 560/1000, Loss: 0.0018


 57%|█████▊    | 575/1000 [00:09<00:06, 63.80it/s]

Iteration 570/1000, Loss: 0.0015
Iteration 580/1000, Loss: 0.0007


 60%|█████▉    | 596/1000 [00:09<00:06, 63.63it/s]

Iteration 590/1000, Loss: 0.0007
Iteration 600/1000, Loss: 0.0010


 62%|██████▏   | 617/1000 [00:10<00:05, 64.18it/s]

Iteration 610/1000, Loss: 0.0018
Iteration 620/1000, Loss: 0.0007


 64%|██████▍   | 638/1000 [00:10<00:05, 64.10it/s]

Iteration 630/1000, Loss: 0.0012
Iteration 640/1000, Loss: 0.0008


 66%|██████▌   | 659/1000 [00:10<00:05, 62.59it/s]

Iteration 650/1000, Loss: 0.0009
Iteration 660/1000, Loss: 0.0002


 68%|██████▊   | 680/1000 [00:11<00:05, 63.37it/s]

Iteration 670/1000, Loss: 0.0001
Iteration 680/1000, Loss: 0.0009


 70%|███████   | 701/1000 [00:11<00:04, 62.62it/s]

Iteration 690/1000, Loss: 0.0012
Iteration 700/1000, Loss: 0.0003


 72%|███████▏  | 722/1000 [00:11<00:04, 62.41it/s]

Iteration 710/1000, Loss: 0.0019
Iteration 720/1000, Loss: 0.0023


 74%|███████▎  | 736/1000 [00:11<00:04, 61.10it/s]

Iteration 730/1000, Loss: 0.0004
Iteration 740/1000, Loss: 0.0012


 76%|███████▌  | 757/1000 [00:12<00:03, 62.32it/s]

Iteration 750/1000, Loss: 0.0009
Iteration 760/1000, Loss: 0.0010


 78%|███████▊  | 778/1000 [00:12<00:03, 62.43it/s]

Iteration 770/1000, Loss: 0.0012
Iteration 780/1000, Loss: 0.0016


 80%|███████▉  | 799/1000 [00:13<00:03, 60.76it/s]

Iteration 790/1000, Loss: 0.0015
Iteration 800/1000, Loss: 0.0016


 82%|████████▏ | 820/1000 [00:13<00:02, 60.70it/s]

Iteration 810/1000, Loss: 0.0012
Iteration 820/1000, Loss: 0.0003


 84%|████████▍ | 841/1000 [00:13<00:02, 60.95it/s]

Iteration 830/1000, Loss: 0.0008
Iteration 840/1000, Loss: 0.0007


 86%|████████▌ | 862/1000 [00:14<00:02, 61.91it/s]

Iteration 850/1000, Loss: 0.0004
Iteration 860/1000, Loss: 0.0011


 88%|████████▊ | 876/1000 [00:14<00:02, 61.33it/s]

Iteration 870/1000, Loss: 0.0008
Iteration 880/1000, Loss: 0.0006


 90%|████████▉ | 897/1000 [00:14<00:01, 61.56it/s]

Iteration 890/1000, Loss: 0.0012
Iteration 900/1000, Loss: 0.0008


 92%|█████████▏| 918/1000 [00:14<00:01, 62.33it/s]

Iteration 910/1000, Loss: 0.0009
Iteration 920/1000, Loss: 0.0008


 94%|█████████▍| 939/1000 [00:15<00:00, 62.34it/s]

Iteration 930/1000, Loss: 0.0013
Iteration 940/1000, Loss: 0.0011


 96%|█████████▌| 960/1000 [00:15<00:00, 61.67it/s]

Iteration 950/1000, Loss: 0.0017
Iteration 960/1000, Loss: 0.0007


 98%|█████████▊| 981/1000 [00:15<00:00, 62.86it/s]

Iteration 970/1000, Loss: 0.0005
Iteration 980/1000, Loss: 0.0004


100%|██████████| 1000/1000 [00:16<00:00, 61.46it/s]

Iteration 990/1000, Loss: 0.0006
Iteration 1000/1000, Loss: 0.0011





In [58]:
actor(obs[0], obs[1], obs[2])

tensor([[[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        [[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        [[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        ...,

        [[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        [[-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         ...,
         [-0.0430, -0.0210],
         [-0.0430, -0.0210],
         [-0.0430, -0.0210]],

        [[-0.0430, -0.0210],
       

In [7]:
batch = random.sample(buffer.buffer, 64)
states, actions, rewards, next_states, dones = zip(*batch)

In [11]:
state_keys=['time_dependent','time_independent', 'lane_markers', 'boundary_lines', 'agent_mask']

In [17]:
[ torch.stack([obs[key] for obs in states]) for key in state_keys ][2].shape

torch.Size([64, 2])

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils.rnn as rnn_utils
import numpy as np

#########################################
# Utility function for packing sequences
#########################################
def compute_seq_lengths(time_dep):
    """
    Given time_dep of shape (B, T, F), return a tensor of sequence lengths,
    where a timestep is considered valid if not all features are NaN.
    """
    # valid if not all F values are NaN:
    valid_mask = ~torch.all(torch.isnan(time_dep), dim=-1)  # shape (B, T)
    seq_lengths = valid_mask.sum(dim=1)  # (B,)
    seq_lengths[seq_lengths == 0] = 1  # avoid zeros
    return seq_lengths


In [77]:
#########################################
# Actor Network (Policy)
#########################################
class PPOActor(nn.Module):
    def __init__(self, T, M, F, lstm_hidden=64, global_dim=12, combined_hidden=64, output_size=2):
        """
        Actor network that takes the structured observation and outputs
        a Gaussian distribution over acceleration commands.
        
        Observation:
           - time_dep: (N, T, M, F) 
           - lane_markers: (N, 10)
           - boundary_lines: (N, 2)
        Output:
           - accelerations: (N, M, 2)
        """
        super(PPOActor, self).__init__()
        self.T = T
        self.M = M
        self.F = F
        self.lstm_hidden = lstm_hidden
        
        # LSTM for time-dependent kinematics (for each vehicle's sequence of length T and F features).
        self.lstm = nn.LSTM(input_size=F, hidden_size=lstm_hidden, batch_first=True)
        
        # Global information: process lane markers and boundaries separately.
        self.global_fc = nn.Linear(10, global_dim)  # lane markers of shape (N, 10)
        self.boundary_fc = nn.Linear(2, global_dim)   # boundaries of shape (N, 2)
        self.global_combine_fc = nn.Linear(2 * global_dim, global_dim)
        
        # Fully connected layers that combine per-vehicle features with global info.
        self.combine_fc = nn.Sequential(
            nn.Linear(lstm_hidden + global_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, output_size)
        )
        
        # Learnable log standard deviation for the Gaussian distribution over actions.
        # We use a parameter of shape (1, 1, output_size) that will be broadcast.
        self.log_std = nn.Parameter(torch.zeros(1, 1, output_size))
        
    def forward(self, time_dep, lane_markers, boundaries):
        """
        Forward pass.
        
        Parameters:
           time_dep: (N, T, M, F)
           lane_markers: (N, 10)  (may contain NaNs)
           boundaries: (N, 2)
        
        Returns:
           mean: (N, M, 2) — the mean acceleration for each vehicle slot.
           log_std: (N, M, 2) — the log_std, broadcasted along N and M.
        """
        if time_dep.dim() == 3:
            time_dep = time_dep.unsqueeze(0) 
        if lane_markers.dim() == 1:
            lane_markers = lane_markers.unsqueeze(0) 
        if boundaries.dim() == 1:
            boundaries = boundaries.unsqueeze(0) 

        N, T, M, _F = time_dep.shape
        
        # Process time-dependent kinematics:
        # Permute to (N, M, T, F) then flatten N and M to get (N*M, T, F)
        time_dep = time_dep.permute(0, 2, 1, 3).contiguous().view(N * M, T, _F)
        # Replace NaNs with 0:
        time_dep_clean = torch.nan_to_num(time_dep, nan=0.0)
        # Compute sequence lengths per vehicle:
        seq_lengths = compute_seq_lengths(time_dep)
        assert not torch.isnan(time_dep_clean).any()
        # Pack padded sequence:
        packed = rnn_utils.pack_padded_sequence(time_dep_clean, lengths=seq_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        # hn: (1, N*M, lstm_hidden) -> squeeze to (N*M, lstm_hidden)
        hn = hn.squeeze(0)
        # Reshape back to (N, M, lstm_hidden)
        vehicle_repr = hn.view(N, M, self.lstm_hidden)
        
        # Process global inputs:
        # For lane markers, replace NaNs with 0 and create a mask.
        lane_markers_clean = torch.nan_to_num(lane_markers, nan=0.0)  # (N, 10)
        lane_mask = (~torch.isnan(lane_markers)).float()  # (N, 10)
        global_lane_features = F.relu(self.global_fc(lane_markers_clean))  # (N, global_dim)
        valid_ratio = lane_mask.mean(dim=1, keepdim=True)  # (N, 1)
        global_lane_features = global_lane_features * valid_ratio
        
        global_boundaries_features = F.relu(self.boundary_fc(boundaries))  # (N, global_dim)
        global_combined = torch.cat([global_lane_features, global_boundaries_features], dim=1)  # (N, 2*global_dim)
        global_info = F.relu(self.global_combine_fc(global_combined))  # (N, global_dim)
        global_info = global_info.unsqueeze(1).expand(-1, M, -1)  # (N, M, global_dim)
        
        # Combine vehicle representation with global info:
        combined = torch.cat([vehicle_repr, global_info], dim=-1)  # (N, M, lstm_hidden + global_dim)
        mean = self.combine_fc(combined)  # (N, M, output_size)
        
        # Broadcast log_std:
        log_std = self.log_std.expand(N, M, -1)  # (N, M, output_size)
        
        return mean, log_std
    
    def get_action(self, time_dep, lane_markers, boundaries, agent_mask):
        """
        Sample actions from the policy distribution.
        
        Returns:
           action: (N, M, output_size)
           log_prob: (N, M, output_size) or summed over output_size per vehicle.
        """
        _agent_mask = agent_mask.unsqueeze(0) if agent_mask.dim() == 1 else agent_mask
        mean, log_std = self.forward(time_dep, lane_markers, boundaries)
        # print (mean, log_std)
        std = torch.exp(log_std)
        # Create a normal distribution per vehicle slot.
        dist = torch.distributions.Normal(mean, std)
        # Sample actions using reparameterization (this allows for differentiable sampling).
        action = dist.rsample()  # shape (N, M, output_size)
        # Compute log probabilities.
        log_prob = dist.log_prob(action)  # shape (N, M, output_size)
        log_prob = log_prob.sum(dim=-1)  # aggregate over action dimensions, shape (N, M)
        # Mask out non-agent slots: set log_prob to 0 for non-agent vehicles.
        # This means that when computing the loss, only entries with agent_mask==1 will contribute.
        masked_log_prob = log_prob * _agent_mask  # agent_mask is assumed to be float, with 1 or 0.
        masked_action = action * _agent_mask.unsqueeze(-1) 
        
        return masked_action, masked_log_prob


In [57]:
#########################################
# Critic Network (Value Function)
#########################################
class PPOCritic(nn.Module):
    def __init__(self, T, M, F, lstm_hidden=64, global_dim=12, combined_hidden=64):
        """
        Critic network that takes the same observation and outputs a scalar value for each sample.
        
        In many cases we want a state value per sample (N,1) that is derived from the entire observation.
        We can process each vehicle's time-dependent sequence similar to the actor,
        combine with global features, and then aggregate (for instance, by averaging over vehicles).
        """
        super(PPOCritic, self).__init__()
        self.T = T
        self.M = M
        self.F = F
        self.lstm_hidden = lstm_hidden
        
        # LSTM for time-dependent kinematics (per vehicle):
        self.lstm = nn.LSTM(input_size=F, hidden_size=lstm_hidden, batch_first=True)
        
        # Global information processing:
        self.global_fc = nn.Linear(10, global_dim)  # for lane markers
        self.boundary_fc = nn.Linear(2, global_dim)   # for boundaries
        self.global_combine_fc = nn.Linear(2 * global_dim, global_dim)
        
        # Combine per-vehicle features with global info and aggregate:
        # We combine each vehicle's representation and then average over vehicles.
        self.combine_fc = nn.Sequential(
            nn.Linear(lstm_hidden + global_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, 1)
        )
    
    def forward(self, time_dep, lane_markers, boundaries, agent_mask):
        """
        Forward pass.
        
        Parameters:
           time_dep: (N, T, M, F)
           lane_markers: (N, 10)
           boundaries: (N, 2)
           agent_mask: (N, M) binary mask for valid vehicles.
           
        Returns:
           values: (N, 1) scalar state-value estimates.
        """
        if time_dep.dim() == 3:
            time_dep = time_dep.unsqueeze(0) 
        if lane_markers.dim() == 1:
            lane_markers = lane_markers.unsqueeze(0) 
        if boundaries.dim() == 1:
            boundaries = boundaries.unsqueeze(0) 
        N, T, M, _F = time_dep.shape
        
        # Process time-dependent input as before:
        time_dep = time_dep.permute(0, 2, 1, 3).contiguous().view(N * M, T, _F)
        time_dep_clean = torch.nan_to_num(time_dep, nan=0.0)
        valid_mask = ~torch.all(torch.isnan(time_dep), dim=-1)  # (N*M, T)
        seq_lengths = valid_mask.sum(dim=1)
        seq_lengths[seq_lengths == 0] = 1
        packed = rnn_utils.pack_padded_sequence(time_dep_clean, lengths=seq_lengths.cpu(),
                                                  batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        hn = hn.squeeze(0)  # (N*M, lstm_hidden)
        vehicle_repr = hn.view(N, M, self.lstm_hidden)
        
        # Process global inputs:
        lane_markers_clean = torch.nan_to_num(lane_markers, nan=0.0)
        lane_mask = (~torch.isnan(lane_markers)).float()
        global_lane_features = F.relu(self.global_fc(lane_markers_clean))  # (N, global_dim)
        valid_ratio = lane_mask.mean(dim=1, keepdim=True)
        global_lane_features = global_lane_features * valid_ratio
        
        global_boundaries_features = F.relu(self.boundary_fc(boundaries))  # (N, global_dim)
        global_combined = torch.cat([global_lane_features, global_boundaries_features], dim=1)
        global_info = F.relu(self.global_combine_fc(global_combined))  # (N, global_dim)
        global_info = global_info.unsqueeze(1).expand(-1, M, -1)  # (N, M, global_dim)
        
        # Combine vehicle representation and global features.
        combined = torch.cat([vehicle_repr, global_info], dim=-1)  # (N, M, lstm_hidden+global_dim)
        # Process each vehicle:
        vehicle_values = self.combine_fc(combined)  # (N, M, 1)
        # Mask out invalid vehicles:
        mask = agent_mask.unsqueeze(-1).float()  # (N, M, 1)
        vehicle_values = vehicle_values * mask
        # Aggregate values from each vehicle (e.g., average only over valid vehicles):
        sum_values = vehicle_values.sum(dim=1)  # (N, 1)
        count = mask.sum(dim=1)  # (N, 1)
        # To avoid division by zero:
        count = torch.clamp(count, min=1.0)
        state_values = sum_values / count  # (N, 1)
        return state_values

In [84]:
#########################################
# PPO Update and Training Loop (Outline)
#########################################
def ppo_update(policy_net, critic_net, optimizer_policy, optimizer_critic,
               observations, actions, log_probs_old, returns, advantages,
               clip_epsilon=0.2, value_coef=0.5, entropy_coef=0.01):
    """
    Perform one update step of PPO given a batch of data.
    
    observations: a dict containing keys: 'time_dependent', 'lane_markers', 'boundary_lines', 'mask'
    actions: tensor of shape (N, M, 2)
    log_probs_old: tensor of shape (N, M), the log probability of actions under the old policy.
    returns: tensor of shape (N,), estimated returns.
    advantages: tensor of shape (N,), estimated advantages.
    
    Note: In this simple outline, we assume that the policy is applied per sample (with N samples).
    """
    # Unpack observations:
    time_dep, lane_markers, boundaries, mask = observations.values()   

    # Forward pass through the actor network:
    mean, log_std = policy_net.forward(time_dep, lane_markers, boundaries)
    std = torch.exp(log_std)
    
    # Create the current policy distribution:
    dist = torch.distributions.Normal(mean, std)
    log_probs = dist.log_prob(actions)  # (N, M, 2)
    log_probs = log_probs.sum(dim=-1)    # Sum over action dimensions to get (N, M)
    
    # For simplicity, average log_probs over valid vehicles per sample.
    mask_float = mask.float().unsqueeze(-1)
    # print (mask.shape, log_probs.shape, mask_float.shape)
    log_probs_sample = (log_probs * mask_float).sum(dim=1) / torch.clamp(mask_float.sum(dim=1), min=1.0)
    log_probs_old_sample = (log_probs_old * mask_float).sum(dim=1) / torch.clamp(mask_float.sum(dim=1), min=1.0)
    
    # Compute probability ratio:
    ratio = torch.exp(log_probs_sample - log_probs_old_sample)  # (N,)

    # Compute PPO actor loss with clipping:
    surr1 = ratio * advantages
    surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
    actor_loss = -torch.mean(torch.min(surr1, surr2))
    
    # Compute critic loss (MSE between value estimate and returns):
    values = critic_net.forward(time_dep, lane_markers, boundaries, mask)  # (N, 1)
    values = values.squeeze(-1)
    critic_loss = F.mse_loss(values, returns)
    
    # Compute entropy bonus:
    entropy = dist.entropy().sum(dim=-1)  # (N, M) → sum over action dimensions.
    entropy = (entropy * mask_float).sum(dim=1) / torch.clamp(mask_float.sum(dim=1), min=1.0)
    entropy_bonus = torch.mean(entropy)
    
    loss = actor_loss + value_coef * critic_loss - entropy_coef * entropy_bonus
    
    optimizer_policy.zero_grad()
    optimizer_critic.zero_grad()
    loss.backward()
    optimizer_policy.step()
    optimizer_critic.step()
    
    return loss.item(), actor_loss.item(), critic_loss.item(), entropy_bonus.item()

In [85]:
# For completeness, here is a simple discounting function.
def compute_returns(rewards, gamma):
    """
    Compute discounted returns from a list of rewards.
    Args:
        rewards (list[float]): List of rewards over a rollout (length L).
        gamma (float): Discount factor.
    Returns:
        torch.Tensor: Returns as a tensor of shape (L,)
    """
    returns = []
    R = 0.0
    for r in reversed(rewards):
        R = r + gamma * R
        returns.insert(0, R)
    return torch.tensor(returns, dtype=torch.float32)

In [86]:
# Hyperparameters for training:
num_epochs = 1000          # number of PPO iterations
rollout_steps = 300        # number of steps in each rollout (or use env.total_steps if defined)
gamma = 0.99               # discount factor
clip_epsilon = 0.2         # PPO clipping epsilon
value_coef = 0.5           # coefficient for value loss
entropy_coef = 0.01        # coefficient for entropy bonus

# Create your environment instance.
env = HighwayEnv(generation_mode=True, demo_mode=False, T=50)
# Optionally, set expert data:
# expert_data, df = convert_highd_sample_to_gail_expert(...); env.set_expert_data(expert_data)

# Create PPO actor (policy) and critic networks.
# Let T = history length, M = max number of vehicles, F = features (e.g., 4).

# Uncomment and update the following line when expert_data is available:
env.set_expert_data(expert_data)

T = env.T
M = env.N_max
_F = 7  

In [87]:
actor = PPOActor(T=T, M=M, F=_F, lstm_hidden=64, global_dim=12, combined_hidden=64, output_size=2)
critic = PPOCritic(T=T, M=M, F=_F, lstm_hidden=64, global_dim=12, combined_hidden=64)

actor.train()
critic.train()

# Create optimizers.
optimizer_policy = optim.Adam(actor.parameters(), lr=1e-3)
optimizer_critic = optim.Adam(critic.parameters(), lr=1e-3)
# PPO training loop:
for epoch in trange(num_epochs):
    # Rollout: collect one trajectory from the environment.
    obs = env.reset()
    done = False
    rollout_data = []  # list to store transitions: (obs, action, reward, log_prob)
    total_reward = 0.0
    step = 0
    # Roll out for a fixed number of steps or until done.
    while not done and step < rollout_steps:
        # Get action and log probability from the actor network.
        with torch.no_grad():
            # actor.get_action returns: action (N, M, 2) and log_prob (N, M)
            action, log_prob = actor.get_action(obs['time_dependent'], obs['lane_markers'], obs['boundary_lines'], obs['agent_mask'])
        next_obs, reward, done, info = env.step(action[0])
        # print (next_obs)
        rollout_data.append((obs, action[0], reward, log_prob[0]))
        total_reward += reward  # reward is scalar (from environment) per step
        obs = next_obs
        step += 1

    # Now, assume that the rollout_data is a list of length L (the number of steps collected).
    # We need to compute returns and advantages per rollout step.
    # For simplicity, we treat the scalar reward per step as common across the batch dimension (N).
    rewards = [transition[2] for transition in rollout_data]  # list of rewards, length L
    returns = compute_returns(rewards, gamma)  # tensor shape: (L,)
    # Compute critic values for each step (we use the "time_dependent" observation and associated global inputs)
    critic_values = []
    for (obs, _, _, _) in rollout_data:
        # Critic forward pass: returns shape (N, 1)
        value = critic.forward(obs['time_dependent'], obs['lane_markers'], obs['boundary_lines'], obs['agent_mask'])
        # For simplicity, we average over the vehicle (M) dimension and then take the mean over the batch (N).
        # (In practice, you might compute a more refined advantage estimate.)
        critic_values.append(value.mean())
    critic_values = torch.stack(critic_values)  # shape (L,)
    advantages = returns - critic_values  # shape (L,)
    # For the PPO update, we need to build a mini-batch. In this simplified version, we average across the rollout
    # and use the final observation's data as representative.
    final_obs, final_action, final_log_prob = rollout_data[-1][0], rollout_data[-1][1], rollout_data[-1][3] 
    returns_final = returns[-1]   # scalar
    advantage_final = advantages[-1]  # scalar
    # Call PPO update function with the collected final transition data.
    # (A full implementation would sample multiple mini-batches over all rollout data.)
    loss, actor_loss, critic_loss, entropy_bonus = ppo_update(actor, critic,
                                                              optimizer_policy, optimizer_critic,
                                                              observations=final_obs,
                                                              actions=final_action,
                                                              log_probs_old=final_log_prob,
                                                              returns=returns_final,
                                                              advantages=advantage_final,
                                                              clip_epsilon=clip_epsilon,
                                                              value_coef=value_coef,
                                                              entropy_coef=entropy_coef)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{num_epochs}: Total Reward = {total_reward:.2f}, Loss = {loss:.4f}, Actor Loss = {actor_loss:.4f}, Critic Loss = {critic_loss:.4f}, Entropy = {entropy_bonus:.4f}")

print("PPO training completed.")

  0%|          | 0/1000 [00:00<?, ?it/s]

tensor(-0.6012, grad_fn=<SelectBackward0>)


  critic_loss = F.mse_loss(values, returns)
  0%|          | 1/1000 [00:06<1:45:52,  6.36s/it]

tensor(-0.3147, grad_fn=<SelectBackward0>)


  0%|          | 2/1000 [00:11<1:35:57,  5.77s/it]

tensor(0.1036, grad_fn=<SelectBackward0>)


  0%|          | 3/1000 [00:17<1:36:42,  5.82s/it]

tensor(0.4542, grad_fn=<SelectBackward0>)


  0%|          | 4/1000 [00:23<1:36:36,  5.82s/it]

tensor(0.5779, grad_fn=<SelectBackward0>)


  0%|          | 5/1000 [00:29<1:36:39,  5.83s/it]

tensor(1.1665, grad_fn=<SelectBackward0>)


  1%|          | 6/1000 [00:35<1:37:57,  5.91s/it]

tensor(1.6875, grad_fn=<SelectBackward0>)


  1%|          | 7/1000 [00:41<1:37:13,  5.87s/it]

tensor(0.4623, grad_fn=<SelectBackward0>)


  1%|          | 8/1000 [00:47<1:39:09,  6.00s/it]

tensor(1.4445, grad_fn=<SelectBackward0>)


  1%|          | 9/1000 [00:53<1:40:23,  6.08s/it]

tensor(2.1270, grad_fn=<SelectBackward0>)


  1%|          | 10/1000 [00:59<1:38:32,  5.97s/it]

Epoch 10/1000: Total Reward = 0.00, Loss = 0.0622, Actor Loss = -1.9143, Critic Loss = 4.5241, Entropy = 28.5562
tensor(1.4430, grad_fn=<SelectBackward0>)


  1%|          | 11/1000 [01:05<1:37:44,  5.93s/it]

tensor(1.5593, grad_fn=<SelectBackward0>)


  1%|          | 11/1000 [01:11<1:46:57,  6.49s/it]


KeyboardInterrupt: 

In [76]:
torch.zeros((1,100)).unsqueeze(-1).shape

torch.Size([1, 100, 1])

In [83]:
rollout_data[-1][0]

{'time_dependent': tensor([[[338.5528,  12.2227, -16.0645,  ...,   4.6378,   1.8017,  12.9250],
          [328.8494,  11.1394, -16.0059,  ...,   5.1959,   1.7971,  12.9250],
          [246.1113,  12.1972, -14.1193,  ...,   5.0208,   1.8035,  12.9250],
          ...,
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan]],
 
         [[335.3517,  12.0744, -15.9469,  ...,   4.6378,   1.8017,  12.9250],
          [325.6366,  10.9188, -16.1223,  ...,   5.1959,   1.7971,  12.9250],
          [243.2651,  12.4896, -14.3428,  ...,   5.0208,   1.8035,  12.9250],
          ...,
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan],
          [     nan,      nan,      nan,  ...,      nan,      nan,      nan]],
 
         [

In [62]:
obs['time_dependent'][-1][...,0]

tensor([ 27.0900,  50.7800,  78.8900, 112.4550, 172.2250,  40.3900,  60.4650,
         30.9850, 223.9500,  87.7250,  66.6500, 112.8450, 154.7500, 107.4150,
        255.3000, 132.0350, 185.8650, 163.2550, 201.2500, 294.2700, 191.2100,
        243.3400, 211.5400, 225.6150, 316.4600, 280.2050, 247.4550, 341.1450,
        294.3150, 320.9800, 356.1300, 311.7950, 334.7300, 365.3650, 376.6300,
             nan,      nan,      nan,      nan,      nan,      nan,      nan,
             nan,      nan,      nan,      nan,      nan,      nan,      nan,
             nan,      nan,      nan,      nan,      nan,      nan,      nan,
             nan,      nan,      nan,      nan,      nan,      nan,      nan,
             nan,      nan,      nan,      nan,      nan,      nan,      nan,
             nan,      nan,      nan,      nan,      nan,      nan,      nan,
             nan,      nan,      nan,      nan,      nan,      nan,      nan,
             nan,      nan,      nan,      nan,      nan,      n

In [87]:
obs['agent_mask']

tensor([0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 1., 0.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])

In [88]:
actor.get_action(obs['time_dependent'], obs['lane_markers'], obs['boundary_lines'], obs['agent_mask'])

ValueError: Expected parameter loc (Tensor of shape (1, 100, 2)) of distribution Normal(loc: torch.Size([1, 100, 2]), scale: torch.Size([1, 100, 2])) to satisfy the constraint Real(), but found invalid values:
tensor([[[nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan],
         [nan, nan]]], grad_fn=<ViewBackward0>)

In [None]:
rollout_data[34][1]

IndexError: list index out of range

In [128]:
rollout_data[-1][3]

tensor([-2.4855, -2.8273, -2.2401, -2.5778, -2.4378, -1.8631, -1.8576, -2.0488,
        -2.8671, -2.3174, -2.7676, -4.3424, -1.8963, -2.0925, -3.1784, -3.7279,
        -2.8142, -3.4707, -2.1737, -3.4811, -2.1687, -5.7620, -2.6282, -2.1297,
        -2.7912, -5.1209, -1.8950, -2.3487, -2.2773, -1.8647, -2.2185, -2.8119,
        -2.2143, -2.5488, -3.1125, -2.0399, -4.0850, -3.8517, -4.0146, -2.8722,
        -2.6194, -2.1434, -2.1509, -2.5953, -2.4833, -3.5731, -2.3614, -3.6410,
        -2.1411, -1.9383, -2.3136, -3.3060, -3.8306, -2.2841, -2.1443, -1.8807,
        -2.3448, -3.4627, -3.1537, -1.9227, -2.2948, -2.8104, -2.1649, -2.8447,
        -3.6454, -1.8410, -4.7004, -3.4866, -2.0449, -2.4044, -2.2190, -3.1293,
        -3.4302, -2.6643, -3.8905, -4.3101, -6.5863, -3.2350, -2.4898, -2.6519,
        -1.9388, -3.7525, -2.2636, -2.1696, -2.3912, -3.1848, -5.9453, -2.2931,
        -2.3512, -6.4200, -2.8845, -3.5569, -1.9417, -2.1279, -4.1467, -2.1907,
        -2.2525, -2.2613, -2.5749, -2.72

In [36]:
actor = PPOActor(50, 100, 7)

In [37]:
time_dep, lane_markers, boundaries = obs[:3]

In [38]:
actor.train()  # set to training mode

# Define an optimizer.
optimizer = optim.Adam(actor.parameters(), lr=1e-3)

# Training loop: We'll do a few iterations updating the network using a fake loss defined as the mean absolute acceleration.
num_iterations = 1000
for i in trange(num_iterations):
    optimizer.zero_grad()
    # Forward pass.
    accel_output, log_prob = actor.get_action(time_dep, lane_markers, boundaries)  # shape (N, M, 2)
    # Compute a fake loss: mean of the absolute acceleration values.
    loss = torch.abs(accel_output).mean()
    loss.backward()
    optimizer.step()

    if (i+1) % 10 == 0:
        print(f"Iteration {i+1}/{num_iterations}, Loss: {loss.item():.4f}")

  2%|▏         | 19/1000 [00:00<00:16, 60.09it/s]

Iteration 10/1000, Loss: 0.7866
Iteration 20/1000, Loss: 0.7771


  4%|▍         | 38/1000 [00:00<00:16, 58.16it/s]

Iteration 30/1000, Loss: 0.7641
Iteration 40/1000, Loss: 0.7755


  6%|▋         | 63/1000 [00:01<00:15, 58.97it/s]

Iteration 50/1000, Loss: 0.7645
Iteration 60/1000, Loss: 0.7536


  8%|▊         | 81/1000 [00:01<00:15, 58.33it/s]

Iteration 70/1000, Loss: 0.7524
Iteration 80/1000, Loss: 0.7307


 10%|█         | 100/1000 [00:01<00:15, 58.50it/s]

Iteration 90/1000, Loss: 0.7258
Iteration 100/1000, Loss: 0.7249


 12%|█▏        | 118/1000 [00:02<00:15, 57.25it/s]

Iteration 110/1000, Loss: 0.7112
Iteration 120/1000, Loss: 0.7127


 14%|█▍        | 142/1000 [00:02<00:14, 58.15it/s]

Iteration 130/1000, Loss: 0.7019
Iteration 140/1000, Loss: 0.6957


 16%|█▌        | 160/1000 [00:02<00:14, 57.18it/s]

Iteration 150/1000, Loss: 0.6891
Iteration 160/1000, Loss: 0.6854


 18%|█▊        | 178/1000 [00:03<00:14, 57.49it/s]

Iteration 170/1000, Loss: 0.6790
Iteration 180/1000, Loss: 0.6705


 20%|█▉        | 196/1000 [00:03<00:14, 57.12it/s]

Iteration 190/1000, Loss: 0.6658
Iteration 200/1000, Loss: 0.6657


 22%|██▏       | 215/1000 [00:03<00:13, 57.68it/s]

Iteration 210/1000, Loss: 0.6478
Iteration 220/1000, Loss: 0.6487


 24%|██▍       | 240/1000 [00:04<00:12, 59.12it/s]

Iteration 230/1000, Loss: 0.6370
Iteration 240/1000, Loss: 0.6309


 26%|██▌       | 259/1000 [00:04<00:12, 59.60it/s]

Iteration 250/1000, Loss: 0.6264
Iteration 260/1000, Loss: 0.6292


 28%|██▊       | 283/1000 [00:04<00:12, 58.44it/s]

Iteration 270/1000, Loss: 0.6203
Iteration 280/1000, Loss: 0.6084


 30%|██▉       | 296/1000 [00:05<00:11, 58.82it/s]

Iteration 290/1000, Loss: 0.6010
Iteration 300/1000, Loss: 0.6053


 32%|███▏      | 316/1000 [00:05<00:11, 59.02it/s]

Iteration 310/1000, Loss: 0.5939
Iteration 320/1000, Loss: 0.5909


 34%|███▍      | 340/1000 [00:05<00:11, 57.54it/s]

Iteration 330/1000, Loss: 0.5861
Iteration 340/1000, Loss: 0.5811


 36%|███▌      | 358/1000 [00:06<00:11, 57.80it/s]

Iteration 350/1000, Loss: 0.5730
Iteration 360/1000, Loss: 0.5726


 38%|███▊      | 382/1000 [00:06<00:10, 58.75it/s]

Iteration 370/1000, Loss: 0.5686
Iteration 380/1000, Loss: 0.5560


 40%|████      | 400/1000 [00:06<00:10, 58.79it/s]

Iteration 390/1000, Loss: 0.5554
Iteration 400/1000, Loss: 0.5557


 42%|████▏     | 418/1000 [00:07<00:10, 56.85it/s]

Iteration 410/1000, Loss: 0.5499
Iteration 420/1000, Loss: 0.5441


 44%|████▎     | 436/1000 [00:07<00:09, 57.27it/s]

Iteration 430/1000, Loss: 0.5428
Iteration 440/1000, Loss: 0.5305


 46%|████▌     | 460/1000 [00:07<00:09, 57.88it/s]

Iteration 450/1000, Loss: 0.5323
Iteration 460/1000, Loss: 0.5319


 48%|████▊     | 479/1000 [00:08<00:08, 58.08it/s]

Iteration 470/1000, Loss: 0.5218
Iteration 480/1000, Loss: 0.5175


 50%|████▉     | 497/1000 [00:08<00:08, 58.18it/s]

Iteration 490/1000, Loss: 0.5127
Iteration 500/1000, Loss: 0.5110


 52%|█████▏    | 516/1000 [00:08<00:08, 59.72it/s]

Iteration 510/1000, Loss: 0.5076
Iteration 520/1000, Loss: 0.5038


 54%|█████▎    | 535/1000 [00:09<00:07, 59.36it/s]

Iteration 530/1000, Loss: 0.4944
Iteration 540/1000, Loss: 0.4925


 56%|█████▌    | 559/1000 [00:09<00:07, 58.47it/s]

Iteration 550/1000, Loss: 0.4950
Iteration 560/1000, Loss: 0.4836


 58%|█████▊    | 578/1000 [00:09<00:07, 58.51it/s]

Iteration 570/1000, Loss: 0.4817
Iteration 580/1000, Loss: 0.4744


 60%|█████▉    | 596/1000 [00:10<00:06, 58.19it/s]

Iteration 590/1000, Loss: 0.4725
Iteration 600/1000, Loss: 0.4676


 62%|██████▏   | 615/1000 [00:10<00:06, 58.18it/s]

Iteration 610/1000, Loss: 0.4691
Iteration 620/1000, Loss: 0.4608


 64%|██████▍   | 640/1000 [00:11<00:06, 58.10it/s]

Iteration 630/1000, Loss: 0.4635
Iteration 640/1000, Loss: 0.4498


 66%|██████▌   | 658/1000 [00:11<00:06, 56.39it/s]

Iteration 650/1000, Loss: 0.4526
Iteration 660/1000, Loss: 0.4487


 68%|██████▊   | 676/1000 [00:11<00:05, 57.01it/s]

Iteration 670/1000, Loss: 0.4426
Iteration 680/1000, Loss: 0.4435


 70%|███████   | 700/1000 [00:12<00:05, 57.41it/s]

Iteration 690/1000, Loss: 0.4412
Iteration 700/1000, Loss: 0.4385


 72%|███████▏  | 718/1000 [00:12<00:04, 57.70it/s]

Iteration 710/1000, Loss: 0.4330
Iteration 720/1000, Loss: 0.4279


 74%|███████▍  | 742/1000 [00:12<00:04, 57.05it/s]

Iteration 730/1000, Loss: 0.4317
Iteration 740/1000, Loss: 0.4246


 76%|███████▌  | 761/1000 [00:13<00:04, 57.88it/s]

Iteration 750/1000, Loss: 0.4170
Iteration 760/1000, Loss: 0.4091


 78%|███████▊  | 779/1000 [00:13<00:03, 56.97it/s]

Iteration 770/1000, Loss: 0.4118
Iteration 780/1000, Loss: 0.4148


 80%|███████▉  | 797/1000 [00:13<00:03, 55.02it/s]

Iteration 790/1000, Loss: 0.4068
Iteration 800/1000, Loss: 0.4030


 82%|████████▏ | 815/1000 [00:14<00:03, 55.26it/s]

Iteration 810/1000, Loss: 0.4026
Iteration 820/1000, Loss: 0.3949


 84%|████████▍ | 839/1000 [00:14<00:02, 56.34it/s]

Iteration 830/1000, Loss: 0.3896
Iteration 840/1000, Loss: 0.3904


 86%|████████▌ | 857/1000 [00:14<00:02, 56.77it/s]

Iteration 850/1000, Loss: 0.3859
Iteration 860/1000, Loss: 0.3863


 88%|████████▊ | 875/1000 [00:15<00:02, 56.13it/s]

Iteration 870/1000, Loss: 0.3897
Iteration 880/1000, Loss: 0.3806


 90%|████████▉ | 899/1000 [00:15<00:01, 54.09it/s]

Iteration 890/1000, Loss: 0.3750
Iteration 900/1000, Loss: 0.3714


 92%|█████████▏| 917/1000 [00:15<00:01, 55.03it/s]

Iteration 910/1000, Loss: 0.3772
Iteration 920/1000, Loss: 0.3711


 94%|█████████▍| 941/1000 [00:16<00:01, 55.60it/s]

Iteration 930/1000, Loss: 0.3661
Iteration 940/1000, Loss: 0.3634


 96%|█████████▌| 959/1000 [00:16<00:00, 56.44it/s]

Iteration 950/1000, Loss: 0.3622
Iteration 960/1000, Loss: 0.3586


 98%|█████████▊| 977/1000 [00:17<00:00, 56.47it/s]

Iteration 970/1000, Loss: 0.3560
Iteration 980/1000, Loss: 0.3540


100%|██████████| 1000/1000 [00:17<00:00, 57.42it/s]

Iteration 990/1000, Loss: 0.3478
Iteration 1000/1000, Loss: 0.3475



