<a href="https://colab.research.google.com/github/9Tempest/23DesignPatterns/blob/master/waymax_demo_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Scenario Data Loading

This tutorial demonstrates how to load scenario data from the Waymo Open Motion Dataset (WOMD) using the Waymax dataloader.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/9Tempest/motionLM-Serve/blob/main/waymax_demo.ipynb"><img src="https://quantumai.google/site-assets/images/buttons/colab_logo_1x.png" />Run in Google Colab</a>
  </td>
</table>

In [None]:
%%capture
!pip install mediapy
!pip install git+https://github.com/waymo-research/waymax.git@main#egg=waymo-waymax
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses
import jax
from jax import numpy as jnp
import numpy as np
import mediapy


from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import dynamics
from waymax import env as _env
from waymax import agents
from waymax import visualization
from google.colab import auth




We first create a dataset config, using the default configs provided in the `waymax.config` module. In particular, `config.WOD_1_1_0_TRAINING` is a pre-defined configuration that points to version 1.1.0 of the Waymo Open Dataset.

The data config contains a number of options to configure how and where the dataset is loaded from. By default, the `WOD_1_1_0_TRAINING` loads up to 128 objects (e.g. vehicles, pedestrians) per scenario. Here, we can save memory and compute by loading only the first 32 objects stored in the scenario.

We use the `dataloader.simulator_state_generator` function to create an iterator
through Open Motion Dataset scenarios. Calling next on the iterator will retrieve the first scenario in the dataset.


In [None]:
auth.authenticate_user()
!gsutil cp gs://waymo_open_dataset_motion_v_1_2_0/uncompressed/tf_example/training/training_tfexample.tfrecord-00000-of-01000 /content/training_tfexample.tfrecord


In [None]:

config = _config.DatasetConfig(path ='/content/training_tfexample.tfrecord',
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=32)
data_iter = dataloader.simulator_state_generator(config=config)
# Check if the iterator is empty before calling next
try:
    scenario = next(data_iter)
    print(scenario)
except StopIteration:
    print("The data iterator is empty.")
    # Handle empty iterator (e.g., reload data, check config)

Next, we can plot the initial state of this scenario. We use a matplotlib-based visualization available in the `waymax.visualization` package.

In [None]:
# Using logged trajectory
img = visualization.plot_simulator_state(scenario, use_log_traj=True)
mediapy.show_image(img)

The Waymo Open Motion Dataset consists of 9-second trajectory snippets. We can visualize the entire logged trajectory as a video as follows:

In [None]:
imgs = []

state = scenario
for _ in range(scenario.remaining_timesteps):
  state = datatypes.update_state_by_log(state, num_steps=1)
  imgs.append(visualization.plot_simulator_state(state, use_log_traj=True))

mediapy.show_video(imgs, fps=10)

## Initializing and Running the Simulator

Waymax uses a Gym-like interface for running closed-loop simulation.

The `env.MultiAgentEnvironment` class defines a stateless simulation interface with the two key methods:
- The `reset` method initializes and returns the first simulation state.
- The `step` method transitions the simulation and takes as arguments a state and an action and outputs the next state.

Crucially, the `MultiAgentEnvironment` does not hold any simulation state itself, and the `reset` and `step` functions have no side effects. This allows us to use functional transforms from JAX, such as using jit compilation to optimize the compuation. It also allows the user to arbitrarily branch and restart simulation from any state, or save the simulation by simply serializing and saving the state object.



In [None]:
# Config the multi-agent environment:
init_steps = 11

# Set the dynamics model the environment is using.
# Note each actor interacting with the environment needs to provide action
# compatible with this dynamics model.
dynamics_model = dynamics.StateDynamics()

# Expect users to control all valid object in the scene.
env = _env.MultiAgentEnvironment(
    dynamics_model=dynamics_model,
    config=dataclasses.replace(
        _config.EnvironmentConfig(),
        max_num_objects=32,
        controlled_object=_config.ObjectType.VALID,
    ),
)

We now create a set of sim agents to run in simulation. By default, the behavior of an object that is not controlled is to replay the behavior stored in the dataset (log playback).

For each sim agent, we define the algorithm (such as IDM), and specify which objects the agent controls via the `is_controlled_func`, which is required to return a boolean mask marking which objects are being controlled.

The IDM agent we use in this example is the `IDMRoutePolicy`, which follows the spatial trajectory stored in the logs, but adjusts the speed profile based on the IDM rule, which will stop or speed up according to the distance between the vehicle and any objects in front of the vehicle. For the remaining agents, we set them to use a constant speed policy which will follow the logged route with a fixed, constant speed.

In [None]:
# Setup a few actors, see visualization below for how each actor behaves.

# An actor that doesn't move, controlling all objects with index > 4
obj_idx = jnp.arange(32)
static_actor = agents.create_constant_speed_actor(
    speed=0.0,
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx > 4,
)

# IDM actor/policy controlling both object 0 and 1.
# Note IDM policy is an actor hard-coded to use dynamics.StateDynamics().
actor_0 = agents.IDMRoutePolicy(
    is_controlled_func=lambda state: (obj_idx == 0) | (obj_idx == 1)
)

# Constant speed actor with predefined fixed speed controlling object 2.
actor_1 = agents.create_constant_speed_actor(
    speed=5.0,
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx == 2,
)

# Exper/log actor controlling objects 3 and 4.
actor_2 = agents.create_expert_actor(
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: (obj_idx == 3) | (obj_idx == 4),
)

