<a href="https://colab.research.google.com/github/9Tempest/motionLM-Serve/blob/main/waymax_demo_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Scenario Data Loading

This tutorial demonstrates how to load scenario data from the Waymo Open Motion Dataset (WOMD) using the Waymax dataloader.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/9Tempest/motionLM-Serve/blob/main/waymax_demo.ipynb"><img src="https://quantumai.google/site-assets/images/buttons/colab_logo_1x.png" />Run in Google Colab</a>
  </td>
</table>

In [None]:
%%capture
!pip install mediapy
!pip install git+https://github.com/waymo-research/waymax.git@main#egg=waymo-waymax
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses
import jax
from jax import numpy as jnp
import numpy as np
import mediapy


from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import dynamics
from waymax import env as _env
from waymax import agents
from waymax import visualization
from google.colab import auth




We first create a dataset config, using the default configs provided in the `waymax.config` module. In particular, `config.WOD_1_1_0_TRAINING` is a pre-defined configuration that points to version 1.1.0 of the Waymo Open Dataset.

The data config contains a number of options to configure how and where the dataset is loaded from. By default, the `WOD_1_1_0_TRAINING` loads up to 128 objects (e.g. vehicles, pedestrians) per scenario. Here, we can save memory and compute by loading only the first 32 objects stored in the scenario.

We use the `dataloader.simulator_state_generator` function to create an iterator
through Open Motion Dataset scenarios. Calling next on the iterator will retrieve the first scenario in the dataset.


In [None]:
auth.authenticate_user()
!gsutil cp gs://waymo_open_dataset_motion_v_1_2_0/uncompressed/tf_example/training/training_tfexample.tfrecord-00000-of-01000 /content/training_tfexample.tfrecord


In [None]:

config = _config.DatasetConfig(path ='/content/training_tfexample.tfrecord',
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=32)
data_iter = dataloader.simulator_state_generator(config=config)
# Check if the iterator is empty before calling next
try:
    scenario = next(data_iter)
    print(scenario)
except StopIteration:
    print("The data iterator is empty.")
    # Handle empty iterator (e.g., reload data, check config)

Next, we can plot the initial state of this scenario. We use a matplotlib-based visualization available in the `waymax.visualization` package.

In [None]:
# Using logged trajectory
img = visualization.plot_simulator_state(scenario, use_log_traj=True)
mediapy.show_image(img)

The Waymo Open Motion Dataset consists of 9-second trajectory snippets. We can visualize the entire logged trajectory as a video as follows:

In [None]:
imgs = []

state = scenario
for _ in range(scenario.remaining_timesteps):
  state = datatypes.update_state_by_log(state, num_steps=1)
  imgs.append(visualization.plot_simulator_state(state, use_log_traj=True))

mediapy.show_video(imgs, fps=10)

## Initializing and Running the Simulator

Waymax uses a Gym-like interface for running closed-loop simulation.

The `env.MultiAgentEnvironment` class defines a stateless simulation interface with the two key methods:
- The `reset` method initializes and returns the first simulation state.
- The `step` method transitions the simulation and takes as arguments a state and an action and outputs the next state.

Crucially, the `MultiAgentEnvironment` does not hold any simulation state itself, and the `reset` and `step` functions have no side effects. This allows us to use functional transforms from JAX, such as using jit compilation to optimize the compuation. It also allows the user to arbitrarily branch and restart simulation from any state, or save the simulation by simply serializing and saving the state object.



In [None]:
# Config the multi-agent environment:
init_steps = 11

# Set the dynamics model the environment is using.
# Note each actor interacting with the environment needs to provide action
# compatible with this dynamics model.
dynamics_model = dynamics.StateDynamics()

# Expect users to control all valid object in the scene.
env = _env.MultiAgentEnvironment(
    dynamics_model=dynamics_model,
    config=dataclasses.replace(
        _config.EnvironmentConfig(),
        max_num_objects=32,
        controlled_object=_config.ObjectType.VALID,
    ),
)

We now create a set of sim agents to run in simulation. By default, the behavior of an object that is not controlled is to replay the behavior stored in the dataset (log playback).

For each sim agent, we define the algorithm (such as IDM), and specify which objects the agent controls via the `is_controlled_func`, which is required to return a boolean mask marking which objects are being controlled.

