<a href="https://colab.research.google.com/github/9Tempest/motionLM-Serve/blob/main/waymax_demo_kvcache.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Scenario Data Loading

This tutorial demonstrates how to load scenario data from the Waymo Open Motion Dataset (WOMD) using the Waymax dataloader.

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/9Tempest/motionLM-Serve/blob/main/waymax_demo.ipynb"><img src="https://quantumai.google/site-assets/images/buttons/colab_logo_1x.png" />Run in Google Colab</a>
  </td>
</table>

In [None]:
%%capture
!pip install mediapy
!pip install git+https://github.com/waymo-research/waymax.git@main#egg=waymo-waymax
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses
import jax
from jax import numpy as jnp
import numpy as np
import mediapy


from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import dynamics
from waymax import env as _env
from waymax import agents
from waymax import visualization
from google.colab import auth




We first create a dataset config, using the default configs provided in the `waymax.config` module. In particular, `config.WOD_1_1_0_TRAINING` is a pre-defined configuration that points to version 1.1.0 of the Waymo Open Dataset.

The data config contains a number of options to configure how and where the dataset is loaded from. By default, the `WOD_1_1_0_TRAINING` loads up to 128 objects (e.g. vehicles, pedestrians) per scenario. Here, we can save memory and compute by loading only the first 32 objects stored in the scenario.

We use the `dataloader.simulator_state_generator` function to create an iterator
through Open Motion Dataset scenarios. Calling next on the iterator will retrieve the first scenario in the dataset.


In [None]:
auth.authenticate_user()
!gsutil cp gs://waymo_open_dataset_motion_v_1_2_0/uncompressed/tf_example/training/training_tfexample.tfrecord-00000-of-01000 /content/training_tfexample.tfrecord


In [None]:

config = _config.DatasetConfig(path ='/content/training_tfexample.tfrecord',
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=32)
data_iter = dataloader.simulator_state_generator(config=config)
# Check if the iterator is empty before calling next
try:
    scenario = next(data_iter)
    print(scenario)
except StopIteration:
    print("The data iterator is empty.")
    # Handle empty iterator (e.g., reload data, check config)

In [None]:
# Wayformer Model Setup
import tensorflow as tf
import numpy as np
from typing import Optional, Tuple, List
from dataclasses import dataclass

In [None]:
@dataclass
class WayformerTrainingConfig:
    """Configuration for Wayformer model training."""
    num_map_feature: int = 11  # Road feature dimensions
    num_agent_feature: int = 9  # Agent feature dimensions
    hidden_size: int = 256
    max_num_agents: int = 32
    num_modes: int = 6
    future_len: int = 80  # 8 seconds with 10Hz
    past_len: int = 11   # 1 second with 10Hz
    dropout: float = 0.1
    tx_num_heads: int = 8
    max_points_per_lane: int = 40
    max_num_roads: int = 50
    num_queries_enc: int = 128
    num_queries_dec: int = 64
    learning_rate: float = 1e-4
    batch_size: int = 32
    num_epochs: int = 10

@dataclass
class ModuleOutput:
    last_hidden_state: tf.Tensor
    kv_cache: Optional[Tuple[tf.Tensor, tf.Tensor]] = None

class TrainableQueryProvider(tf.keras.layers.Layer):
    def __init__(self, num_queries: int, num_query_channels: int, init_scale: float = 0.02):
        super().__init__()
        self.num_queries = num_queries
        self.num_query_channels = num_query_channels
        self.init_scale = init_scale

    def build(self, input_shape):
        self.query = self.add_weight(
            shape=(self.num_queries, self.num_query_channels),
            initializer=tf.keras.initializers.RandomNormal(stddev=self.init_scale),
            trainable=True,
            name='query'
        )

    def call(self, x=None):
        return tf.expand_dims(self.query, 0)  # Add batch dimension

class PerceiverEncoder(tf.keras.layers.Layer):
    def __init__(
        self,
        num_latents: int,
        num_latent_channels: int,
        num_cross_attention_heads: int = 4,
        num_cross_attention_qk_channels: Optional[int] = None,
        num_cross_attention_v_channels: Optional[int] = None,
        num_cross_attention_layers: int = 1,
        dropout: float = 0.1,
        init_scale: float = 0.02,
    ):
        super().__init__()

        self.latent_provider = TrainableQueryProvider(
            num_latents,
            num_latent_channels,
            init_scale=init_scale
        )

        self.cross_attention = MultiHeadAttention(
            num_heads=num_cross_attention_heads,
            num_q_input_channels=num_latent_channels,
            num_kv_input_channels=num_latent_channels,
            num_qk_channels=num_cross_attention_qk_channels,
            num_v_channels=num_cross_attention_v_channels,
            dropout=dropout
        )

        self.self_attention = MultiHeadAttention(
            num_heads=num_cross_attention_heads,
            num_q_input_channels=num_latent_channels,
            num_kv_input_channels=num_latent_channels,
            dropout=dropout
        )

        self.layer_norm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layer_norm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout = tf.keras.layers.Dropout(dropout)

    def call(self, x, pad_mask=None, training=False):
        x_latent = self.latent_provider()

        # Cross attention
        residual = x_latent
        x_latent = self.layer_norm1(x_latent)
        cross_attn_output = self.cross_attention(x_latent, x, pad_mask=pad_mask, training=training)
        x_latent = residual + self.dropout(cross_attn_output.last_hidden_state, training=training)

        # Self attention
        residual = x_latent
        x_latent = self.layer_norm2(x_latent)
        self_attn_output = self.self_attention(x_latent, x_latent, training=training)
        x_latent = residual + self.dropout(self_attn_output.last_hidden_state, training=training)

        return x_latent

