In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils

In [2]:
import random
from collections import deque
from env import HighwayEnv, convert_highd_sample_to_gail_expert
import numpy as np
import pandas as pd
from tqdm import tqdm, trange


In [None]:
class HighwayEnvBuffer:
    def __init__(self, h_dim, v_dim, f_dim, act_dim, m_dim=10, b_dim=2, size=800, gamma=0.99, lam=0.95, device='cuda'):
        self.gamma, self.lam = gamma, lam
        self.kn_buf = torch.zeros((size, h_dim, v_dim, f_dim))
        self.lm_buf = torch.zeros((size, m_dim))
        self.bd_buf = torch.zeros((size, b_dim))
        self.mk_buf = torch.zeros((size, v_dim))
        self.act_buf = torch.zeros((size, v_dim, act_dim))
        self.logp_buf = torch.zeros((size, v_dim))
        self.rew_buf = torch.zeros(size)
        self.val_buf = torch.zeros((size, v_dim))
        self.adv_buf = torch.zeros(size)
        self.ret_buf = torch.zeros(size)
        self.ptr = 0
        self.path_start_idx = 0
        self.max_size = size

    def store(self, kn, lm, bd, mk, act, logp, rew, val):
        assert self.ptr < self.max_size
        self.kn_buf[self.ptr]  = kn
        self.lm_buf[self.ptr]  = lm
        self.bd_buf[self.ptr]  = bd
        self.mk_buf[self.ptr]   = mk
        self.act_buf[self.ptr]  = act
        self.logp_buf[self.ptr] = logp
        self.rew_buf[self.ptr]  = rew
        self.val_buf[self.ptr]  = val
        self.ptr += 1

    def finish_path(self, last_val):
        """
        last_val: (M,)  value estimates for step after the end
        """
        slice_ = slice(self.path_start_idx, self.ptr)
        # append last_val to compute deltas
        rews = torch.cat([self.rew_buf[slice_], last_val.unsqueeze(0)], dim=0)  # (L+1, M)
        vals = torch.cat([self.val_buf[slice_], last_val.unsqueeze(0)], dim=0)  # (L+1, M)
        # GAE deltas
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]                  # (L, M)

        # compute advantage
        adv = torch.zeros_like(deltas)
        gae = torch.zeros(self.M)
        for t in reversed(range(deltas.shape[0])):
            gae = deltas[t] + self.gamma * self.lam * gae
            adv[t] = gae
        self.adv_buf[slice_] = adv

        # rewards-to-go
        ret = torch.zeros_like(rews)
        ret[-1] = last_val
        for t in reversed(range(rews.shape[0]-1)):
            ret[t] = rews[t] + self.gamma * ret[t+1]
        self.ret_buf[slice_] = ret[:-1]

        self.path_start_idx = self.ptr

    def get(self):
        assert self.ptr == self.max_size  # buffer full
        # normalize advantages
        adv_mean, adv_std = self.adv_buf.mean(), self.adv_buf.std() + 1e-8
        self.adv_buf = (self.adv_buf - adv_mean) / adv_std
        data = dict(kn=self.kn_buf,
                    lm=self.lm_buf,
                    bd=self.bd_buf,
                    mk=self.mk_buf,
                    act=self.act_buf,
                    logp=self.logp_buf,
                    ret=self.ret_buf,
                    adv=self.adv_buf)
        # reset pointer
        self.ptr = 0
        self.path_start_idx = 0
        return {k: v for k,v in data.items()}

In [4]:
def compute_seq_lengths(time_dep):
    """
    Given time_dep of shape (B, T, F), return a tensor of sequence lengths,
    where a timestep is considered valid if not all features are NaN.
    """
    # valid if not all F values are NaN:
    valid_mask = ~torch.all(torch.isnan(time_dep), dim=-1)  # shape (B, T)
    seq_lengths = valid_mask.sum(dim=1)  # (B,)
    seq_lengths[seq_lengths == 0] = 1  # avoid zeros
    return seq_lengths