The IDM agent we use in this example is the `IDMRoutePolicy`, which follows the spatial trajectory stored in the logs, but adjusts the speed profile based on the IDM rule, which will stop or speed up according to the distance between the vehicle and any objects in front of the vehicle. For the remaining agents, we set them to use a constant speed policy which will follow the logged route with a fixed, constant speed.

In [None]:
# Setup a few actors, see visualization below for how each actor behaves.

# An actor that doesn't move, controlling all objects with index > 4
obj_idx = jnp.arange(32)
static_actor = agents.create_constant_speed_actor(
    speed=0.0,
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx > 4,
)

# IDM actor/policy controlling both object 0 and 1.
# Note IDM policy is an actor hard-coded to use dynamics.StateDynamics().
actor_0 = agents.IDMRoutePolicy(
    is_controlled_func=lambda state: (obj_idx == 0) | (obj_idx == 1)
)

# Constant speed actor with predefined fixed speed controlling object 2.
actor_1 = agents.create_constant_speed_actor(
    speed=5.0,
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx == 2,
)

# Exper/log actor controlling objects 3 and 4.
actor_2 = agents.create_expert_actor(
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: (obj_idx == 3) | (obj_idx == 4),
)

actors = [static_actor, actor_0, actor_1, actor_2]

We can (optionally) jit the step and select action functions to speed up computation.

In [None]:
jit_step = jax.jit(env.step)
jit_select_action_list = [jax.jit(actor.select_action) for actor in actors]

We can now write a for loop to all of these agents in simulation together.

In [None]:
states = [env.reset(scenario)]
for _ in range(states[0].remaining_timesteps):
  current_state = states[-1]

  outputs = [
      jit_select_action({}, current_state, None, None)
      for jit_select_action in jit_select_action_list
  ]
  action = agents.merge_actions(outputs)
  next_state = jit_step(current_state, action)

  states.append(next_state)

## Visualization of simulation.

We can now visualize the result of the simulation loop.

On the left side:
- Objects 5, 6, and 7 (controlled by static_actor) remain static.
- Objects 3 and 4 controlled by log playback, and collide with objects 5 and 6.

On the right side:
- Object 2 controlled by actor_1 is moving at constant speed 5m/s (i.e. slower than log in this case).
- Object 0 and 1, controlled by the IDM agent, follow the log in the beginning, but object 1 slows down when approaching object 2.

In [None]:
imgs = []
for state in states:
  imgs.append(visualization.plot_simulator_state(state, use_log_traj=False))
mediapy.show_video(imgs, fps=10)

In [None]:
# Wayformer Model Setup
import tensorflow as tf
import numpy as np
from typing import Optional, Tuple, List
from dataclasses import dataclass

In [None]:
@dataclass
class WayformerTrainingConfig:
    """Configuration for Wayformer model training."""
    num_map_feature: int = 11  # Road feature dimensions
    num_agent_feature: int = 9  # Agent feature dimensions
    hidden_size: int = 256
    max_num_agents: int = 32
    num_modes: int = 6
    future_len: int = 80  # 8 seconds with 10Hz
    past_len: int = 11   # 1 second with 10Hz
    dropout: float = 0.1
    tx_num_heads: int = 8
    max_points_per_lane: int = 40
    max_num_roads: int = 50
    num_queries_enc: int = 128
    num_queries_dec: int = 64
    learning_rate: float = 1e-4
    batch_size: int = 32
    num_epochs: int = 10

@dataclass
class ModuleOutput:
    last_hidden_state: tf.Tensor
    kv_cache: Optional[Tuple[tf.Tensor, tf.Tensor]] = None