actors = [static_actor, actor_0, actor_1, actor_2]

We can (optionally) jit the step and select action functions to speed up computation.

In [None]:
jit_step = jax.jit(env.step)
jit_select_action_list = [jax.jit(actor.select_action) for actor in actors]

We can now write a for loop to all of these agents in simulation together.

In [None]:
states = [env.reset(scenario)]
for _ in range(states[0].remaining_timesteps):
  current_state = states[-1]

  outputs = [
      jit_select_action({}, current_state, None, None)
      for jit_select_action in jit_select_action_list
  ]
  action = agents.merge_actions(outputs)
  next_state = jit_step(current_state, action)

  states.append(next_state)

## Visualization of simulation.

We can now visualize the result of the simulation loop.

On the left side:
- Objects 5, 6, and 7 (controlled by static_actor) remain static.
- Objects 3 and 4 controlled by log playback, and collide with objects 5 and 6.

On the right side:
- Object 2 controlled by actor_1 is moving at constant speed 5m/s (i.e. slower than log in this case).
- Object 0 and 1, controlled by the IDM agent, follow the log in the beginning, but object 1 slows down when approaching object 2.

In [None]:
imgs = []
for state in states:
  imgs.append(visualization.plot_simulator_state(state, use_log_traj=False))
mediapy.show_video(imgs, fps=10)

In [None]:
# Wayformer Model Setup
import tensorflow as tf
import numpy as np
from typing import Optional, Tuple, List
from dataclasses import dataclass

In [None]:
@dataclass
class WayformerTrainingConfig:
    """Configuration for Wayformer model training."""
    num_map_feature: int = 11  # Road feature dimensions
    num_agent_feature: int = 9  # Agent feature dimensions
    hidden_size: int = 256
    max_num_agents: int = 32
    num_modes: int = 6
    future_len: int = 80  # 8 seconds with 10Hz
    past_len: int = 11   # 1 second with 10Hz
    dropout: float = 0.1
    tx_num_heads: int = 8
    max_points_per_lane: int = 40
    max_num_roads: int = 50
    num_queries_enc: int = 128
    num_queries_dec: int = 64
    learning_rate: float = 1e-4
    batch_size: int = 32
    num_epochs: int = 10

@dataclass
class ModuleOutput:
    last_hidden_state: tf.Tensor
    kv_cache: Optional[Tuple[tf.Tensor, tf.Tensor]] = None