In [None]:
def process_waymax_data(state, action_space):
    """Convert waymax SimulatorState to model input format with consistent shapes"""
    trajectory = state.sim_trajectory
    timestep = state.timestep
    past_steps = 11  # Match the config.past_len
    future_steps = 80  # Match the config.future_len

    num_timesteps = trajectory.x.shape[1]

    # Get past trajectory features (last past_steps timesteps before current)
    past_idx = slice(max(0, timestep - past_steps + 1), timestep + 1)

    # Extract and pad past features if necessary
    xy = trajectory.xy[..., past_idx, :]
    vel_xy = trajectory.vel_xy[..., past_idx, :]
    yaw = trajectory.yaw[..., past_idx]
    length = trajectory.length[..., past_idx]
    width = trajectory.width[..., past_idx]
    height = trajectory.height[..., past_idx]
    valid = trajectory.valid[..., past_idx]

    # Pad if we don't have enough past steps
    if xy.shape[-2] < past_steps:
        pad_length = past_steps - xy.shape[-2]
        xy = jnp.pad(xy, ((0, 0), (pad_length, 0), (0, 0)), mode='edge')
        vel_xy = jnp.pad(vel_xy, ((0, 0), (pad_length, 0), (0, 0)), mode='edge')
        yaw = jnp.pad(yaw, ((0, 0), (pad_length, 0)), mode='edge')
        length = jnp.pad(length, ((0, 0), (pad_length, 0)), mode='edge')
        width = jnp.pad(width, ((0, 0), (pad_length, 0)), mode='edge')
        height = jnp.pad(height, ((0, 0), (pad_length, 0)), mode='edge')
        valid = jnp.pad(valid, ((0, 0), (pad_length, 0)), mode='constant', constant_values=0)

    # Expand dimensions for single-value features
    yaw = jnp.expand_dims(yaw, axis=-1)
    length = jnp.expand_dims(length, axis=-1)
    width = jnp.expand_dims(width, axis=-1)
    height = jnp.expand_dims(height, axis=-1)
    valid = jnp.expand_dims(valid, axis=-1)

    # Broadcast is_sdc to match time dimension
    is_sdc = jnp.expand_dims(state.object_metadata.is_sdc, axis=(1, -1))
    is_sdc = jnp.repeat(is_sdc, past_steps, axis=1)

    # Concatenate all features
    current_features = jnp.concatenate([
        xy,                     # (..., past_steps, 2)
        vel_xy,                 # (..., past_steps, 2)
        yaw,                    # (..., past_steps, 1)
        length,                 # (..., past_steps, 1)
        width,                  # (..., past_steps, 1)
        height,                 # (..., past_steps, 1)
        is_sdc,                 # (..., past_steps, 1)
        valid                   # (..., past_steps, 1)
    ], axis=-1)

    # Get future trajectory features
    future_idx = slice(timestep + 1, min(timestep + 1 + future_steps, num_timesteps))
    future_xy = trajectory.xy[..., future_idx, :]
    future_vel_xy = trajectory.vel_xy[..., future_idx, :]
    future_yaw = jnp.expand_dims(trajectory.yaw[..., future_idx], axis=-1)
    future_valid = jnp.expand_dims(trajectory.valid[..., future_idx], axis=-1)

    # Pad future features if necessary
    if future_xy.shape[-2] < future_steps:
        pad_length = future_steps - future_xy.shape[-2]
        future_xy = jnp.pad(future_xy, ((0, 0), (0, pad_length), (0, 0)), mode='edge')
        future_vel_xy = jnp.pad(future_vel_xy, ((0, 0), (0, pad_length), (0, 0)), mode='edge')
        future_yaw = jnp.pad(future_yaw, ((0, 0), (0, pad_length), (0, 0)), mode='edge')
        future_valid = jnp.pad(future_valid, ((0, 0), (0, pad_length), (0, 0)), mode='constant', constant_values=0)

    future_features = jnp.concatenate([
        future_xy,
        future_vel_xy,
        future_yaw,
        future_valid
    ], axis=-1)

    # Process road features
    if hasattr(state, 'roadgraph_points') and state.roadgraph_points is not None:
        road_xyz = state.roadgraph_points.xyz
        road_dir = jnp.stack([
            state.roadgraph_points.dir_x,
            state.roadgraph_points.dir_y,
            state.roadgraph_points.dir_z
        ], axis=-1)
        road_valid = jnp.expand_dims(state.roadgraph_points.valid, axis=-1)
        road_features = jnp.concatenate([road_xyz, road_dir, road_valid], axis=-1)

        # Truncate or pad to fixed size
        max_road_points = 1024
        if road_features.shape[0] > max_road_points:
            road_features = road_features[:max_road_points]
        else:
            pad_length = max_road_points - road_features.shape[0]
            road_features = jnp.pad(road_features, ((0, pad_length), (0, 0)), mode='constant')
    else:
        road_features = jnp.zeros((1024, 7))
    # Assume ego vehicle is at index 0
    ego_index = 0  # Update this if necessary based on your data
    # Extract ego's past and future data
    ego_past_positions = xy[ego_index]  # Shape: (past_steps, 2)
    ego_past_velocities = vel_xy[ego_index]  # Shape: (past_steps, 2)
    ego_future_positions = future_xy[ego_index]  # Shape: (future_steps, 2)
    ego_future_velocities = future_vel_xy[ego_index]  # Shape: (future_steps, 2)
    # Combine past and future positions and velocities
    positions = np.concatenate([ego_past_positions, ego_future_positions], axis=0)
    velocities = np.concatenate([ego_past_velocities, ego_future_velocities], axis=0)

    # Encode the future trajectory into motion tokens
    # Note: For encoding, we need positions and velocities at T+1 points to compute T accelerations
    token_indices = encode_trajectory(positions, velocities, action_space)

    return {
        'current_features': current_features,  # (32, past_steps=11, 10)
        'future_features': future_features,    # (32, future_steps=80, 6)
        'road_features': road_features,        # (1024, 7)
        'object_metadata': {
            'is_sdc': state.object_metadata.is_sdc,
            'object_types': state.object_metadata.object_types,
            'is_valid': state.object_metadata.is_valid,
        },
        'motion_tokens': token_indices[past_steps - 1:]  # Use future tokens only
    }

def prepare_model_input(batch_data):
    """Prepare batch data for model input with correct shapes"""
    batch_size = len(batch_data)

    # Stack all features
    batched_data = {
        'current_features': jnp.stack([d['current_features'] for d in batch_data]),
        'future_features': jnp.stack([d['future_features'] for d in batch_data]),
        'road_features': jnp.stack([d['road_features'] for d in batch_data]),
        'object_metadata': {
            'is_sdc': jnp.stack([d['object_metadata']['is_sdc'] for d in batch_data])
        },
        'motion_tokens': jnp.stack([d['motion_tokens'] for d in batch_data])
    }

    # Find ego vehicle (SDC) in each scenario
    ego_indices = jnp.argmax(batched_data['object_metadata']['is_sdc'], axis=-1)

    # Extract ego and other agent features
    ego_features = []
    ego_future_features = []
    other_agents = []

    for i in range(batch_size):
        # Get ego features
        ego_feat = batched_data['current_features'][i, ego_indices[i]]
        ego_features.append(ego_feat)

        # Get ego future features
        ego_future_feat = batched_data['future_features'][i, ego_indices[i]]
        ego_future_features.append(ego_future_feat)

        # Get other agent features (excluding ego)
        agents = []
        for j in range(batched_data['current_features'].shape[1]):
            if j != ego_indices[i]:
                agents.append(batched_data['current_features'][i, j])
        agents = jnp.stack(agents[:31])  # Limit to 31 other agents

        # Pad if we have fewer than 31 agents
        if agents.shape[0] < 31:
            pad_length = 31 - agents.shape[0]
            agents = jnp.pad(agents, ((0, pad_length), (0, 0), (0, 0)), mode='constant')

        other_agents.append(agents)

    return {
        'ego_in': jnp.stack(ego_features),          # (batch_size, past_steps, features)
        'agents_in': jnp.stack(other_agents),       # (batch_size, 31, past_steps, features)
        'roads': batched_data['road_features'],     # (batch_size, max_road_points, features)
        'future_trajectory': jnp.stack(ego_future_features),  # (batch_size, future_steps, features)
        'motion_tokens': batched_data['motion_tokens']  # (batch_size, future_steps)
    }