class TrainableQueryProvider(tf.keras.layers.Layer):
    def __init__(self, num_queries: int, num_query_channels: int, init_scale: float = 0.02):
        super().__init__()
        self.num_queries = num_queries
        self.num_query_channels = num_query_channels
        self.init_scale = init_scale

    def build(self, input_shape):
        self.query = self.add_weight(
            shape=(self.num_queries, self.num_query_channels),
            initializer=tf.keras.initializers.RandomNormal(stddev=self.init_scale),
            trainable=True,
            name='query'
        )

    def call(self, x=None):
        return tf.expand_dims(self.query, 0)  # Add batch dimension

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(
        self,
        num_heads: int,
        num_q_input_channels: int,
        num_kv_input_channels: int,
        num_qk_channels: Optional[int] = None,
        num_v_channels: Optional[int] = None,
        num_output_channels: Optional[int] = None,
        max_heads_parallel: Optional[int] = None,
        causal_attention: bool = False,
        dropout: float = 0.0,
        qkv_bias: bool = True,
        out_bias: bool = True,
    ):
        super().__init__()

        if num_qk_channels is None:
            num_qk_channels = num_q_input_channels

        if num_v_channels is None:
            num_v_channels = num_qk_channels

        if num_output_channels is None:
            num_output_channels = num_q_input_channels

        self.num_heads = num_heads
        self.dp_scale = (num_qk_channels // num_heads) ** -0.5
        self.causal_attention = causal_attention

        self.q_proj = tf.keras.layers.Dense(num_qk_channels, use_bias=qkv_bias)
        self.k_proj = tf.keras.layers.Dense(num_qk_channels, use_bias=qkv_bias)
        self.v_proj = tf.keras.layers.Dense(num_v_channels, use_bias=qkv_bias)
        self.o_proj = tf.keras.layers.Dense(num_output_channels, use_bias=out_bias)
        self.dropout = tf.keras.layers.Dropout(dropout)

    def call(self, x_q, x_kv, pad_mask=None, training=False):
        batch_size = tf.shape(x_q)[0]

        # Linear projections and reshape for multi-head attention
        q = self.q_proj(x_q)  # [batch_size, seq_len_q, d_model]
        k = self.k_proj(x_kv)  # [batch_size, seq_len_k, d_model]
        v = self.v_proj(x_kv)  # [batch_size, seq_len_k, d_model]

        # Reshape to [batch_size, num_heads, seq_len, depth]
        q = self._reshape_for_heads(q)
        k = self._reshape_for_heads(k)
        v = self._reshape_for_heads(v)

        # Scale query
        q = q * self.dp_scale

        # Calculate attention scores
        attn = tf.matmul(q, k, transpose_b=True)

        if pad_mask is not None:
            pad_mask = tf.expand_dims(tf.expand_dims(pad_mask, 1), 1)
            attn = tf.where(pad_mask, tf.float32.min, attn)

        if self.causal_attention:
            causal_mask = self._create_causal_mask(tf.shape(q)[2], tf.shape(k)[2])
            attn = tf.where(causal_mask, tf.float32.min, attn)

        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.dropout(attn, training=training)

        # Calculate output
        output = tf.matmul(attn, v)
        output = self._reshape_from_heads(output)
        output = self.o_proj(output)

        return ModuleOutput(last_hidden_state=output)

    def _reshape_for_heads(self, x):
        batch_size = tf.shape(x)[0]
        seq_len = tf.shape(x)[1]
        depth = tf.shape(x)[2] // self.num_heads

        x = tf.reshape(x, [batch_size, seq_len, self.num_heads, depth])
        return tf.transpose(x, [0, 2, 1, 3])

    def _reshape_from_heads(self, x):
        batch_size = tf.shape(x)[0]
        seq_len = tf.shape(x)[2]

        x = tf.transpose(x, [0, 2, 1, 3])
        return tf.reshape(x, [batch_size, seq_len, -1])

    def _create_causal_mask(self, seq_len_q, seq_len_k):
        mask = tf.ones((seq_len_q, seq_len_k), dtype=tf.bool)
        mask = tf.linalg.band_part(mask, -1, 0)  # Lower triangular
        return mask

class PerceiverEncoder(tf.keras.layers.Layer):
    def __init__(
        self,
        num_latents: int,
        num_latent_channels: int,
        num_cross_attention_heads: int = 4,
        num_cross_attention_qk_channels: Optional[int] = None,
        num_cross_attention_v_channels: Optional[int] = None,
        num_cross_attention_layers: int = 1,
        dropout: float = 0.1,
        init_scale: float = 0.02,
    ):
        super().__init__()

        self.latent_provider = TrainableQueryProvider(
            num_latents,
            num_latent_channels,
            init_scale=init_scale
        )

        self.cross_attention = MultiHeadAttention(
            num_heads=num_cross_attention_heads,
            num_q_input_channels=num_latent_channels,
            num_kv_input_channels=num_latent_channels,
            num_qk_channels=num_cross_attention_qk_channels,
            num_v_channels=num_cross_attention_v_channels,
            dropout=dropout
        )

        self.self_attention = MultiHeadAttention(
            num_heads=num_cross_attention_heads,
            num_q_input_channels=num_latent_channels,
            num_kv_input_channels=num_latent_channels,
            dropout=dropout
        )

        self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout = tf.keras.layers.Dropout(dropout)

    def call(self, x, pad_mask=None, training=False):
        x_latent = self.latent_provider()

        # Cross attention
        residual = x_latent
        x_latent = self.layer_norm1(x_latent)
        cross_attn_output = self.cross_attention(x_latent, x, pad_mask=pad_mask, training=training)
        x_latent = residual + self.dropout(cross_attn_output.last_hidden_state, training=training)

        # Self attention
        residual = x_latent
        x_latent = self.layer_norm2(x_latent)
        self_attn_output = self.self_attention(x_latent, x_latent, training=training)
        x_latent = residual + self.dropout(self_attn_output.last_hidden_state, training=training)

        return x_latent

class Wayformer(tf.keras.Model):
    def __init__(self, config: WayformerTrainingConfig):
        super().__init__()
        # Initialize dimensions and parameters
        self.map_attr = config.num_map_feature
        self.k_attr = config.num_agent_feature
        self.d_k = config.hidden_size
        self._M = config.max_num_agents
        self.c = config.num_modes
        self.T = config.future_len
        self.dropout = config.dropout
        self.num_heads = config.tx_num_heads
        self.past_T = config.past_len
        self.max_points_per_lane = config.max_points_per_lane
        self.max_num_roads = config.max_num_roads
        self.num_queries_enc = config.num_queries_enc
        self.num_queries_dec = config.num_queries_dec

        # Input encoders
        self.road_pts_lin = tf.keras.layers.Dense(self.d_k)
        self.agents_dynamic_encoder = tf.keras.layers.Dense(self.d_k)

        # Positional embeddings
        self.agents_positional_embedding = self.add_weight(
            shape=(1, 1, self._M + 1, self.d_k),
            initializer='zeros',
            trainable=True,
            name='agents_pos_emb'
        )

        self.temporal_positional_embedding = self.add_weight(
            shape=(1, self.past_T, 1, self.d_k),
            initializer='zeros',
            trainable=True,
            name='temporal_pos_emb'
        )

        # Perceiver components
        self.perceiver_encoder = PerceiverEncoder(
            num_latents=self.num_queries_enc,
            num_latent_channels=self.d_k
        )

        self.output_query_provider = TrainableQueryProvider(
            num_queries=self.num_queries_dec,
            num_query_channels=self.d_k,
            init_scale=0.1
        )

        # Output heads
        self.prob_predictor = tf.keras.layers.Dense(1)
        self.output_model = tf.keras.layers.Dense(5 * self.T)

        self.selu = tf.keras.layers.Activation('selu')

    def call(self, inputs, training=False):
        ego_in = inputs['ego_in']
        agents_in = inputs['agents_in']
        roads = inputs['roads']

        # Process observations
        ego_tensor, opps_tensor, opps_masks, env_masks = self.process_observations(ego_in, agents_in)

        # Encode inputs
        agents_tensor = tf.concat([tf.expand_dims(ego_tensor, 2), opps_tensor], axis=2)
        agents_emb = self.selu(self.agents_dynamic_encoder(agents_tensor))

        # Add positional embeddings
        agents_emb = agents_emb + self.agents_positional_embedding + self.temporal_positional_embedding
        batch_size = tf.shape(agents_emb)[0]
        agents_emb = tf.reshape(agents_emb, [batch_size, -1, self.d_k])

        # Process road features
        road_pts_feats = self.selu(self.road_pts_lin(roads[:, :self.max_num_roads, :, :self.map_attr]))
        road_pts_feats = tf.reshape(road_pts_feats, [batch_size, -1, self.d_k])

        # Combine features
        mixed_input_features = tf.concat([agents_emb, road_pts_feats], axis=1)

        # Encode with Perceiver
        context = self.perceiver_encoder(mixed_input_features, training=training)

        # Generate outputs
        output_query = self.output_query_provider()
        out_seq = tf.matmul(output_query, context, transpose_b=True)

        # Predict trajectories and probabilities
        out_dists = self.output_model(out_seq[:, :self.c])
        out_dists = tf.reshape(out_dists, [batch_size, self.c, self.T, -1])

        mode_probs = self.prob_predictor(out_seq[:, :self.c])
        mode_probs = tf.reshape(mode_probs, [batch_size, self.c])

        return {
            'predicted_probability': mode_probs,
            'predicted_trajectory': out_dists,
            'scene_emb': tf.reshape(out_seq[:, :self.num_queries_dec], [batch_size, -1])
        }

    def process_observations(self, ego, agents):
        # Process ego vehicle data
        ego_tensor = ego[:, :, :self.k_attr]
        env_masks_orig = ego[:, :, -1]
        env_masks = tf.cast(1.0 - env_masks_orig, tf.bool)
        env_masks = tf.expand_dims(env_masks, 1)
        env_masks = tf.tile(env_masks, [1, self.num_queries_dec, 1])
        batch_size = tf.shape(ego)[0]
        env_masks = tf.reshape(env_masks, [batch_size * self.num_queries_dec, -1])

        # Process other agents data
        temp_masks = tf.concat([
            tf.ones_like(tf.expand_dims(env_masks_orig, -1)),
            agents[:, :, :, -1]
        ], axis=-1)
        opps_masks = tf.cast(1.0 - temp_masks, tf.bool)
        opps_tensor = agents[:, :, :, :self.k_attr]

        return ego_tensor, opps_tensor, opps_masks, env_masks

    def train_step(self, data):
        with tf.GradientTape() as tape:
            predictions = self(data['input_dict'], training=True)
            loss = self.compute_loss(predictions, data)

        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        return {'loss': loss}

    def compute_loss(self, predictions, data):
        # Implement your loss computation here
        # This should match the PyTorch Criterion class functionality
        pass


In [None]:
def process_waymax_data(state):
    """Convert waymax SimulatorState to model input format"""
    trajectory = state.sim_trajectory
    timestep = state.timestep

    # Print input shapes for debugging
    print("\nInput shapes:")
    print(f"xy shape: {trajectory.xy.shape}")
    print(f"vel_xy shape: {trajectory.vel_xy.shape}")
    print(f"yaw shape: {trajectory.yaw.shape}")
    print(f"length shape: {trajectory.length.shape}")
    print(f"width shape: {trajectory.width.shape}")
    print(f"height shape: {trajectory.height.shape}")
    print(f"is_sdc shape: {state.object_metadata.is_sdc.shape}")
    print(f"valid shape: {trajectory.valid.shape}")

    # Get current features, keeping all timesteps
    xy = trajectory.xy                          # Shape: (32, 91, 2)
    vel_xy = trajectory.vel_xy                  # Shape: (32, 91, 2)

    # Expand dimensions for all features to match xy shape
    yaw = jnp.expand_dims(trajectory.yaw, axis=-1)        # Shape: (32, 91, 1)
    length = jnp.expand_dims(trajectory.length, axis=-1)  # Shape: (32, 91, 1)
    width = jnp.expand_dims(trajectory.width, axis=-1)    # Shape: (32, 91, 1)
    height = jnp.expand_dims(trajectory.height, axis=-1)  # Shape: (32, 91, 1)
    valid = jnp.expand_dims(trajectory.valid, axis=-1)    # Shape: (32, 91, 1)

    # Fix broadcasting for is_sdc
    is_sdc = jnp.expand_dims(state.object_metadata.is_sdc, axis=(1, -1))  # Shape: (32, 1, 1)
    is_sdc = jnp.repeat(is_sdc, xy.shape[1], axis=1)  # Shape: (32, 91, 1)

    # Reshape xy and vel_xy to split the last dimension
    xy_list = [xy[..., i:i+1] for i in range(xy.shape[-1])]       # Split into list of (32, 91, 1)
    vel_xy_list = [vel_xy[..., i:i+1] for i in range(vel_xy.shape[-1])]  # Split into list of (32, 91, 1)

    # Print shapes after reshaping for debugging
    print("\nReshaped feature shapes:")
    print(f"xy_split_0: {xy_list[0].shape}")
    print(f"vel_xy_split_0: {vel_xy_list[0].shape}")
    print(f"yaw: {yaw.shape}")
    print(f"length: {length.shape}")
    print(f"width: {width.shape}")
    print(f"height: {height.shape}")
    print(f"is_sdc: {is_sdc.shape}")
    print(f"valid: {valid.shape}")

    # Concatenate all features along the last dimension
    current_features = jnp.concatenate(
        xy_list + vel_xy_list + [yaw, length, width, height, is_sdc, valid],
        axis=-1
    )

    # Get future trajectory features (for timesteps after current timestep)
    future_idx = slice(timestep + 1, trajectory.num_timesteps)
    future_xy = trajectory.xy[..., future_idx, :]
    future_vel_xy = trajectory.vel_xy[..., future_idx, :]
    future_yaw = jnp.expand_dims(trajectory.yaw[..., future_idx], axis=-1)
    future_valid = jnp.expand_dims(trajectory.valid[..., future_idx], axis=-1)

    # Concatenate future features
    future_features = jnp.concatenate([
        future_xy,
        future_vel_xy,
        future_yaw,
        future_valid
    ], axis=-1)

    # Process road features if available
    if hasattr(state, 'roadgraph_points') and state.roadgraph_points is not None:
        # Print roadgraph points attributes for debugging
        print("\nRoadgraph attributes:", dir(state.roadgraph_points))

        try:
            road_features = []

            # Add xyz coordinates
            if hasattr(state.roadgraph_points, 'xyz'):
                road_features.append(state.roadgraph_points.xyz)

            # Add direction if available (might be called heading or orientation)
            if hasattr(state.roadgraph_points, 'heading'):
                road_features.append(jnp.expand_dims(state.roadgraph_points.heading, -1))
            elif hasattr(state.roadgraph_points, 'orientation'):
                road_features.append(jnp.expand_dims(state.roadgraph_points.orientation, -1))

            # Add type if available
            if hasattr(state.roadgraph_points, 'type'):
                road_features.append(jnp.expand_dims(state.roadgraph_points.type, -1))

            # Add valid flag
            if hasattr(state.roadgraph_points, 'valid'):
                road_features.append(jnp.expand_dims(state.roadgraph_points.valid, -1))

            if road_features:
                road_features = jnp.concatenate(road_features, axis=-1)
            else:
                road_features = jnp.zeros(current_features.shape[:-1] + (11,))

        except Exception as e:
            print(f"\nError processing roadgraph features: {str(e)}")
            road_features = jnp.zeros(current_features.shape[:-1] + (11,))
    else:
        # Create dummy road features matching batch dimensions
        road_features = jnp.zeros(current_features.shape[:-1] + (11,))

    print("\nOutput shapes:")
    print(f"current_features: {current_features.shape}")
    print(f"future_features: {future_features.shape}")
    print(f"road_features: {road_features.shape}")

    return {
        'current_features': current_features,
        'future_features': future_features,
        'road_features': road_features,
        'object_metadata': {
            'is_sdc': state.object_metadata.is_sdc,
            'object_types': state.object_metadata.object_types,
            'is_valid': state.object_metadata.is_valid,
        }
    }

def create_training_dataset(config, num_scenarios=1000):
    """Create training dataset from waymax data"""
    data_iter = dataloader.simulator_state_generator(config=config)
    dataset = []

    for i in tqdm(range(num_scenarios)):
        try:
            scenario = next(data_iter)
            if jnp.any(scenario.object_metadata.is_valid):
                try:
                    processed_data = process_waymax_data(scenario)
                    dataset.append(processed_data)
                    if i == 0:  # Print shapes for first successful scenario
                        print("\nFirst successful scenario shapes:")
                        print(f"current_features: {processed_data['current_features'].shape}")
                        print(f"future_features: {processed_data['future_features'].shape}")
                        print(f"road_features: {processed_data['road_features'].shape}")
                except Exception as e:
                    print(f"Failed to process scenario: {str(e)}")
                    continue
        except StopIteration:
            break

    print(f"\nSuccessfully processed {len(dataset)} scenarios")
    return dataset

def prepare_model_input(batch_data):
    """Prepare batch data for model input"""
    # Find ego vehicle (SDC) in each scenario
    ego_indices = jnp.argmax(batch_data['object_metadata']['is_sdc'], axis=-1)
    batch_size = ego_indices.shape[0]

    # Extract ego vehicle features
    ego_features = jax.vmap(lambda x, i: x[i])(
        batch_data['current_features'],
        ego_indices
    )

    # Extract other agent features (excluding ego)
    def get_other_agents(features, ego_idx):
        mask = jnp.arange(features.shape[0]) != ego_idx
        return features[mask]

    other_agents = jax.vmap(get_other_agents)(
        batch_data['current_features'],
        ego_indices
    )

    # Get corresponding future trajectories
    future_trajectories = jax.vmap(lambda x, i: x[i])(
        batch_data['future_features'],
        ego_indices
    )

    return {
        'ego_in': ego_features,
        'agents_in': other_agents,
        'roads': batch_data['road_features'],
        'future_trajectory': future_trajectories,
        'is_controlled': batch_data['object_metadata']['is_controlled']
    }

def compute_trajectory_loss(pred_trajectories, pred_probabilities, target_trajectories, is_controlled):
    """Compute training loss considering trajectory validity and control mask"""
    # Extract valid future timesteps and control mask
    valid_mask = target_trajectories[..., -1]  # Last dimension contains validity
    control_mask = is_controlled[..., None]  # Expand for timesteps

    # Combine validity and control masks
    mask = valid_mask & control_mask

    # Compute displacement error only for valid and controlled positions
    displacement_error = tf.reduce_mean(
        tf.sqrt(tf.reduce_sum(
            tf.square(
                pred_trajectories[..., :2] - target_trajectories[..., :2]
            ),
            axis=-1
        )) * tf.cast(mask, tf.float32)
    )

    # Compute probability loss for controlled agents
    probability_loss = tf.reduce_mean(
        tf.keras.losses.categorical_crossentropy(
            tf.ones_like(pred_probabilities) / pred_probabilities.shape[-1],
            pred_probabilities
        ) * tf.cast(tf.reduce_any(control_mask, axis=-1), tf.float32)
    )

    return displacement_error + 0.1 * probability_loss

def train_wayformer(model_config, training_data):
    """Train Wayformer model on processed data"""
    model = Wayformer(model_config)
    optimizer = tf.keras.optimizers.Adam(learning_rate=model_config.learning_rate)
    print("num length is" + str(len(training_data)))
    num_batches = len(training_data) // model_config.batch_size

    for epoch in range(model_config.num_epochs):
        total_loss = 0

        for batch_idx in tqdm(range(num_batches)):
            batch_data = training_data[batch_idx * model_config.batch_size:
                                     (batch_idx + 1) * model_config.batch_size]

            # Prepare input in the format expected by the model
            model_inputs = prepare_model_input(batch_data)

            with tf.GradientTape() as tape:
                predictions = model(model_inputs, training=True)
                loss = compute_trajectory_loss(
                    predictions['predicted_trajectory'],
                    predictions['predicted_probability'],
                    model_inputs['future_trajectory'],
                    model_inputs['is_controlled']
                )

            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))

            total_loss += loss

        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch + 1}/{model_config.num_epochs}, Average Loss: {avg_loss:.4f}")

    return model