class TrainableQueryProvider(tf.keras.layers.Layer):
    def __init__(self, num_queries: int, num_query_channels: int, init_scale: float = 0.02):
        super().__init__()
        self.num_queries = num_queries
        self.num_query_channels = num_query_channels
        self.init_scale = init_scale

    def build(self, input_shape):
        self.query = self.add_weight(
            shape=(self.num_queries, self.num_query_channels),
            initializer=tf.keras.initializers.RandomNormal(stddev=self.init_scale),
            trainable=True,
            name='query'
        )

    def call(self, x=None):
        return tf.expand_dims(self.query, 0)  # Add batch dimension

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(
        self,
        num_heads: int,
        num_q_input_channels: int,
        num_kv_input_channels: int,
        num_qk_channels: Optional[int] = None,
        num_v_channels: Optional[int] = None,
        num_output_channels: Optional[int] = None,
        max_heads_parallel: Optional[int] = None,
        causal_attention: bool = False,
        dropout: float = 0.0,
        qkv_bias: bool = True,
        out_bias: bool = True,
    ):
        super().__init__()

        if num_qk_channels is None:
            num_qk_channels = num_q_input_channels

        if num_v_channels is None:
            num_v_channels = num_qk_channels

        if num_output_channels is None:
            num_output_channels = num_q_input_channels

        self.num_heads = num_heads
        self.dp_scale = (num_qk_channels // num_heads) ** -0.5
        self.causal_attention = causal_attention

        self.q_proj = tf.keras.layers.Dense(num_qk_channels, use_bias=qkv_bias)
        self.k_proj = tf.keras.layers.Dense(num_qk_channels, use_bias=qkv_bias)
        self.v_proj = tf.keras.layers.Dense(num_v_channels, use_bias=qkv_bias)
        self.o_proj = tf.keras.layers.Dense(num_output_channels, use_bias=out_bias)
        self.dropout = tf.keras.layers.Dropout(dropout)

    def call(self, x_q, x_kv, pad_mask=None, training=False):
        batch_size = tf.shape(x_q)[0]

        # Linear projections and reshape for multi-head attention
        q = self.q_proj(x_q)  # [batch_size, seq_len_q, d_model]
        k = self.k_proj(x_kv)  # [batch_size, seq_len_k, d_model]
        v = self.v_proj(x_kv)  # [batch_size, seq_len_k, d_model]

        # Reshape to [batch_size, num_heads, seq_len, depth]
        q = self._reshape_for_heads(q)
        k = self._reshape_for_heads(k)
        v = self._reshape_for_heads(v)

        # Scale query
        q = q * self.dp_scale

        # Calculate attention scores
        attn = tf.matmul(q, k, transpose_b=True)

        if pad_mask is not None:
            pad_mask = tf.expand_dims(tf.expand_dims(pad_mask, 1), 1)
            attn = tf.where(pad_mask, tf.float32.min, attn)

        if self.causal_attention:
            causal_mask = self._create_causal_mask(tf.shape(q)[2], tf.shape(k)[2])
            attn = tf.where(causal_mask, tf.float32.min, attn)

        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.dropout(attn, training=training)

        # Calculate output
        output = tf.matmul(attn, v)
        output = self._reshape_from_heads(output)
        output = self.o_proj(output)

        return ModuleOutput(last_hidden_state=output)

    def _reshape_for_heads(self, x):
        batch_size = tf.shape(x)[0]
        seq_len = tf.shape(x)[1]
        depth = tf.shape(x)[2] // self.num_heads

        x = tf.reshape(x, [batch_size, seq_len, self.num_heads, depth])
        return tf.transpose(x, [0, 2, 1, 3])

    def _reshape_from_heads(self, x):
        batch_size = tf.shape(x)[0]
        seq_len = tf.shape(x)[2]

        x = tf.transpose(x, [0, 2, 1, 3])
        return tf.reshape(x, [batch_size, seq_len, -1])

    def _create_causal_mask(self, seq_len_q, seq_len_k):
        mask = tf.ones((seq_len_q, seq_len_k), dtype=tf.bool)
        mask = tf.linalg.band_part(mask, -1, 0)  # Lower triangular
        return mask

class PerceiverEncoder(tf.keras.layers.Layer):
    def __init__(
        self,
        num_latents: int,
        num_latent_channels: int,
        num_cross_attention_heads: int = 4,
        num_cross_attention_qk_channels: Optional[int] = None,
        num_cross_attention_v_channels: Optional[int] = None,
        num_cross_attention_layers: int = 1,
        dropout: float = 0.1,
        init_scale: float = 0.02,
    ):
        super().__init__()

        self.latent_provider = TrainableQueryProvider(
            num_latents,
            num_latent_channels,
            init_scale=init_scale
        )

        self.cross_attention = MultiHeadAttention(
            num_heads=num_cross_attention_heads,
            num_q_input_channels=num_latent_channels,
            num_kv_input_channels=num_latent_channels,
            num_qk_channels=num_cross_attention_qk_channels,
            num_v_channels=num_cross_attention_v_channels,
            dropout=dropout
        )

        self.self_attention = MultiHeadAttention(
            num_heads=num_cross_attention_heads,
            num_q_input_channels=num_latent_channels,
            num_kv_input_channels=num_latent_channels,
            dropout=dropout
        )

        self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout = tf.keras.layers.Dropout(dropout)

    def call(self, x, pad_mask=None, training=False):
        x_latent = self.latent_provider()

        # Cross attention
        residual = x_latent
        x_latent = self.layer_norm1(x_latent)
        cross_attn_output = self.cross_attention(x_latent, x, pad_mask=pad_mask, training=training)
        x_latent = residual + self.dropout(cross_attn_output.last_hidden_state, training=training)

        # Self attention
        residual = x_latent
        x_latent = self.layer_norm2(x_latent)
        self_attn_output = self.self_attention(x_latent, x_latent, training=training)
        x_latent = residual + self.dropout(self_attn_output.last_hidden_state, training=training)

        return x_latent


def train_wayformer(model_config, action_space, training_data):
    """Train Wayformer model with better error handling and debugging"""
    model = Wayformer(model_config, action_space)
    optimizer = tf.keras.optimizers.Adam(learning_rate=model_config.learning_rate)

    # Convert data to numpy arrays for TensorFlow
    training_data = [
        {k: (v.numpy() if hasattr(v, 'numpy') else v) for k, v in scenario.items()}
        for scenario in training_data
    ]

    num_samples = len(training_data)
    num_batches = num_samples // model_config.batch_size
    print(f"Training on {num_samples} samples with {num_batches} batches per epoch")

    for epoch in range(model_config.num_epochs):
        print(f"\nEpoch {epoch + 1}/{model_config.num_epochs}")
        total_loss = 0.0
        batch_losses = []

        # Shuffle the data
        np.random.shuffle(training_data)

        for batch_idx in tqdm(range(num_batches)):
            try:
                # Get batch
                start_idx = batch_idx * model_config.batch_size
                end_idx = start_idx + model_config.batch_size
                batch_data = training_data[start_idx:end_idx]

                # Prepare input
                model_inputs = prepare_model_input(batch_data)
                model_inputs = {k: tf.convert_to_tensor(v, dtype=tf.float32) for k, v in model_inputs.items()}

                # Debug shapes
                print(f"\nBatch {batch_idx} input shapes:")
                for k, v in model_inputs.items():
                    print(f"{k}: {v.shape}")

                with tf.GradientTape() as tape:
                    predictions = model(model_inputs, training=True)
                    loss = compute_trajectory_loss(
                        predictions['predicted_trajectory'],
                        predictions['predicted_probability'],
                        model_inputs['future_trajectory'],
                        tf.ones_like(model_inputs['future_trajectory'][..., 0])
                    )

                # Compute and apply gradients
                gradients = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(zip(gradients, model.trainable_variables))

                batch_losses.append(float(loss.numpy()))
                total_loss += float(loss.numpy())

                print(f"Batch {batch_idx} loss: {float(loss.numpy()):.4f}")

            except Exception as e:
                print(f"\nError processing batch {batch_idx}:")
                print(str(e))
                print("\nFull traceback:")
                import traceback
                traceback.print_exc()
                continue

        # Print epoch statistics
        if batch_losses:
            avg_loss = total_loss / len(batch_losses)
            print(f"\nEpoch {epoch + 1} statistics:")
            print(f"Average Loss: {avg_loss:.4f}")
            print(f"Min Batch Loss: {min(batch_losses):.4f}")
            print(f"Max Batch Loss: {max(batch_losses):.4f}")
        else:
            print(f"\nEpoch {epoch + 1}: No successful batches")

    return model

In [None]:
def process_waymax_data(state, action_space):
    """Convert waymax SimulatorState to model input format with consistent shapes"""
    trajectory = state.sim_trajectory
    timestep = state.timestep
    past_steps = 11  # Match the config.past_len
    future_steps = 80  # Match the config.future_len

    # Get past trajectory features (last past_steps timesteps before current)
    past_idx = slice(max(0, timestep - past_steps + 1), timestep + 1)

    # Extract and pad past features if necessary
    xy = trajectory.xy[..., past_idx, :]
    vel_xy = trajectory.vel_xy[..., past_idx, :]
    yaw = trajectory.yaw[..., past_idx]
    length = trajectory.length[..., past_idx]
    width = trajectory.width[..., past_idx]
    height = trajectory.height[..., past_idx]
    valid = trajectory.valid[..., past_idx]

    # Pad if we don't have enough past steps
    if xy.shape[-2] < past_steps:
        pad_length = past_steps - xy.shape[-2]
        xy = jnp.pad(xy, ((0, 0), (pad_length, 0), (0, 0)), mode='edge')
        vel_xy = jnp.pad(vel_xy, ((0, 0), (pad_length, 0), (0, 0)), mode='edge')
        yaw = jnp.pad(yaw, ((0, 0), (pad_length, 0)), mode='edge')
        length = jnp.pad(length, ((0, 0), (pad_length, 0)), mode='edge')
        width = jnp.pad(width, ((0, 0), (pad_length, 0)), mode='edge')
        height = jnp.pad(height, ((0, 0), (pad_length, 0)), mode='edge')
        valid = jnp.pad(valid, ((0, 0), (pad_length, 0)), mode='constant', constant_values=0)

    # Expand dimensions for single-value features
    yaw = jnp.expand_dims(yaw, axis=-1)
    length = jnp.expand_dims(length, axis=-1)
    width = jnp.expand_dims(width, axis=-1)
    height = jnp.expand_dims(height, axis=-1)
    valid = jnp.expand_dims(valid, axis=-1)

    # Broadcast is_sdc to match time dimension
    is_sdc = jnp.expand_dims(state.object_metadata.is_sdc, axis=(1, -1))
    is_sdc = jnp.repeat(is_sdc, past_steps, axis=1)

    # Concatenate all features
    current_features = jnp.concatenate([
        xy,                     # (..., past_steps, 2)
        vel_xy,                 # (..., past_steps, 2)
        yaw,                    # (..., past_steps, 1)
        length,                 # (..., past_steps, 1)
        width,                  # (..., past_steps, 1)
        height,                 # (..., past_steps, 1)
        is_sdc,                 # (..., past_steps, 1)
        valid                   # (..., past_steps, 1)
    ], axis=-1)

    # Get future trajectory features
    future_idx = slice(timestep + 1, min(timestep + 1 + future_steps, trajectory.num_timesteps))
    future_xy = trajectory.xy[..., future_idx, :]
    future_vel_xy = trajectory.vel_xy[..., future_idx, :]
    future_yaw = jnp.expand_dims(trajectory.yaw[..., future_idx], axis=-1)
    future_valid = jnp.expand_dims(trajectory.valid[..., future_idx], axis=-1)

    # Pad future features if necessary
    if future_xy.shape[-2] < future_steps:
        pad_length = future_steps - future_xy.shape[-2]
        future_xy = jnp.pad(future_xy, ((0, 0), (0, pad_length), (0, 0)), mode='edge')
        future_vel_xy = jnp.pad(future_vel_xy, ((0, 0), (0, pad_length), (0, 0)), mode='edge')
        future_yaw = jnp.pad(future_yaw, ((0, 0), (0, pad_length), (0, 0)), mode='edge')
        future_valid = jnp.pad(future_valid, ((0, 0), (0, pad_length), (0, 0)), mode='constant', constant_values=0)

    future_features = jnp.concatenate([
        future_xy,
        future_vel_xy,
        future_yaw,
        future_valid
    ], axis=-1)

    # Process road features
    if hasattr(state, 'roadgraph_points') and state.roadgraph_points is not None:
        road_xyz = state.roadgraph_points.xyz
        road_dir = jnp.stack([
            state.roadgraph_points.dir_x,
            state.roadgraph_points.dir_y,
            state.roadgraph_points.dir_z
        ], axis=-1)
        road_valid = jnp.expand_dims(state.roadgraph_points.valid, axis=-1)
        road_features = jnp.concatenate([road_xyz, road_dir, road_valid], axis=-1)

        # Truncate or pad to fixed size
        max_road_points = 1024
        if road_features.shape[0] > max_road_points:
            road_features = road_features[:max_road_points]
        else:
            pad_length = max_road_points - road_features.shape[0]
            road_features = jnp.pad(road_features, ((0, pad_length), (0, 0)), mode='constant')
    else:
        road_features = jnp.zeros((1024, 7))
    # Assume ego vehicle is at index 0
    ego_index = 0  # Update this if necessary based on your data
    # Extract ego's past and future data
    ego_past_positions = xy[ego_index]  # Shape: (past_steps, 2)
    ego_past_velocities = vel_xy[ego_index]  # Shape: (past_steps, 2)
    ego_future_positions = future_xy[ego_index]  # Shape: (future_steps, 2)
    ego_future_velocities = future_vel_xy[ego_index]  # Shape: (future_steps, 2)
    # Combine past and future positions and velocities
    positions = np.concatenate([ego_past_positions, ego_future_positions], axis=0)
    velocities = np.concatenate([ego_past_velocities, ego_future_velocities], axis=0)

    # Encode the future trajectory into motion tokens
    # Note: For encoding, we need positions and velocities at T+1 points to compute T accelerations
    token_indices = encode_trajectory(positions, velocities, action_space)

    return {
        'current_features': current_features,  # (32, past_steps=11, 10)
        'future_features': future_features,    # (32, future_steps=80, 6)
        'road_features': road_features,        # (1024, 7)
        'object_metadata': {
            'is_sdc': state.object_metadata.is_sdc,
            'object_types': state.object_metadata.object_types,
            'is_valid': state.object_metadata.is_valid,
        },
        'motion_tokens': token_indices[past_steps - 1:]  # Use future tokens only
    }

def prepare_model_input(batch_data):
    """Prepare batch data for model input with correct shapes"""
    batch_size = len(batch_data)

    # Stack all features
    batched_data = {
        'current_features': jnp.stack([d['current_features'] for d in batch_data]),
        'future_features': jnp.stack([d['future_features'] for d in batch_data]),
        'road_features': jnp.stack([d['road_features'] for d in batch_data]),
        'object_metadata': {
            'is_sdc': jnp.stack([d['object_metadata']['is_sdc'] for d in batch_data])
        },
        'motion_tokens': jnp.stack([d['motion_tokens'] for d in batch_data])
    }

    # Find ego vehicle (SDC) in each scenario
    ego_indices = jnp.argmax(batched_data['object_metadata']['is_sdc'], axis=-1)

    # Extract ego and other agent features
    ego_features = []
    ego_future_features = []
    other_agents = []

    for i in range(batch_size):
        # Get ego features
        ego_feat = batched_data['current_features'][i, ego_indices[i]]
        ego_features.append(ego_feat)

        # Get ego future features
        ego_future_feat = batched_data['future_features'][i, ego_indices[i]]
        ego_future_features.append(ego_future_feat)

        # Get other agent features (excluding ego)
        agents = []
        for j in range(batched_data['current_features'].shape[1]):
            if j != ego_indices[i]:
                agents.append(batched_data['current_features'][i, j])
        agents = jnp.stack(agents[:31])  # Limit to 31 other agents

        # Pad if we have fewer than 31 agents
        if agents.shape[0] < 31:
            pad_length = 31 - agents.shape[0]
            agents = jnp.pad(agents, ((0, pad_length), (0, 0), (0, 0)), mode='constant')

        other_agents.append(agents)

    return {
        'ego_in': jnp.stack(ego_features),          # (batch_size, past_steps, features)
        'agents_in': jnp.stack(other_agents),       # (batch_size, 31, past_steps, features)
        'roads': batched_data['road_features'],     # (batch_size, max_road_points, features)
        'future_trajectory': jnp.stack(ego_future_features),  # (batch_size, future_steps, features)
        'motion_tokens': batched_data['motion_tokens']  # (batch_size, future_steps)
    }

def create_training_dataset(config, action_space, num_scenarios=1000):
    """Create training dataset from waymax data"""
    data_iter = dataloader.simulator_state_generator(config=config)
    dataset = []

    for i in tqdm(range(num_scenarios)):
        try:
            scenario = next(data_iter)
            if jnp.any(scenario.object_metadata.is_valid):
                try:
                    processed_data = process_waymax_data(scenario,action_space)
                    dataset.append(processed_data)
                    if i == 0:  # Print shapes for first successful scenario
                        print("\nFirst successful scenario shapes:")
                        print(f"current_features: {processed_data['current_features'].shape}")
                        print(f"future_features: {processed_data['future_features'].shape}")
                        print(f"road_features: {processed_data['road_features'].shape}")
                except Exception as e:
                    print(f"Failed to process scenario: {str(e)}")
                    continue
        except StopIteration:
            break

    print(f"\nSuccessfully processed {len(dataset)} scenarios")
    return dataset

In [None]:
def compute_trajectory_loss(pred_trajectories, pred_probabilities, target_trajectories, is_valid):
    """Compute training loss for trajectory prediction with proper shape handling"""
    # pred_trajectories shape: [batch, num_modes, time, 5]
    # target_trajectories shape: [batch, time, 6]
    # is_valid shape: [batch, time]

    batch_size = tf.shape(pred_trajectories)[0]
    num_modes = tf.shape(pred_trajectories)[1]

    # Expand target trajectories for comparison with each mode
    target_expanded = tf.expand_dims(target_trajectories[..., :5], axis=1)  # [batch, 1, time, 5]
    target_expanded = tf.tile(target_expanded, [1, num_modes, 1, 1])  # [batch, num_modes, time, 5]

    # Compute displacement error for each mode
    displacement_error = tf.reduce_mean(
        tf.sqrt(tf.reduce_sum(
            tf.square(pred_trajectories - target_expanded),
            axis=-1
        )),
        axis=-1  # Average over time
    )  # [batch, num_modes]

    # Find best matching mode for each sample
    min_mode_error = tf.reduce_min(displacement_error, axis=1)  # [batch]

    # Compute probability loss
    prob_loss = -tf.math.log(tf.clip_by_value(pred_probabilities, 1e-7, 1.0))
    prob_loss = tf.reduce_mean(prob_loss)  # Scalar

    # Combine losses
    total_loss = tf.reduce_mean(min_mode_error) + 0.1 * prob_loss

    return total_loss

class TransformerDecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        super(TransformerDecoderLayer, self).__init__()

        # Self-attention (masked)
        self.mha1 = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=d_model // num_heads, dropout=dropout_rate)

        # Cross-attention with encoder outputs
        self.mha2 = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=d_model // num_heads, dropout=dropout_rate)

        # Point-wise feedforward network
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
            tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
        ])

        # Layer normalizations
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        # Dropouts
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout3 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x, enc_output, look_ahead_mask=None, padding_mask=None, training=False):
        # x.shape == (batch_size, target_seq_len, d_model)
        # enc_output.shape == (batch_size, input_seq_len, d_model)

        # Masked self-attention (decoder's self-attention)
        attn1 = self.mha1(
            query=x,
            key=x,
            value=x,
            attention_mask=look_ahead_mask,
            training=training
        )  # (batch_size, target_seq_len, d_model)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(x + attn1)  # Residual connection and layer norm

        # Cross-attention with encoder outputs
        attn2 = self.mha2(
            query=out1,
            key=enc_output,
            value=enc_output,
            attention_mask=padding_mask,
            training=training
        )  # (batch_size, target_seq_len, d_model)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2)  # Residual connection and layer norm

        # Feedforward network
        ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)  # Residual connection and layer norm

        return out3