def prepare_model_input_with_history(batch_data, history_motion_tokens, T=80):
    """Prepare batch data for model input with history buffers and correct shapes"""
    batch_size = len(batch_data)

    # Stack all features
    batched_data = {
        'current_features': jnp.stack([d['current_features'] for d in batch_data]),
        'road_features': jnp.stack([d['road_features'] for d in batch_data]),
        'object_metadata': {
            'is_sdc': jnp.stack([d['object_metadata']['is_sdc'] for d in batch_data])
        },
        # We will not use 'future_features' and 'motion_tokens' from batch_data
        # as we are replacing them with the history buffers
    }

    # Find ego vehicle (SDC) in each scenario
    ego_indices = jnp.argmax(batched_data['object_metadata']['is_sdc'], axis=-1)

    # Extract ego and other agent features
    ego_features = []
    other_agents = []

    for i in range(batch_size):
        # Get ego features
        ego_feat = batched_data['current_features'][i, ego_indices[i]]
        ego_features.append(ego_feat)

        # Get other agent features (excluding ego)
        agents = []
        for j in range(batched_data['current_features'].shape[1]):
            if j != ego_indices[i]:
                agents.append(batched_data['current_features'][i, j])
        agents = agents[:31]  # Limit to 31 other agents

        # Pad if we have fewer than 31 agents
        if len(agents) < 31:
            pad_length = 31 - len(agents)
            # Create padding agents with zeros
            agent_pad = jnp.zeros((pad_length, ego_feat.shape[0], ego_feat.shape[1]))
            agents.extend([agent_pad])

        agents = jnp.stack(agents)
        other_agents.append(agents)

    # Prepare history buffers for motion tokens
    # Ensure history buffers have length T and proper padding

    # For motion tokens
    if len(history_motion_tokens) < T:
        pad_length = T - len(history_motion_tokens)
        # Pad with zeros or appropriate padding tokens
        padding_motion_tokens = [0] * pad_length  # Assuming 0 is the padding token
        history_motion_tokens = padding_motion_tokens + history_motion_tokens
    else:
        history_motion_tokens = history_motion_tokens[-T:]  # Keep the last T tokens

    # Convert to numpy array and add batch dimension
    motion_tokens_array = np.array(history_motion_tokens)[np.newaxis, :]  # Shape: (1, T)

    # Stack the features
    return {
        'ego_in': jnp.stack(ego_features),          # (batch_size, past_steps, features)
        'agents_in': jnp.stack(other_agents),       # (batch_size, 31, past_steps, features)
        'roads': batched_data['road_features'],     # (batch_size, max_road_points, features)
        'motion_tokens': motion_tokens_array        # (batch_size, T)
    }

def create_training_dataset(config, action_space, num_scenarios=1000):
    """Create training dataset from waymax data"""
    data_iter = dataloader.simulator_state_generator(config=config)
    dataset = []

    for i in tqdm(range(num_scenarios)):
        try:
            scenario = next(data_iter)
            if jnp.any(scenario.object_metadata.is_valid):
                try:
                    processed_data = process_waymax_data(scenario,action_space)
                    dataset.append(processed_data)
                    if i == 0:  # Print shapes for first successful scenario
                        print("\nFirst successful scenario shapes:")
                        print(f"current_features: {processed_data['current_features'].shape}")
                        print(f"future_features: {processed_data['future_features'].shape}")
                        print(f"road_features: {processed_data['road_features'].shape}")
                except Exception as e:
                    print(f"Failed to process scenario: {str(e)}")
                    continue
        except StopIteration:
            break

    print(f"\nSuccessfully processed {len(dataset)} scenarios")
    return dataset

In [None]:
class VerletActionSpace:
    def __init__(self, delta_interval=(-18.0, 18.0), num_bins=13, step_frequency=2.0):
        self.delta_interval = delta_interval
        self.num_bins = num_bins
        self.step_frequency = step_frequency
        self.time_step = 1.0 / step_frequency
        self.accel_bin_edges = np.linspace(delta_interval[0], delta_interval[1], num_bins + 1)
        self.accel_bin_centers = (self.accel_bin_edges[:-1] + self.accel_bin_edges[1:]) / 2

        # Create the vocabulary as the Cartesian product of acceleration bins for x and y
        self.vocab_size = num_bins * num_bins
        self.token_to_accel = np.array([
            (ax, ay) for ax in self.accel_bin_centers for ay in self.accel_bin_centers
        ])  # Shape (169, 2)

    def discretize_acceleration(self, accel):
        # accel: array of shape (..., 2)
        indices_x = np.digitize(accel[..., 0], self.accel_bin_edges) - 1
        indices_y = np.digitize(accel[..., 1], self.accel_bin_edges) - 1
        indices_x = np.clip(indices_x, 0, self.num_bins - 1)
        indices_y = np.clip(indices_y, 0, self.num_bins - 1)
        token_indices = indices_x * self.num_bins + indices_y
        return token_indices  # Shape (...)

    def continuous_acceleration(self, token_indices):
        # token_indices: array of shape (...)
        accel = self.token_to_accel[token_indices]
        return accel  # Shape (..., 2)

    def verlet_update(self, position, velocity, acceleration):
        # position, velocity, acceleration: arrays of shape (..., 2)
        new_position = position + velocity * self.time_step + 0.5 * acceleration * self.time_step ** 2
        new_velocity = velocity + acceleration * self.time_step
        return new_position, new_velocity

def encode_trajectory(positions, velocities, action_space):
    """
    Encode continuous trajectories into motion tokens.
    positions: array of shape (T+1, 2)
    velocities: array of shape (T+1, 2)
    Returns:
        token_indices: array of shape (T,)
    """
    time_step = action_space.time_step
    # Compute accelerations using inverse Verlet integration
    accelerations = (velocities[1:] - velocities[:-1]) / time_step
    # Discretize accelerations
    token_indices = action_space.discretize_acceleration(accelerations)
    return token_indices

