# Comparing VBD Outputs: Waymax vs GPUDrive

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import numpy as np
import mediapy
import os
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
import torch
# Set working directory to the base directory 'gpudrive'
working_dir = Path.cwd()
while working_dir.name != 'gpudrive':
    working_dir = working_dir.parent.parent
    if working_dir == Path.home():
        raise FileNotFoundError("Base directory 'gpudrive' not found")
os.chdir(working_dir)

# GPUDrive dependencies
import gpudrive
from gpudrive.env.config import EnvConfig, RenderConfig, SceneConfig
from gpudrive.env.env_torch import GPUDriveTorchEnv
from gpudrive.env.dataset import SceneDataLoader
from gpudrive.visualize.utils import img_from_fig

# Plotting
sns.set("notebook")
sns.set_style("ticks", rc={"figure.facecolor": "none", "axes.facecolor": "none"})
#%config InlineBackend.figure_format = 'svg'

# Ignore all warnings
import warnings
warnings.filterwarnings("ignore")

2025-03-25 15:24:52.394752: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742930692.413969   56624 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742930692.419632   56624 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### Configurations

In [2]:
DATA_DIR = 'data/processed' # Base data path
CKPT_PATH = 'gpudrive/integrations/vbd/weights/epoch=18.ckpt'

SCENARIO_ID = 'efc5cbe01b4a526f'

FPS = 20
INIT_STEPS = 11 # Warmup period
MAX_CONTROLLED_OBJECTS = 32

### Make Videos

In [3]:
#Init GPUDrive env
env_config = EnvConfig(
    init_steps=INIT_STEPS, # Warmup period
    dynamics_model="state", # Use state-based dynamics model
    dist_to_goal_threshold=1e-5, # Trick to make sure the agents don't disappear when they reach the goal
    init_mode = 'all_non_trivial',
    use_vbd=True,
    max_controlled_agents=32,
    vbd_model_path=CKPT_PATH,
)
        
scene_config = SceneConfig(batch_size=1, dataset_size=1, path="data/processed/training", num_scenes=1)
# Make env
gpudrive_env = GPUDriveTorchEnv(
    config=env_config,
    data_loader = SceneDataLoader(
        root="data/processed/training",
        batch_size=scene_config.batch_size,
        dataset_size=scene_config.dataset_size,
    ),
    render_config=RenderConfig(resolution=(400, 400)),
    max_cont_agents=MAX_CONTROLLED_OBJECTS, # Maximum number of agents to control per scene
    device="cpu",
)
gpudrive_sample_batch = gpudrive_env._generate_sample_batch()

# Reset predictions tensor
pred_trajs = torch.zeros((gpudrive_env.num_worlds, gpudrive_env.max_agent_count, env_config.episode_len-INIT_STEPS, 10))

# Fill pred_trajs correctly for each world
for i in range(gpudrive_env.num_worlds):
    world_agent_indices = gpudrive_sample_batch['agents_id'][i]
    
    # Filter out negative indices (padding values)
    valid_mask = world_agent_indices >= 0  # Boolean mask of valid indices
    valid_agent_indices = world_agent_indices[valid_mask]  # Filtered tensor

    # Use tensor indexing with valid agent indices
    pred_trajs[i, valid_agent_indices, :, :2] = gpudrive_env.vbd_trajectories[i, valid_agent_indices, :, :2]  # pos x, y
    pred_trajs[i, valid_agent_indices, :, 3] = gpudrive_env.vbd_trajectories[i, valid_agent_indices, :, 2]    # yaw
    pred_trajs[i, valid_agent_indices, :, 4:6] = gpudrive_env.vbd_trajectories[i, valid_agent_indices, :, 3:5]  # vel x, y

# Now step through the simulation
gpudrive_frames = []
for t in range(env_config.episode_len-INIT_STEPS):
    gpudrive_env.step_dynamics(pred_trajs[:, :, t, :])
    fig = gpudrive_env.vis.plot_simulator_state(
        time_steps=[t],
        env_indices=[0],
        zoom_radius=70,
    )[0]
    gpudrive_frames.append(img_from_fig(fig))

Diffusion: 100%|██████████| 50/50 [00:17<00:00,  2.94it/s]


In [4]:
mediapy.write_video('gpudrive/integrations/vbd/viz/train.gif', gpudrive_frames, fps=FPS, codec="gif")
print("GIF saved at gpudrive/integrations/vbd/viz/train.gif")

GIF saved at gpudrive/integrations/vbd/viz/train.gif


In [5]:
print(pred_trajs.shape, gpudrive_env.vbd_trajectories.shape)

torch.Size([1, 64, 80, 10]) torch.Size([1, 64, 80, 5])


In [6]:
world_agent_indices = gpudrive_sample_batch['agents_id'][0]
valid_mask = world_agent_indices >= 0  # Boolean mask of valid indices
valid_agent_indices = world_agent_indices[valid_mask]  # Filtered tensor

print(world_agent_indices)
print(valid_agent_indices)

tensor([32, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
        -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
       dtype=torch.int32)
tensor([32], dtype=torch.int32)


## Plotting VBD trajectory as part of agent obs

In [7]:
def transform_to_ego_frame(trajectory: torch.Tensor, ego_pos: torch.Tensor, ego_yaw: torch.Tensor) -> torch.Tensor:
    """
    Transform trajectory from global coordinates to ego-centric frame.
    Args:
    trajectory: Shape (time_steps, 2) containing x,y coordinates in global frame
    ego_pos: Shape (2,) containing ego x,y position
    ego_yaw: Shape (1,) containing ego yaw angle in radians
    Returns:
    transformed_trajectory: Shape (time_steps, 2) in ego-centric frame
    """
    # Step 1: Translate trajectory to be relative to ego position
    translated = trajectory - ego_pos
    
    # Step 2: Rotate trajectory to align with ego orientation
    # Create rotation matrix
    cos_yaw = torch.cos(ego_yaw)
    sin_yaw = torch.sin(ego_yaw)
    rotation_matrix = torch.tensor([
        [cos_yaw, sin_yaw],
        [-sin_yaw, cos_yaw]
    ])
    
    # Apply rotation matrix to the translated trajectory
    # We need to transpose rotation_matrix for batch matrix multiplication
    transformed_trajectory = torch.matmul(translated, rotation_matrix.T)
    
    return transformed_trajectory

In [8]:
# plotting vbd trajectory as part of observation
from gpudrive.datatypes.observation import GlobalEgoState
init_state = gpudrive_env.reset()
# Get global agent observations
global_agent_obs = GlobalEgoState.from_tensor(
    abs_self_obs_tensor=gpudrive_env.sim.absolute_self_observation_tensor(),
    backend=gpudrive_env.backend,
    device=gpudrive_env.device,
)

for agent_index in valid_agent_indices:
    vbd_trajectory = gpudrive_env.vbd_trajectories[0, agent_index, :, :2]
    pos_xy = torch.tensor([global_agent_obs.pos_x[0, agent_index], global_agent_obs.pos_y[0, agent_index]])
    yaw = torch.tensor(global_agent_obs.rotation_angle[0, agent_index])

    transformed_trajectory = transform_to_ego_frame(vbd_trajectory, pos_xy, yaw)

    fig = gpudrive_env.vis.plot_agent_observation(
        agent_idx=agent_index,
        env_idx=0,
        trajectory=transformed_trajectory,
        figsize = (4, 4)
    )
    fig.savefig(f'gpudrive/integrations/vbd/viz/{agent_index}_vbd_trajectory.png', facecolor='white', transparent=False)