class Wayformer(tf.keras.Model):
    def __init__(self, config, action_space):
        super().__init__()
        # Initialize dimensions and parameters
        self.map_attr = config.num_map_feature
        self.k_attr = config.num_agent_feature
        self.d_k = config.hidden_size
        self._M = config.max_num_agents
        self.c = config.num_modes
        self.T = config.future_len
        self.dropout = config.dropout
        self.num_heads = config.tx_num_heads
        self.past_T = config.past_len

        # Input encoders
        self.road_pts_lin = tf.keras.layers.Dense(self.d_k)
        self.agents_dynamic_encoder = tf.keras.layers.Dense(self.d_k)

        # Positional embeddings
        self.temporal_embedding = self.add_weight(
            shape=(self.past_T, self.d_k),
            initializer='zeros',
            trainable=True,
            name='temporal_emb'
        )

        self.agent_embedding = self.add_weight(
            shape=(32, self.d_k),  # 32 = 1 ego + 31 other agents
            initializer='zeros',
            trainable=True,
            name='agent_emb'
        )

        # Agent processing layers
        self.agent_encoder = tf.keras.Sequential([
            tf.keras.layers.Dense(self.d_k, activation='relu'),
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Dropout(self.dropout)
        ])

        self.agent_transformer = tf.keras.layers.MultiHeadAttention(
            num_heads=8,
            key_dim=self.d_k//8,
            dropout=self.dropout
        )

        # Road processing layers
        self.road_encoder = tf.keras.Sequential([
            tf.keras.layers.Dense(self.d_k, activation='relu'),
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Dropout(self.dropout)
        ])

        # Output processing
        self.perceiver = PerceiverEncoder(
            num_latents=config.num_queries_enc,
            num_latent_channels=self.d_k
        )

        self.output_query = self.add_weight(
            shape=(self.c, self.d_k),
            initializer='random_normal',
            trainable=True,
            name='output_query'
        )

        self.trajectory_projection = tf.keras.layers.Dense(5 * self.T)
        self.mode_projection = tf.keras.layers.Dense(1)

        self.action_space = action_space
        self.vocab_size = action_space.vocab_size

        # Token embedding for motion tokens
        self.token_embedding = tf.keras.layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.d_k)

        # Positional encoding for tokens
        self.positional_encoding = self.add_weight(
            shape=(config.future_len, self.d_k),
            initializer='zeros',
            trainable=True,
            name='token_positional_encoding'
        )

        # Transformer decoder layers
        self.decoder_layers = [
            TransformerDecoderLayer(
                d_model=self.d_k,
                num_heads=config.tx_num_heads,
                dff=1024,
                dropout_rate=config.dropout
            ) for _ in range(4)
        ]

        # Output projection to predict token logits
        self.token_projection = tf.keras.layers.Dense(self.vocab_size)

    def build(self, input_shapes):
        """Build the model based on input shapes"""
        ego_shape = input_shapes['ego_in']
        agents_shape = input_shapes['agents_in']
        roads_shape = input_shapes['roads']

        self.built = True

    def encode_agents(self, ego_in, agents_in, training=False):
        """Process agent features with shape tracking"""
        batch_size = tf.shape(ego_in)[0]

        # Print shapes for debugging
        print(f"ego_in shape: {ego_in.shape}")
        print(f"agents_in shape: {agents_in.shape}")

        # Reshape ego to add agent dimension
        ego_expanded = tf.expand_dims(ego_in, axis=1)  # [batch, 1, time, features]

        # Concatenate ego with other agents
        all_agents = tf.concat([ego_expanded, agents_in], axis=1)  # [batch, 32, time, features]
        print(f"all_agents shape after concat: {all_agents.shape}")

        # Encode agent features
        encoded = self.agent_encoder(all_agents, training=training)
        print(f"encoded shape after agent_encoder: {encoded.shape}")

        # Add temporal embeddings
        temporal_emb = tf.expand_dims(tf.expand_dims(self.temporal_embedding, 0), 0)
        temporal_emb = tf.tile(temporal_emb, [batch_size, encoded.shape[1], 1, 1])
        encoded = encoded + temporal_emb

        # Add agent embeddings
        agent_emb = tf.expand_dims(tf.expand_dims(self.agent_embedding, 0), 2)
        agent_emb = tf.tile(agent_emb, [batch_size, 1, encoded.shape[2], 1])
        encoded = encoded + agent_emb

        print(f"encoded shape before transformer: {encoded.shape}")

        # Apply self-attention
        attended = self.agent_transformer(
            query=encoded,
            key=encoded,
            value=encoded,
            training=training
        )
        print(f"attended shape after transformer: {attended.shape}")

        # Final reshape
        final_encoded = tf.reshape(attended, [batch_size, -1, self.d_k])
        print(f"final_encoded shape: {final_encoded.shape}")

        return final_encoded

    def encode_roads(self, roads, training=False):
        # Process road features
        encoded = self.road_encoder(roads, training=training)
        return tf.reshape(encoded, [tf.shape(encoded)[0], -1, self.d_k])

    def call(self, inputs, training=False):
        """Forward pass with shape debugging"""
        print("\nInput shapes:")
        for k, v in inputs.items():
            print(f"{k}: {v.shape}")

        ego_in = inputs['ego_in']
        agents_in = inputs['agents_in']
        roads = inputs['roads']

        # Process agents
        agents_encoded = self.encode_agents(ego_in, agents_in, training)
        print(f"agents_encoded shape: {agents_encoded.shape}")

        # Process roads
        roads_encoded = self.encode_roads(roads, training)
        print(f"roads_encoded shape: {roads_encoded.shape}")

        # Combine features
        combined_features = tf.concat([agents_encoded, roads_encoded], axis=1)
        print(f"combined_features shape: {combined_features.shape}")

        # Apply perceiver encoding
        context = self.perceiver(combined_features, training=training)
        print(f"context shape: {context.shape}")

        # Generate outputs
        batch_size = tf.shape(ego_in)[0]
        query = tf.tile(tf.expand_dims(self.output_query, 0), [batch_size, 1, 1])
        print(f"query shape: {query.shape}")

        # Get trajectory predictions
        output_features = tf.matmul(query, context, transpose_b=True)
        print(f"output_features shape: {output_features.shape}")

        trajectories = self.trajectory_projection(output_features)
        trajectories = tf.reshape(trajectories, [batch_size, self.c, self.T, -1])
        print(f"trajectories shape: {trajectories.shape}")

        # Get mode probabilities
        mode_logits = tf.squeeze(self.mode_projection(output_features), axis=-1)
        mode_probs = tf.nn.softmax(mode_logits, axis=-1)
        print(f"mode_probs shape: {mode_probs.shape}")

        motion_tokens = inputs['motion_tokens']  # Shape: (batch_size, future_steps)

        # Embed motion tokens
        token_embeddings = self.token_embedding(motion_tokens)  # Shape: (batch_size, future_steps, d_k)
        token_embeddings += self.positional_encoding  # Add positional encoding

        # Prepare attention masks
        look_ahead_mask = tf.linalg.band_part(tf.ones((self.T, self.T)), -1, 0)  # Causal mask

        # Pass through decoder layers
        decoder_output = token_embeddings
        for layer in self.decoder_layers:
            decoder_output = layer(
                decoder_output, context, training=training,
                look_ahead_mask=look_ahead_mask
            )

        # Predict logits for next tokens
        logits = self.token_projection(decoder_output)  # Shape: (batch_size, future_steps, vocab_size)

        return {
            'predicted_trajectory': trajectories,
            'predicted_probability': mode_probs,
            'scene_emb': tf.reshape(output_features, [batch_size, -1]),
            'logits': logits
        }