def decode_tokens(start_position, start_velocity, token_indices, action_space):
    """
    Decode motion tokens back into continuous trajectories.
    start_position: array of shape (2,)
    start_velocity: array of shape (2,)
    token_indices: array of shape (T,)
    Returns:
        positions: array of shape (T+1, 2)
        velocities: array of shape (T+1, 2)
    """
    positions = [start_position]
    velocities = [start_velocity]
    for idx in token_indices:
        accel = action_space.continuous_acceleration(idx)
        new_position, new_velocity = action_space.verlet_update(
            positions[-1], velocities[-1], accel)
        positions.append(new_position)
        velocities.append(new_velocity)
    positions = np.stack(positions)
    velocities = np.stack(velocities)
    return positions, velocities




In [None]:
# Authenticate and setup
auth.authenticate_user()

# Configure dataset
config = _config.DatasetConfig(
    path='/content/training_tfexample.tfrecord',
    data_format=_config.DataFormat.TFRECORD,
    max_num_objects=32
)

# Create model config
model_config = WayformerTrainingConfig()

# Create training dataset
print("Creating training dataset...")
action_space = VerletActionSpace()
training_data = create_training_dataset(config, action_space, num_scenarios=1000)


In [None]:
@dataclass
class WayformerTrainingConfig:
    """Configuration for Wayformer model training."""
    num_map_feature: int = 11  # Road feature dimensions
    num_agent_feature: int = 9  # Agent feature dimensions
    hidden_size: int = 256
    max_num_agents: int = 32
    num_modes: int = 6
    future_len: int = 512  # 8 seconds with 10Hz
    past_len: int = 11   # 1 second with 10Hz
    dropout: float = 0.1
    tx_num_heads: int = 8
    max_points_per_lane: int = 40
    max_num_roads: int = 50
    num_queries_enc: int = 128
    num_queries_dec: int = 64
    learning_rate: float = 1e-4
    batch_size: int = 32
    num_epochs: int = 10