In [None]:
# Authenticate and setup
auth.authenticate_user()

# Configure dataset
config = _config.DatasetConfig(
    path='/content/training_tfexample.tfrecord',
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=32
)

# Create model config
model_config = WayformerTrainingConfig()

# Create training dataset
print("Creating training dataset...")
training_data = create_training_dataset(config, num_scenarios=1000)

# Train model
print("Training Wayformer model...")
model = train_wayformer(model_config, training_data)

# Evaluate on a test scenario
test_scenario = next(dataloader.simulator_state_generator(config=config))
test_data = process_waymax_data(test_scenario)

predictions = model({
    'ego_in': test_data['agent_features'][:1, 0],
    'agents_in': test_data['agent_features'][:1, 1:],
    'roads': test_data['road_features'][:1]
}, training=False)

# Visualize predictions
imgs = []
state = test_scenario
pred_trajectory = predictions['predicted_trajectory'][0]  # Take most likely trajectory

for t in range(state.remaining_timesteps):
    state = datatypes.update_state_by_log(state, num_steps=1)
    img = visualization.plot_simulator_state(
        state,
        use_log_traj=True,
        additional_trajectories={0: pred_trajectory[t]}  # Overlay predictions
    )
    imgs.append(img)

mediapy.show_video(imgs, fps=10)