In [None]:
class VerletActionSpace:
    def __init__(self, delta_interval=(-18.0, 18.0), num_bins=13, step_frequency=2.0):
        self.delta_interval = delta_interval
        self.num_bins = num_bins
        self.step_frequency = step_frequency
        self.time_step = 1.0 / step_frequency
        self.accel_bin_edges = np.linspace(delta_interval[0], delta_interval[1], num_bins + 1)
        self.accel_bin_centers = (self.accel_bin_edges[:-1] + self.accel_bin_edges[1:]) / 2

        # Create the vocabulary as the Cartesian product of acceleration bins for x and y
        self.vocab_size = num_bins * num_bins
        self.token_to_accel = np.array([
            (ax, ay) for ax in self.accel_bin_centers for ay in self.accel_bin_centers
        ])  # Shape (169, 2)

    def discretize_acceleration(self, accel):
        # accel: array of shape (..., 2)
        indices_x = np.digitize(accel[..., 0], self.accel_bin_edges) - 1
        indices_y = np.digitize(accel[..., 1], self.accel_bin_edges) - 1
        indices_x = np.clip(indices_x, 0, self.num_bins - 1)
        indices_y = np.clip(indices_y, 0, self.num_bins - 1)
        token_indices = indices_x * self.num_bins + indices_y
        return token_indices  # Shape (...)

    def continuous_acceleration(self, token_indices):
        # token_indices: array of shape (...)
        accel = self.token_to_accel[token_indices]
        return accel  # Shape (..., 2)

    def verlet_update(self, position, velocity, acceleration):
        # position, velocity, acceleration: arrays of shape (..., 2)
        new_position = position + velocity * self.time_step + 0.5 * acceleration * self.time_step ** 2
        new_velocity = velocity + acceleration * self.time_step
        return new_position, new_velocity