In [None]:
class MultiHeadAttentionCache(tf.keras.layers.Layer):
    def __init__(
        self,
        num_heads: int,
        d_model: int,
        dropout_rate: float = 0.1
    ):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads

        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result to shape (batch_size, num_heads, seq_len, depth)"""
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask, cache=None, cache_key=None, training=False):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len_q, d_model)
        k = self.wk(k)  # (batch_size, seq_len_k, d_model)
        v = self.wv(v)  # (batch_size, seq_len_v, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        if cache is not None and cache_key is not None:
            # Use the cache instance
            cache_instance = cache[cache_key]
            # Update the cache with only the new keys and values
            cache_instance.update(k, v)
            # Retrieve the cached keys and values
            k_cache, v_cache = cache_instance.get_key_value()
        else:
            k_cache = k
            v_cache = v

        # Scaled dot-product attention
        matmul_qk = tf.matmul(q, k_cache, transpose_b=True)  # (..., seq_len_q, seq_len_k)
        dk = tf.cast(tf.shape(k_cache)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        # Apply mask if provided
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)

        # Softmax on the last axis (seq_len_k)
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)
        attention_weights = self.dropout(attention_weights, training=training)

        output = tf.matmul(attention_weights, v_cache)  # (..., seq_len_q, depth)

        output = tf.transpose(output, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_output = tf.reshape(output, (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_output)  # (batch_size, seq_len_q, d_model)

        return output, attention_weights

class TransformerDecoderLayerCache(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        super(TransformerDecoderLayerCache, self).__init__()

        # Self-attention (masked)
        self.mha1 = MultiHeadAttentionCache(num_heads, d_model, dropout_rate)

        # Cross-attention with encoder outputs
        self.mha2 = MultiHeadAttentionCache(num_heads, d_model, dropout_rate)

        # Point-wise feedforward network
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
            tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
        ])

        # Layer normalizations
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        # Dropouts
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout3 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x, enc_output, look_ahead_mask=None, padding_mask=None, training=False, cache=None, layer_idx=0):
        # Self-attention (masked)
        attn1, _ = self.mha1(
            q=x,
            k=x,
            v=x,
            mask=look_ahead_mask,
            cache=cache,
            cache_key=f'decoder_layer_{layer_idx}_self_attn',
            training=training
        )
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(x + attn1)  # Residual connection and layer norm

        # Cross-attention with encoder outputs
        attn2, _ = self.mha2(
            q=out1,
            k=enc_output,
            v=enc_output,
            mask=padding_mask,
            training=training
        )
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2)  # Residual connection and layer norm

        # Feedforward network
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)

        return out3

class WayformerCache(tf.keras.Model):
    def __init__(self, config, action_space):
        super().__init__()
        # Initialize dimensions and parameters
        self.map_attr = config.num_map_feature
        self.k_attr = config.num_agent_feature
        self.d_k = config.hidden_size
        self._M = config.max_num_agents
        self.c = config.num_modes
        self.T = config.future_len
        self.dropout = config.dropout
        self.num_heads = config.tx_num_heads
        self.past_T = config.past_len

        # Input encoders
        self.road_pts_lin = tf.keras.layers.Dense(self.d_k)
        self.agents_dynamic_encoder = tf.keras.layers.Dense(self.d_k)

        # Positional embeddings
        self.temporal_embedding = self.add_weight(
            shape=(self.past_T, self.d_k),
            initializer='zeros',
            trainable=True,
            name='temporal_emb'
        )

        self.agent_embedding = self.add_weight(
            shape=(32, self.d_k),  # 32 = 1 ego + 31 other agents
            initializer='zeros',
            trainable=True,
            name='agent_emb'
        )

        # Agent processing layers
        self.agent_encoder = tf.keras.Sequential([
            tf.keras.layers.Dense(self.d_k, activation='relu'),
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Dropout(self.dropout)
        ])

        self.agent_transformer = tf.keras.layers.MultiHeadAttention(
            num_heads=8,
            key_dim=self.d_k // 8,
            dropout=self.dropout
        )

        # Road processing layers
        self.road_encoder = tf.keras.Sequential([
            tf.keras.layers.Dense(self.d_k, activation='relu'),
            tf.keras.layers.LayerNormalization(),
            tf.keras.layers.Dropout(self.dropout)
        ])

        # Output processing
        self.perceiver = PerceiverEncoder(
            num_latents=config.num_queries_enc,
            num_latent_channels=self.d_k
        )

        self.output_query = self.add_weight(
            shape=(self.c, self.d_k),
            initializer='random_normal',
            trainable=True,
            name='output_query'
        )

        self.trajectory_projection = tf.keras.layers.Dense(5 * self.T)
        self.mode_projection = tf.keras.layers.Dense(1)

        self.action_space = action_space
        self.vocab_size = action_space.vocab_size

        # Token embedding for motion tokens
        self.token_embedding = tf.keras.layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.d_k)

        # Positional encoding for tokens
        self.positional_encoding = self.add_weight(
            shape=(self.T, self.d_k),
            initializer='zeros',
            trainable=True,
            name='token_positional_encoding'
        )

        # Transformer decoder layers
        self.decoder_layers = [
            TransformerDecoderLayerCache(
                d_model=self.d_k,
                num_heads=config.tx_num_heads,
                dff=1024,
                dropout_rate=config.dropout
            ) for _ in range(4)
        ]

        # Output projection to predict token logits
        self.token_projection = tf.keras.layers.Dense(self.vocab_size)

    def build(self, input_shapes):
        """Build the model based on input shapes"""
        self.built = True

    def encode_agents(self, ego_in, agents_in, training=False):
        """Process agent features"""
        batch_size = tf.shape(ego_in)[0]

        # Reshape ego to add agent dimension
        ego_expanded = tf.expand_dims(ego_in, axis=1)  # [batch, 1, time, features]

        # Concatenate ego with other agents
        all_agents = tf.concat([ego_expanded, agents_in], axis=1)  # [batch, 32, time, features]

        # Encode agent features
        encoded = self.agent_encoder(all_agents, training=training)

        # Add temporal embeddings
        temporal_emb = tf.expand_dims(tf.expand_dims(self.temporal_embedding, 0), 0)
        temporal_emb = tf.tile(temporal_emb, [batch_size, encoded.shape[1], 1, 1])
        encoded = encoded + temporal_emb

        # Add agent embeddings
        agent_emb = tf.expand_dims(tf.expand_dims(self.agent_embedding, 0), 2)
        agent_emb = tf.tile(agent_emb, [batch_size, 1, encoded.shape[2], 1])
        encoded = encoded + agent_emb

        # Apply self-attention
        attended = self.agent_transformer(
            query=encoded,
            key=encoded,
            value=encoded,
            training=training
        )

        # Final reshape
        final_encoded = tf.reshape(attended, [batch_size, -1, self.d_k])

        return final_encoded

    def encode_roads(self, roads, training=False):
        # Process road features
        encoded = self.road_encoder(roads, training=training)
        return tf.reshape(encoded, [tf.shape(encoded)[0], -1, self.d_k])

    def call(self, inputs, cache=None, training=False):
        """Forward pass with KV cache support"""
        ego_in = inputs['ego_in']
        agents_in = inputs['agents_in']
        roads = inputs['roads']
        motion_tokens = inputs['motion_tokens']  # Shape: (batch_size, T)

        # Process agents
        agents_encoded = self.encode_agents(ego_in, agents_in, training)
        # Process roads
        roads_encoded = self.encode_roads(roads, training)
        # Combine features
        combined_features = tf.concat([agents_encoded, roads_encoded], axis=1)
        # Apply perceiver encoding
        context = self.perceiver(combined_features, training=training)
        # Generate outputs
        batch_size = tf.shape(ego_in)[0]
        query = tf.tile(tf.expand_dims(self.output_query, 0), [batch_size, 1, 1])

        # Get trajectory predictions
        output_features = tf.matmul(query, context, transpose_b=True)
        trajectories = self.trajectory_projection(output_features)
        trajectories = tf.reshape(trajectories, [batch_size, self.c, self.T, -1])

        # Get mode probabilities
        mode_logits = tf.squeeze(self.mode_projection(output_features), axis=-1)
        mode_probs = tf.nn.softmax(mode_logits, axis=-1)

        # Embed motion tokens
        token_embeddings = self.token_embedding(motion_tokens)  # Shape: (batch_size, T, d_k)

        # Add positional encoding
        positional_encoding = tf.expand_dims(self.positional_encoding[:tf.shape(motion_tokens)[1], :], axis=0)
        token_embeddings += positional_encoding  # Broadcasting over batch_size

        # Prepare attention masks
        batch_size = tf.shape(token_embeddings)[0]
        seq_len = tf.shape(token_embeddings)[1]
        look_ahead_mask = self.create_look_ahead_mask(seq_len)  # Shape: (1, seq_len, seq_len)

        # Pass through decoder layers
        decoder_output = token_embeddings
        for idx, layer in enumerate(self.decoder_layers):
            decoder_output = layer(
                decoder_output,
                context,
                training=training,
                look_ahead_mask=look_ahead_mask,
                cache=cache,
                layer_idx=idx
            )

        # Predict logits for next tokens
        logits = self.token_projection(decoder_output)  # Shape: (batch_size, seq_len, vocab_size)

        return {
            'predicted_trajectory': trajectories,
            'predicted_probability': mode_probs,
            'scene_emb': tf.reshape(output_features, [batch_size, -1]),
            'logits': logits
        }

    def create_look_ahead_mask(self, size):
        mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
        mask = tf.cast(mask, tf.float32)  # Shape: (size, size)
        return mask[tf.newaxis, :, :]  # Shape: (1, size, size)

In [None]:
""" This cache is a baseline KV cache, storing all past keys and vals"""
from abc import ABC, abstractmethod

class BaseCache(ABC):
    @abstractmethod
    def update(self, key, value):
        """Update the cache with new key and value."""
        pass

    @abstractmethod
    def get_key_value(self):
        """Retrieve the current key and value from the cache."""
        pass

    @abstractmethod
    def reset(self):
        """Reset the cache to its initial state."""
        pass

class FullCache(BaseCache):
    def __init__(self, max_cache_length, num_heads, depth):
        self.max_cache_length = max_cache_length
        self.num_heads = num_heads
        self.depth = depth
        self.current_length = 0

        self.key = None
        self.value = None

    def update(self, key, value):
        batch_size = tf.shape(key)[0]
        seq_len = tf.shape(key)[2]  # seq_len_k

        if self.key is None:
            # Initialize the variables with maximum cache length
            self.key = tf.Variable(
                tf.zeros([batch_size, self.num_heads, self.max_cache_length, self.depth], dtype=key.dtype),
                trainable=False)
            self.value = tf.Variable(
                tf.zeros([batch_size, self.num_heads, self.max_cache_length, self.depth], dtype=value.dtype),
                trainable=False)
            self.current_length = 0

        if self.current_length + seq_len > self.max_cache_length:
            raise ValueError("Cache overflow: Exceeded maximum cache length")

        # Assign the new key and value to the appropriate slice
        self.key[:, :, self.current_length:self.current_length + seq_len, :].assign(key)
        self.value[:, :, self.current_length:self.current_length + seq_len, :].assign(value)

        self.current_length += seq_len

    def get_key_value(self):
        if self.key is None or self.current_length == 0:
            return None, None
        # Return the key and value up to current_length
        key = self.key[:, :, :self.current_length, :]
        value = self.value[:, :, :self.current_length, :]
        return key, value

    def reset(self):
        self.current_length = 0
        if self.key is not None:
            self.key.assign(tf.zeros_like(self.key))
            self.value.assign(tf.zeros_like(self.value))

class SlidingWindowCache(BaseCache):
    def __init__(self, window_size, num_heads, head_dim):
        self.window_size = window_size  # Fixed size of the window (W)
        self.num_heads = num_heads
        self.head_dim = head_dim  # Dimension per head

        # Initialize pointers for the ring buffer
        self.current_size = tf.Variable(0, trainable=False)
        self.start_idx = tf.Variable(0, trainable=False)
        self.end_idx = tf.Variable(0, trainable=False)

        # Preallocate buffers for keys and values with fixed size
        self.key_buffer = None   # To be initialized upon first call
        self.value_buffer = None

    def update(self, key, value):
        batch_size = tf.shape(key)[0]
        seq_len_kv = tf.shape(key)[2]  # Should be 1 during autoregressive inference

        # Initialize buffers upon the first update when batch_size is known
        if self.key_buffer is None:
            self.key_buffer = tf.Variable(
                tf.zeros([batch_size, self.num_heads, self.window_size, self.head_dim], dtype=key.dtype),
                trainable=False
            )
            self.value_buffer = tf.Variable(
                tf.zeros([batch_size, self.num_heads, self.window_size, self.head_dim], dtype=value.dtype),
                trainable=False
            )

        for i in range(seq_len_kv):
            # Calculate the index for insertion
            idx = (self.end_idx + i) % self.window_size

            # Update the key and value buffers at position idx
            self.key_buffer[:, :, idx, :].assign(key[:, :, i, :])
            self.value_buffer[:, :, idx, :].assign(value[:, :, i, :])

        # Update pointers
        self.end_idx.assign((self.end_idx + seq_len_kv) % self.window_size)
        if self.current_size < self.window_size:
            self.current_size.assign_add(seq_len_kv)
        else:
            # Buffer is full; move the start_idx forward
            self.start_idx.assign((self.start_idx + seq_len_kv) % self.window_size)

    def get_key_value(self):
        if self.current_size == 0:
            return None, None  # Cache is empty

        if self.start_idx < self.end_idx:
            # Continuous slice
            key = self.key_buffer[:, :, self.start_idx:self.end_idx, :]
            value = self.value_buffer[:, :, self.start_idx:self.end_idx, :]
        else:
            # Wrap-around slice
            key = tf.concat([
                self.key_buffer[:, :, self.start_idx:, :],
                self.key_buffer[:, :, :self.end_idx, :]
            ], axis=2)
            value = tf.concat([
                self.value_buffer[:, :, self.start_idx:, :],
                self.value_buffer[:, :, :self.end_idx, :]
            ], axis=2)

        return key, value

    def reset(self):
        self.current_size.assign(0)
        self.start_idx.assign(0)
        self.end_idx.assign(0)
        if self.key_buffer is not None:
            self.key_buffer.assign(tf.zeros_like(self.key_buffer))
            self.value_buffer.assign(tf.zeros_like(self.value_buffer))

In [None]:
def pad_state_to_num_timesteps(state, required_timesteps):
    """
    Pads the state's sim_trajectory arrays to have at least required_timesteps.
    Returns the updated state with JAX arrays.
    """
    import jax.numpy as jnp

    # Get the existing number of timesteps
    existing_timesteps = state.sim_trajectory.x.shape[1]
    timesteps_to_add = required_timesteps - existing_timesteps

    if timesteps_to_add <= 0:
        # No padding needed
        return state

    # Function to pad an array along the second dimension (timesteps)
    def pad_array(array, pad_length, pad_value=0):
        pad_shape = list(array.shape)
        pad_shape[1] = pad_length
        if array.dtype == jnp.bool_:
            padding = jnp.full(pad_shape, pad_value, dtype=array.dtype)
        else:
            padding = jnp.full(pad_shape, pad_value, dtype=array.dtype)
        return jnp.concatenate([array, padding], axis=1)

    # Pad each relevant array in sim_trajectory
    padded_sim_trajectory = state.sim_trajectory.replace(
        x=pad_array(state.sim_trajectory.x, timesteps_to_add),
        y=pad_array(state.sim_trajectory.y, timesteps_to_add),
        z=pad_array(state.sim_trajectory.z, timesteps_to_add),
        vel_x=pad_array(state.sim_trajectory.vel_x, timesteps_to_add),
        vel_y=pad_array(state.sim_trajectory.vel_y, timesteps_to_add),
        yaw=pad_array(state.sim_trajectory.yaw, timesteps_to_add),
        length=pad_array(state.sim_trajectory.length, timesteps_to_add),
        width=pad_array(state.sim_trajectory.width, timesteps_to_add),
        height=pad_array(state.sim_trajectory.height, timesteps_to_add),
        timestamp_micros=pad_array(
            state.sim_trajectory.timestamp_micros, timesteps_to_add, pad_value=0
        ),
        valid=pad_array(state.sim_trajectory.valid, timesteps_to_add, pad_value=False)
    )

    # Return the updated state with the padded sim_trajectory
    return state.replace(sim_trajectory=padded_sim_trajectory)

import sys

def get_cache_memory_usage(cache):
    total_size = 0
    for layer_key, cache_instance in cache.items():
        k_cache, v_cache = cache_instance.get_key_value()
        if k_cache is not None and v_cache is not None:
            size_k = tf.size(k_cache)
            size_v = tf.size(v_cache)
            dtype_size = k_cache.dtype.size
            layer_memory = (size_k + size_v) * dtype_size  # Bytes
            total_size += layer_memory
            layer_memory_mb = layer_memory / (1024 ** 2)
            print(f"Layer {layer_key} cache memory usage: {layer_memory_mb:.2f} MB")
    return total_size

In [None]:
"""This block implements inference without KV cache"""
import time
import numpy as np
import tensorflow as tf

# Initialize the model
model_config = WayformerTrainingConfig()
model = WayformerCache(model_config, action_space)

# Get a test scenario
test_scenario = next(dataloader.simulator_state_generator(config=config))

# Number of autoregressive steps for inference benchmarking
num_autoregressive_steps = 512  # Or any desired number of steps

# Initialize the state with the test scenario
state = test_scenario

# Pad the state to have sufficient timesteps
initial_timestep = state.timestep
required_timesteps = initial_timestep + num_autoregressive_steps + 1  # +1 for zero-based indexing
state = pad_state_to_num_timesteps(state, required_timesteps)

# Initialize the history buffer for motion tokens
history_motion_tokens = []

total_time = 0.0

# Initialize the KV cache
# If cache is None, initialize it
cache = None
elapsed_time1 = 0.0
for step in range(num_autoregressive_steps):
    # Process the current state to get model input
    test_data = process_waymax_data(state, action_space)

    # Update history buffers
    if step == 0:
        # For the first step, initialize with zeros or appropriate start tokens
        initial_motion_token = 0  # Replace with your start token if applicable
        history_motion_tokens.append(initial_motion_token)
    else:
        start_time = time.time()
        # Append the predicted motion token from the previous step
        history_motion_tokens.append(predicted_motion_token)
        end_time = time.time()
        elapsed_time1 = end_time - start_time
        total_time += elapsed_time1

    # Prepare the model inputs, including the history
    # With KV cache, we only need the last token
    model_input = prepare_model_input_with_history(
        [test_data],
        history_motion_tokens[-1:],  # Use only the last token
        T=len(history_motion_tokens)
    )

    # Convert model inputs to tensors
    model_inputs = {k: tf.convert_to_tensor(v, dtype=tf.float32) for k, v in model_input.items()}
    #print input shape
    print(f"Model input shape for step {step}: {model_inputs['motion_tokens'].shape}")

    # Perform inference
    # Measure time before inference
    start_time = time.time()

    # Pass the cache to the model
    predictions = model(model_inputs, cache=cache, training=False)

    # Measure time after inference
    end_time = time.time()
    # Calculate elapsed time
    elapsed_time2 = end_time - start_time
    total_time += elapsed_time2
    # Print the timing result
    print(f"Auto-regressive inference time for step {step}: {elapsed_time2+elapsed_time1:.6f} seconds")

    # Get the predicted motion token from logits
    # 'logits' has shape (batch_size, 1, vocab_size)
    # We need the predicted token for the current step
    predicted_motion_token = tf.argmax(predictions['logits'], axis=-1).numpy()[0, 0]

    # Convert motion token to acceleration
    acceleration = action_space.continuous_acceleration(predicted_motion_token)  # Shape: (2,)

    # Update the position and velocity using Verlet integration
    ego_index = 0  # Assuming ego vehicle is at index 0

    # Get current position and velocity from the state
    current_position = np.array([
        state.sim_trajectory.x[ego_index, state.timestep],
        state.sim_trajectory.y[ego_index, state.timestep]
    ])

    current_velocity = np.array([
        state.sim_trajectory.vel_x[ego_index, state.timestep],
        state.sim_trajectory.vel_y[ego_index, state.timestep]
    ])

    # Use the verlet_update function to compute the next position and velocity
    new_position, new_velocity = action_space.verlet_update(current_position, current_velocity, acceleration)

    # Optionally compute the new yaw angle based on the velocity vector
    next_yaw = np.arctan2(new_velocity[1], new_velocity[0])  # Calculate yaw from velocity components

    # Update the state with the predicted values
    state = state.replace(
        sim_trajectory=state.sim_trajectory.replace(
            x=state.sim_trajectory.x.at[ego_index, state.timestep + 1].set(new_position[0]),
            y=state.sim_trajectory.y.at[ego_index, state.timestep + 1].set(new_position[1]),
            vel_x=state.sim_trajectory.vel_x.at[ego_index, state.timestep + 1].set(new_velocity[0]),
            vel_y=state.sim_trajectory.vel_y.at[ego_index, state.timestep + 1].set(new_velocity[1]),
            yaw=state.sim_trajectory.yaw.at[ego_index, state.timestep + 1].set(next_yaw),
            valid=state.sim_trajectory.valid.at[ego_index, state.timestep + 1].set(True)
        ),
        timestep=state.timestep + 1  # Advance the timestep
    )

    # Break the loop if we've reached the end of available timesteps
    if state.timestep >= state.sim_trajectory.x.shape[1] - 1:
        break

# Print total time
print(f"Total auto-regressive inference time with KV cache: {total_time:.6f} seconds")

In [None]:
"""This block implements inference with FullCache (baseline KV cache)"""
import time
import numpy as np
import tensorflow as tf

# Initialize the model
model_config = WayformerTrainingConfig()
model = WayformerCache(model_config, action_space)

# Get a test scenario
test_scenario = next(dataloader.simulator_state_generator(config=config))

# Number of autoregressive steps for inference benchmarking
num_autoregressive_steps = 512  # Or any desired number of steps

# Initialize the state with the test scenario
state = test_scenario

# Pad the state to have sufficient timesteps
initial_timestep = state.timestep
required_timesteps = initial_timestep + num_autoregressive_steps + 1  # +1 for zero-based indexing
state = pad_state_to_num_timesteps(state, required_timesteps)

# Initialize the history buffer for motion tokens
history_motion_tokens = []

total_time = 0.0

# Initialize the FullCache for each decoder layer
cache = {}
for idx, decoder_layer in enumerate(model.decoder_layers):
    cache_key = f'decoder_layer_{idx}_self_attn'
    cache[cache_key] = FullCache(
        max_cache_length=num_autoregressive_steps,
        num_heads=model_config.tx_num_heads,
        depth=model_config.hidden_size // model_config.tx_num_heads
    )

for step in range(num_autoregressive_steps):
    # Process the current state to get model input
    test_data = process_waymax_data(state, action_space)

    # Update history buffers
    if step == 0:
        # For the first step, initialize with zeros or appropriate start tokens
        initial_motion_token = 0  # Replace with your start token if applicable
        history_motion_tokens.append(initial_motion_token)
    else:
        # Append the predicted motion token from the previous step
        history_motion_tokens.append(predicted_motion_token)

    # Prepare the model inputs, including the history
    # With FullCache, we process one token at a time
    model_input = prepare_model_input_with_history(
        [test_data],
        history_motion_tokens[-1:],  # Use only the last token
        T=1
    )

    # Convert model inputs to tensors
    model_inputs = {k: tf.convert_to_tensor(v, dtype=tf.float32) for k, v in model_input.items()}
    # Print input shape
    print(f"Model input shape for step {step}: {model_inputs['motion_tokens'].shape}")

    # Perform inference
    # Measure time before inference
    start_time = time.time()

    # Pass the cache to the model
    predictions = model(model_inputs, cache=cache, training=False)

    # Measure time after inference
    end_time = time.time()
    # Calculate elapsed time
    elapsed_time = end_time - start_time
    total_time += elapsed_time
    # Print the timing result
    print(f"Auto-regressive inference time for step {step}: {elapsed_time:.6f} seconds")

    # Get the predicted motion token from logits
    # 'logits' has shape (batch_size, 1, vocab_size)
    # We need the predicted token for the current step
    predicted_motion_token = tf.argmax(predictions['logits'], axis=-1).numpy()[0, 0]

    # Convert motion token to acceleration
    acceleration = action_space.continuous_acceleration(predicted_motion_token)  # Shape: (2,)

    # Update the position and velocity using Verlet integration
    ego_index = 0  # Assuming ego vehicle is at index 0

    # Get current position and velocity from the state
    current_position = np.array([
        state.sim_trajectory.x[ego_index, state.timestep],
        state.sim_trajectory.y[ego_index, state.timestep]
    ])

    current_velocity = np.array([
        state.sim_trajectory.vel_x[ego_index, state.timestep],
        state.sim_trajectory.vel_y[ego_index, state.timestep]
    ])

    # Use the verlet_update function to compute the next position and velocity
    new_position, new_velocity = action_space.verlet_update(current_position, current_velocity, acceleration)

    # Optionally compute the new yaw angle based on the velocity vector
    next_yaw = np.arctan2(new_velocity[1], new_velocity[0])  # Calculate yaw from velocity components

    # Update the state with the predicted values
    state = state.replace(
        sim_trajectory=state.sim_trajectory.replace(
            x=state.sim_trajectory.x.at[ego_index, state.timestep + 1].set(new_position[0]),
            y=state.sim_trajectory.y.at[ego_index, state.timestep + 1].set(new_position[1]),
            vel_x=state.sim_trajectory.vel_x.at[ego_index, state.timestep + 1].set(new_velocity[0]),
            vel_y=state.sim_trajectory.vel_y.at[ego_index, state.timestep + 1].set(new_velocity[1]),
            yaw=state.sim_trajectory.yaw.at[ego_index, state.timestep + 1].set(next_yaw),
            valid=state.sim_trajectory.valid.at[ego_index, state.timestep + 1].set(True)
        ),
        timestep=state.timestep + 1  # Advance the timestep
    )

    # Break the loop if we've reached the end of available timesteps
    if state.timestep >= state.sim_trajectory.x.shape[1] - 1:
        break

# Print total time
print(f"Total auto-regressive inference time with FullCache: {total_time:.6f} seconds")
# After the inference loop
memory_usage_bytes = get_cache_memory_usage(cache)
memory_usage_mb = memory_usage_bytes / (1024 ** 2)
print(f"Total KV cache memory usage: {memory_usage_mb:.2f} MB")

In [None]:
"""This block implements inference with SlidingWindowCache"""
import time
import numpy as np
import tensorflow as tf

# Initialize the model
model_config = WayformerTrainingConfig()
model = WayformerCache(model_config, action_space)

# Get a test scenario
test_scenario = next(dataloader.simulator_state_generator(config=config))

# Number of autoregressive steps for inference benchmarking
num_autoregressive_steps = 512  # Or any desired number of steps

# Initialize the state with the test scenario
state = test_scenario

# Pad the state to have sufficient timesteps
initial_timestep = state.timestep
required_timesteps = initial_timestep + num_autoregressive_steps + 1  # +1 for zero-based indexing
state = pad_state_to_num_timesteps(state, required_timesteps)

# Initialize the history buffer for motion tokens
history_motion_tokens = []

total_time = 0.0

# Initialize the SlidingWindowCache for each decoder layer
window_size = 10  # Set the desired window size
cache = {}
for idx, decoder_layer in enumerate(model.decoder_layers):
    cache_key = f'decoder_layer_{idx}_self_attn'
    cache[cache_key] = SlidingWindowCache(
        window_size=window_size,
        num_heads=model_config.tx_num_heads,
        head_dim=model_config.hidden_size // model_config.tx_num_heads
    )

for step in range(num_autoregressive_steps):
    # Process the current state to get model input
    test_data = process_waymax_data(state, action_space)

    # Update history buffers
    if step == 0:
        # For the first step, initialize with zeros or appropriate start tokens
        initial_motion_token = 0  # Replace with your start token if applicable
        history_motion_tokens.append(initial_motion_token)
    else:
        # Append the predicted motion token from the previous step
        history_motion_tokens.append(predicted_motion_token)

    # Prepare the model inputs, including the history
    # With SlidingWindowCache, we only need the last token
    model_input = prepare_model_input_with_history(
        [test_data],
        history_motion_tokens[-1:],  # Use only the last token
        T=1
    )

    # Convert model inputs to tensors
    model_inputs = {k: tf.convert_to_tensor(v, dtype=tf.float32) for k, v in model_input.items()}
    # Print input shape
    print(f"Model input shape for step {step}: {model_inputs['motion_tokens'].shape}")

    # Perform inference
    # Measure time before inference
    start_time = time.time()

    # Pass the cache to the model
    predictions = model(model_inputs, cache=cache, training=False)

    # Measure time after inference
    end_time = time.time()
    # Calculate elapsed time
    elapsed_time = end_time - start_time
    total_time += elapsed_time
    # Print the timing result
    print(f"Auto-regressive inference time for step {step}: {elapsed_time:.6f} seconds")

    # Get the predicted motion token from logits
    # 'logits' has shape (batch_size, 1, vocab_size)
    # We need the predicted token for the current step
    predicted_motion_token = tf.argmax(predictions['logits'], axis=-1).numpy()[0, 0]

    # Convert motion token to acceleration
    acceleration = action_space.continuous_acceleration(predicted_motion_token)  # Shape: (2,)

    # Update the position and velocity using Verlet integration
    ego_index = 0  # Assuming ego vehicle is at index 0

    # Get current position and velocity from the state
    current_position = np.array([
        state.sim_trajectory.x[ego_index, state.timestep],
        state.sim_trajectory.y[ego_index, state.timestep]
    ])

    current_velocity = np.array([
        state.sim_trajectory.vel_x[ego_index, state.timestep],
        state.sim_trajectory.vel_y[ego_index, state.timestep]
    ])

    # Use the verlet_update function to compute the next position and velocity
    new_position, new_velocity = action_space.verlet_update(current_position, current_velocity, acceleration)

    # Optionally compute the new yaw angle based on the velocity vector
    next_yaw = np.arctan2(new_velocity[1], new_velocity[0])  # Calculate yaw from velocity components

    # Update the state with the predicted values
    state = state.replace(
        sim_trajectory=state.sim_trajectory.replace(
            x=state.sim_trajectory.x.at[ego_index, state.timestep + 1].set(new_position[0]),
            y=state.sim_trajectory.y.at[ego_index, state.timestep + 1].set(new_position[1]),
            vel_x=state.sim_trajectory.vel_x.at[ego_index, state.timestep + 1].set(new_velocity[0]),
            vel_y=state.sim_trajectory.vel_y.at[ego_index, state.timestep + 1].set(new_velocity[1]),
            yaw=state.sim_trajectory.yaw.at[ego_index, state.timestep + 1].set(next_yaw),
            valid=state.sim_trajectory.valid.at[ego_index, state.timestep + 1].set(True)
        ),
        timestep=state.timestep + 1  # Advance the timestep
    )

    # Break the loop if we've reached the end of available timesteps
    if state.timestep >= state.sim_trajectory.x.shape[1] - 1:
        break

# Print total time
print(f"Total auto-regressive inference time with SlidingWindowCache: {total_time:.6f} seconds")
# After the inference loop
memory_usage_bytes = get_cache_memory_usage(cache)
memory_usage_mb = memory_usage_bytes / (1024 ** 2)
print(f"Total KV cache memory usage: {memory_usage_mb:.2f} MB")