In [86]:
#########################################
# Actor Network (Policy)
#########################################
class PPOActor(nn.Module):
    def __init__(self, h_dim, v_dim, f_dim, lstm_hidden=64, global_dim=12, combined_hidden=64, output_size=2):
        """
        Actor network that takes the structured observation and outputs
        a Gaussian distribution over acceleration commands.
        
        Observation:
           - time_dep: (N, T, M, F) 
           - lane_markers: (N, 10)
           - boundary_lines: (N, 2)
        Output:
           - accelerations: (N, M, 2)
        """
        super(PPOActor, self).__init__()
        self.T = h_dim
        self.M = v_dim
        self.F = f_dim
        self.lstm_hidden = lstm_hidden
        
        # LSTM for time-dependent kinematics (for each vehicle's sequence of length T and F features).
        self.lstm = nn.LSTM(input_size=self.F, hidden_size=lstm_hidden, batch_first=True)
        
        # Global information: process lane markers and boundaries separately.
        self.global_fc = nn.Linear(10, global_dim)  # lane markers of shape (N, 10)
        self.boundary_fc = nn.Linear(2, global_dim)   # boundaries of shape (N, 2)
        self.global_combine_fc = nn.Linear(2 * global_dim, global_dim)
        
        # Fully connected layers that combine per-vehicle features with global info.
        self.combine_fc = nn.Sequential(
            nn.Linear(lstm_hidden + global_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, output_size)
        )
        
        # Learnable log standard deviation for the Gaussian distribution over actions.
        # We use a parameter of shape (1, 1, output_size) that will be broadcast.
        self.log_std = nn.Parameter(torch.zeros(1, 1, output_size))
        
    def forward(self, time_dep, lane_markers, boundaries):
        """
        Forward pass.
        
        Parameters:
           time_dep: (N, T, M, F)
           lane_markers: (N, 10)  (may contain NaNs)
           boundaries: (N, 2)
        
        Returns:
           mean: (N, M, 2) — the mean acceleration for each vehicle slot.
           log_std: (N, M, 2) — the log_std, broadcasted along N and M.
        """
        if time_dep.dim() == 3:
            time_dep = time_dep.unsqueeze(0) 
        if lane_markers.dim() == 1:
            lane_markers = lane_markers.unsqueeze(0) 
        if boundaries.dim() == 1:
            boundaries = boundaries.unsqueeze(0) 

        N, T, M, _F = time_dep.shape
        
        # Process time-dependent kinematics:
        # Permute to (N, M, T, F) then flatten N and M to get (N*M, T, F)
        time_dep = time_dep.permute(0, 2, 1, 3).contiguous().view(N * M, T, _F)
        # Replace NaNs with 0:
        time_dep_clean = torch.nan_to_num(time_dep, nan=0.0)
        # Compute sequence lengths per vehicle:
        seq_lengths = compute_seq_lengths(time_dep)
        assert not torch.isnan(time_dep_clean).any()
        # Pack padded sequence:
        packed = rnn_utils.pack_padded_sequence(time_dep_clean, lengths=seq_lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        # hn: (1, N*M, lstm_hidden) -> squeeze to (N*M, lstm_hidden)
        hn = hn.squeeze(0)
        # Reshape back to (N, M, lstm_hidden)
        vehicle_repr = hn.view(N, M, self.lstm_hidden)
        
        # Process global inputs:
        # For lane markers, replace NaNs with 0 and create a mask.
        lane_markers_clean = torch.nan_to_num(lane_markers, nan=0.0)  # (N, 10)
        lane_mask = (~torch.isnan(lane_markers)).float()  # (N, 10)
        global_lane_features = F.relu(self.global_fc(lane_markers_clean))  # (N, global_dim)
        valid_ratio = lane_mask.mean(dim=1, keepdim=True)  # (N, 1)
        global_lane_features = global_lane_features * valid_ratio
        
        global_boundaries_features = F.relu(self.boundary_fc(boundaries))  # (N, global_dim)
        global_combined = torch.cat([global_lane_features, global_boundaries_features], dim=1)  # (N, 2*global_dim)
        global_info = F.relu(self.global_combine_fc(global_combined))  # (N, global_dim)
        global_info = global_info.unsqueeze(1).expand(-1, M, -1)  # (N, M, global_dim)
        
        # Combine vehicle representation with global info:
        combined = torch.cat([vehicle_repr, global_info], dim=-1)  # (N, M, lstm_hidden + global_dim)
        mean = self.combine_fc(combined)  # (N, M, output_size)
        
        # Broadcast log_std:
        log_std = self.log_std.expand(N, M, -1)  # (N, M, output_size)
        
        return mean, log_std
    
    def get_action(self, time_dep, lane_markers, boundaries, agent_mask):
        """
        Sample actions from the policy distribution.
        
        Returns:
           action: (N, M, output_size)
           log_prob: (N, M, output_size) or summed over output_size per vehicle.
        """
        _agent_mask = agent_mask.unsqueeze(0) if agent_mask.dim() == 1 else agent_mask
        mean, log_std = self.forward(time_dep, lane_markers, boundaries)
        # print (mean, log_std)
        std = torch.exp(log_std)
        # Create a normal distribution per vehicle slot.
        dist = torch.distributions.Normal(mean, std)
        # Sample actions using reparameterization (this allows for differentiable sampling).
        action = dist.rsample()  # shape (N, M, output_size)
        # Compute log probabilities.
        log_prob = dist.log_prob(action)  # shape (N, M, output_size)
        log_prob = log_prob.sum(dim=-1)  # aggregate over action dimensions, shape (N, M)
        # Mask out non-agent slots: set log_prob to 0 for non-agent vehicles.
        # This means that when computing the loss, only entries with agent_mask==1 will contribute.
        masked_log_prob = log_prob * _agent_mask  # agent_mask is assumed to be float, with 1 or 0.
        masked_action = action * _agent_mask.unsqueeze(-1) 
        
        return masked_action, masked_log_prob
    
#########################################
# Critic Network (Value Function)
#########################################
class PPOCritic(nn.Module):
    def __init__(self, h_dim, v_dim, f_dim, lstm_hidden=64, global_dim=12, combined_hidden=64):
        """
        Critic network that takes the same observation and outputs a scalar value for each sample.
        
        In many cases we want a state value per sample (N,1) that is derived from the entire observation.
        We can process each vehicle's time-dependent sequence similar to the actor,
        combine with global features, and then aggregate (for instance, by averaging over vehicles).
        """
        super(PPOCritic, self).__init__()
        self.T = h_dim
        self.M = v_dim
        self.F = f_dim
        self.lstm_hidden = lstm_hidden
        
        # LSTM for time-dependent kinematics (per vehicle):
        self.lstm = nn.LSTM(input_size=self.F, hidden_size=lstm_hidden, batch_first=True)
        
        # Global information processing:
        self.global_fc = nn.Linear(10, global_dim)  # for lane markers
        self.boundary_fc = nn.Linear(2, global_dim)   # for boundaries
        self.global_combine_fc = nn.Linear(2 * global_dim, global_dim)
        
        # Combine per-vehicle features with global info and aggregate:
        # We combine each vehicle's representation and then average over vehicles.
        self.combine_fc = nn.Sequential(
            nn.Linear(lstm_hidden + global_dim, combined_hidden),
            nn.ReLU(),
            nn.Linear(combined_hidden, 1)
        )
    
    def forward(self, time_dep, lane_markers, boundaries, agent_mask):
        """
        Forward pass.
        
        Parameters:
           time_dep: (N, T, M, F)
           lane_markers: (N, 10)
           boundaries: (N, 2)
           agent_mask: (N, M) binary mask for valid vehicles.
           
        Returns:
           values: (N, 1) scalar state-value estimates.
        """
        if time_dep.dim() == 3:
            time_dep = time_dep.unsqueeze(0) 
        if lane_markers.dim() == 1:
            lane_markers = lane_markers.unsqueeze(0) 
        if boundaries.dim() == 1:
            boundaries = boundaries.unsqueeze(0) 
        N, T, M, _F = time_dep.shape
        
        # Process time-dependent input as before:
        time_dep = time_dep.permute(0, 2, 1, 3).contiguous().view(N * M, T, _F)
        time_dep_clean = torch.nan_to_num(time_dep, nan=0.0)
        valid_mask = ~torch.all(torch.isnan(time_dep), dim=-1)  # (N*M, T)
        seq_lengths = valid_mask.sum(dim=1)
        seq_lengths[seq_lengths == 0] = 1
        packed = rnn_utils.pack_padded_sequence(time_dep_clean, lengths=seq_lengths.cpu(),
                                                  batch_first=True, enforce_sorted=False)
        packed_output, (hn, cn) = self.lstm(packed)
        hn = hn.squeeze(0)  # (N*M, lstm_hidden)
        vehicle_repr = hn.view(N, M, self.lstm_hidden)
        
        # Process global inputs:
        lane_markers_clean = torch.nan_to_num(lane_markers, nan=0.0)
        lane_mask = (~torch.isnan(lane_markers)).float()
        global_lane_features = F.relu(self.global_fc(lane_markers_clean))  # (N, global_dim)
        valid_ratio = lane_mask.mean(dim=1, keepdim=True)
        global_lane_features = global_lane_features * valid_ratio
        
        global_boundaries_features = F.relu(self.boundary_fc(boundaries))  # (N, global_dim)
        global_combined = torch.cat([global_lane_features, global_boundaries_features], dim=1)
        global_info = F.relu(self.global_combine_fc(global_combined))  # (N, global_dim)
        global_info = global_info.unsqueeze(1).expand(-1, M, -1)  # (N, M, global_dim)
        
        # Combine vehicle representation and global features.
        combined = torch.cat([vehicle_repr, global_info], dim=-1)  # (N, M, lstm_hidden+global_dim)
        # Process each vehicle:
        vehicle_values = self.combine_fc(combined).squeeze(-1)  # (N, M) after unsqueeze
        # Mask out invalid vehicles:
        vehicle_values = vehicle_values * agent_mask.float()
        return vehicle_values

In [82]:
torch.zeros((2,3,1)).squeeze(-1).shape

torch.Size([2, 3])

In [6]:
def ppo_update(actor, critic, buffer, 
               actor_lr=3e-4, critic_lr=1e-3, 
               clip_ratio=0.2, train_iters=80):
    data = buffer.get()
    kn, lm, bd, mk, act, logp_old, ret, adv = data.values()

    # optimizers
    a_optimizer = optim.Adam(actor.parameters(), lr=actor_lr)
    c_optimizer = optim.Adam(critic.parameters(), lr=critic_lr)

    for _ in range(train_iters):
        # Policy loss
        dist, logp = actor.get_dist(kn, lm, bd, mk)
        ratio = torch.exp(logp - logp_old)
        clipped_ratio = torch.clamp(ratio, 1 - clip_ratio, 1 + clip_ratio)
        policy_loss = -torch.mean(torch.min(ratio * adv, clipped_ratio * adv))

        # Value loss
        value = critic(kn, lm, bd, mk)
        value_loss = torch.mean((ret - value).pow(2))

        # Update actor
        a_optimizer.zero_grad()
        policy_loss.backward()
        a_optimizer.step()

        # Update critic
        c_optimizer.zero_grad()
        value_loss.backward()
        c_optimizer.step()

In [63]:
expert_data, df = convert_highd_sample_to_gail_expert(
    sample_csv=r"./data/26_sample_tracks.csv",
    meta_csv=r"E:\Data\highd-dataset-v1.0\data\26_recordingMeta.csv",
    forward=False,
    p_agent=0.90
)

In [64]:
# Create your environment instance.
env = HighwayEnv(generation_mode=True, demo_mode=False, T=50)
# Optionally, set expert data:
# expert_data, df = convert_highd_sample_to_gail_expert(...); env.set_expert_data(expert_data)

# Create PPO actor (policy) and critic networks.
# Let T = history length, M = max number of vehicles, F = features (e.g., 4).

# Uncomment and update the following line when expert_data is available:
env.set_expert_data(expert_data)

In [89]:
steps_per_epoch = 300
epochs = 50

actor  = PPOActor(50, 100, 7)
critic = PPOCritic(50, 100, 7)
buf = HighwayEnvBuffer(50, 100, 7, 2, size=steps_per_epoch)

for epoch in trange(epochs):
    _obs, ep_ret, ep_len = env.reset(), 0, 0
    obs = _obs.values()
    for t in range(steps_per_epoch):
        action, logp = actor.get_action(*obs)
        value = critic(*obs)
        next_obs, rew, done, _ = env.step(action.squeeze(0).detach().numpy())
        ep_ret += rew; ep_len += 1
        buf.store(*obs, action.squeeze(0), logp.squeeze(0), rew, value.squeeze(0))

        obs = next_obs.values()
        terminal = done or (t == steps_per_epoch - 1)
        if terminal:
            last_val = 0 if done else critic(*obs)
            buf.finish_path(last_val.squeeze(0))
            obs, ep_ret, ep_len = env.reset(), 0, 0

    # after collecting data, perform PPO update
    ppo_update(actor, critic, buf)

    print(f"Epoch {epoch+1}/{epochs} complete")

env.close()

  0%|          | 0/50 [00:04<?, ?it/s]


AttributeError: 'HighwayEnvBuffer' object has no attribute 'path_start'

In [36]:
action.squeeze(0).shape

torch.Size([100, 2])