def encode_trajectory(positions, velocities, action_space):
    """
    Encode continuous trajectories into motion tokens.
    positions: array of shape (T+1, 2)
    velocities: array of shape (T+1, 2)
    Returns:
        token_indices: array of shape (T,)
    """
    time_step = action_space.time_step
    # Compute accelerations using inverse Verlet integration
    accelerations = (velocities[1:] - velocities[:-1]) / time_step
    # Discretize accelerations
    token_indices = action_space.discretize_acceleration(accelerations)
    return token_indices

def decode_tokens(start_position, start_velocity, token_indices, action_space):
    """
    Decode motion tokens back into continuous trajectories.
    start_position: array of shape (2,)
    start_velocity: array of shape (2,)
    token_indices: array of shape (T,)
    Returns:
        positions: array of shape (T+1, 2)
        velocities: array of shape (T+1, 2)
    """
    positions = [start_position]
    velocities = [start_velocity]
    for idx in token_indices:
        accel = action_space.continuous_acceleration(idx)
        new_position, new_velocity = action_space.verlet_update(
            positions[-1], velocities[-1], accel)
        positions.append(new_position)
        velocities.append(new_velocity)
    positions = np.stack(positions)
    velocities = np.stack(velocities)
    return positions, velocities




In [None]:
# Authenticate and setup
auth.authenticate_user()

