## Using Versatile Behavior Diffusion (VBD) with GPUDrive

---

> [VBD project page](https://sites.google.com/view/versatile-behavior-diffusion?pli=1) | [ArXiv](https://arxiv.org/abs/2404.02524)

---

In this notebook we demonstrate how you can generate realistic vehicle trajectories with VBD.

In [1]:
# Dependencies
import os
from pathlib import Path
import torch
import warnings
import mediapy
warnings.filterwarnings("ignore")

# Set working directory to the base directory 'gpudrive'
working_dir = Path.cwd()
while working_dir.name != 'gpudrive':
    working_dir = working_dir.parent
    if working_dir == Path.home():
        raise FileNotFoundError("Base directory 'gpudrive' not found")
os.chdir(working_dir)

from pygpudrive.env.config import EnvConfig, RenderConfig, SceneConfig, SelectionDiscipline
from pygpudrive.env.env_torch import GPUDriveTorchEnv

## Configuration

 - We only control valid agents up to a maximum of 32
- The initialization steps: 10
- We use the `StateDynamics` model
    - this model has a 5D action `(x, y, yaw, velocity x, velocity y)`

In [4]:
scene_config = SceneConfig(
    path="data/examples", 
    num_scenes=1,
    discipline=SelectionDiscipline.K_UNIQUE_N,
    k_unique_scenes=1,
)

env_config = EnvConfig(
    init_steps=10, # Warmup period
    enable_vbd=True, # Use VBD
    dynamics_model="state", # Use state-based dynamics model
)

render_config = RenderConfig()

## Make environment


In [5]:
env = GPUDriveTorchEnv(
    config=env_config,
    scene_config=scene_config,
    render_config=render_config,
    max_cont_agents=32, # Maximum number of agents to control per scene
    device="cpu",
)


--- Ratio unique scenes / number of worls =         1 / 1 ---



In [7]:
# Sanity check: We have a warmup period of 10 steps, so the first step should be 11
env.reset();
env.step_in_episode

selected_agents = torch.nonzero(env.cont_agent_mask[0, :]).flatten().tolist()

print(f"Selected agents: {selected_agents}")

Selected agents: [0, 1, 5]


## Load trained VBD model

In [8]:
# Load model
from vbd.sim_agent.sim_actor import VBDTest, sample_to_action

ckpt_path = 'vbd/weights/epoch=18.ckpt'

model = VBDTest.load_from_checkpoint(ckpt_path, map_location=torch.device('cpu'))
_ = model.cuda()
_ = model.eval()

### Sanity check: sample_batch shapes

In [9]:
sample = env.warmup_trajectory

for key in sample.keys():
    print(key, sample[key].shape)

agents_history torch.Size([1, 32, 11, 8])
agents_interested torch.Size([1, 32])
agents_type torch.Size([1, 32])
agents_future torch.Size([1, 32, 80, 5])
traffic_light_points torch.Size([1, 16, 3])
polylines torch.Size([1, 256, 30, 5])
polylines_valid torch.Size([1, 256])
relations torch.Size([1, 304, 304, 3])
agents_id torch.Size([1, 3])
anchors torch.Size([1, 32, 64, 2])


## Rollout without goal guidance

In [10]:
replan_freq=80 # Roll out every X steps 80 means openloop
model.early_stop=0 # Stop Diffusion Early From 100 to X
model.skip = 1 # Skip Alpha 
model.reward_func = None

# Reset the environment
init_state = env.reset()

current_state = init_state

# Obtain all info for diffusion model (warmup)
sample_batch = env.warmup_trajectory

# Make a prediction
pred = model.sample_denoiser(sample_batch)#, x_t=x_t)
traj_pred = pred['denoised_trajs'].cpu().numpy()[0]

is_controlled = sample_batch['agents_interested'] > 0

Diffusion: 100%|██████████| 50/50 [00:01<00:00, 36.67it/s]


In [13]:
frames = []

for i in range(80):
    t = i % replan_freq
    if t == 0:
        print("Replan at ", i)
        
        # Obtain all info for diffusion model (warmup)
        sample_batch = env.warmup_trajectory

        # Make a prediction
        pred = model.sample_denoiser(sample_batch)#, x_t=x_t)
        traj_pred = pred['denoised_trajs'].cpu().numpy()[0]

    # TODO: Convert sample to action
    action_sample = traj_pred[:, t, :]
    action = sample_to_action(action_sample, is_controlled, agents_id=selected_agents)
    
    # Step the environment with predicted actions
    env.step_dynamics(action)
    
    # TODO: Render the environment
    frame = env.render(world_render_idx=0)
    frames.append(frame)

Replan at  0


Diffusion:   0%|          | 0/50 [00:00<?, ?it/s]

Diffusion: 100%|██████████| 50/50 [00:01<00:00, 37.10it/s]


ValueError: Invalid agents_id size

### Todo: Show generated trajectories without goal guidance

In [None]:
# Takes a list of frames of shape (H, W, 3) and displays them as a video
mediapy.show_video(frames, fps=10)

## Generate goal positions