# Configure dataset
config = _config.DatasetConfig(
    path='/content/training_tfexample.tfrecord',
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=32
)

# Create model config
model_config = WayformerTrainingConfig()

# Create training dataset
print("Creating training dataset...")
action_space = VerletActionSpace()
training_data = create_training_dataset(config, action_space, num_scenarios=1000)


In [None]:

# Train model
print("Training Wayformer model...")
model = train_wayformer(model_config, action_space, training_data)

# Evaluate on a test scenario
test_scenario = next(dataloader.simulator_state_generator(config=config))
test_data = process_waymax_data(test_scenario, action_space)

predictions = model({
    'ego_in': test_data['agent_features'][:1, 0],
    'agents_in': test_data['agent_features'][:1, 1:],
    'roads': test_data['road_features'][:1]
}, training=False)

# Visualize predictions
imgs = []
state = test_scenario
pred_trajectory = predictions['predicted_trajectory'][0]  # Take most likely trajectory

for t in range(state.remaining_timesteps):
    state = datatypes.update_state_by_log(state, num_steps=1)
    img = visualization.plot_simulator_state(
        state,
        use_log_traj=True,
        additional_trajectories={0: pred_trajectory[t]}  # Overlay predictions
    )
    imgs.append(img)

mediapy.show_video(imgs, fps=10)

In [None]:
import time

# Initialize the model
model = Wayformer(model_config, action_space)

# Get a test scenario
test_scenario = next(dataloader.simulator_state_generator(config=config))

# Number of autoregressive steps (e.g., 80 for 8 seconds at 10Hz)
num_autoregressive_steps = 80

# Initialize the state with the test scenario
state = test_scenario

# List to store images for video visualization
imgs = []

# Measure time before inference
start_time = time.time()

for step in range(num_autoregressive_steps):
    # Process the current state to get model input
    test_data = process_waymax_data(state, action_space)
    model_input = prepare_model_input([test_data])
    model_inputs = {k: tf.convert_to_tensor(v, dtype=tf.float32) for k, v in model_input.items()}

    # Perform inference
    predictions = model(model_inputs, training=False)

    # Get the most probable trajectory for the ego vehicle (mode with highest probability)
    mode_index = tf.argmax(predictions['predicted_probability'], axis=-1).numpy()[0]
    pred_trajectory = predictions['predicted_trajectory'][0, mode_index].numpy()

    # Extract the predicted position, velocity, and yaw for the next timestep
    next_position = pred_trajectory[0, :2]       # Position at next timestep
    next_velocity = pred_trajectory[0, 2:4]      # Velocity at next timestep
    next_yaw = pred_trajectory[0, 4]             # Yaw at next timestep

    # Update the state with the predicted values
    ego_index = 0

    # Use .at[].set() and replace() to update the JAX arrays
    state = state.replace(
        sim_trajectory=state.sim_trajectory.replace(
            x=state.sim_trajectory.x.at[ego_index, state.timestep + 1].set(next_position[0]),
            y=state.sim_trajectory.y.at[ego_index, state.timestep + 1].set(next_position[1]),
            vel_x=state.sim_trajectory.vel_x.at[ego_index, state.timestep + 1].set(next_velocity[0]),
            vel_y=state.sim_trajectory.vel_y.at[ego_index, state.timestep + 1].set(next_velocity[1]),
            yaw=state.sim_trajectory.yaw.at[ego_index, state.timestep + 1].set(next_yaw),
            valid=state.sim_trajectory.valid.at[ego_index, state.timestep + 1].set(True)
        ),
        timestep=state.timestep + 1  # Advance the timestep
    )

    # # Plot the current state
    # img = visualization.plot_simulator_state(
    #     state,
    #     use_log_traj=False  # We are using the predicted trajectory
    # )
    # imgs.append(img)

    # Break the loop if we've reached the end of available timesteps
    if state.timestep >= state.sim_trajectory.num_timesteps - 1:
        break

# Measure time after inference
end_time = time.time()

# Calculate elapsed time
elapsed_time = end_time - start_time

# Print the timing result
print(f"Auto-regressive inference time for {len(imgs)} steps: {elapsed_time:.6f} seconds")