# Scenario Data Loading

Load in Waymax

In [None]:
!pip install git+https://github.com/waymo-research/waymax.git@main#egg=waymo-waymax

In [None]:
%%capture
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses

from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import visualization

In [None]:
import jax
from jax import numpy as jnp
import numpy as np
import mediapy
from tqdm import tqdm
import dataclasses

from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax import dynamics
from waymax import env as _env
from waymax import agents
from waymax import visualization
from waymax.metrics import overlap

Authenticating as I am using colab. Note to run you need approved access to the Waymo open dataset.

In [None]:
from google.colab import auth
auth.authenticate_user()


-  `config.WOD_1_1_0_TRAINING` is a pre-defined configuration that points to version 1.1.0 of the Waymo Open Dataset.

By default, the `WOD_1_1_0_TRAINING` loads up to 128 objects (e.g. vehicles, pedestrians) per scenario.

- the `dataloader.simulator_state_generator` function creates an iterator
through Open Motion Dataset scenarios. Calling next on the iterator will retrieve the first scenario in the dataset.


In [None]:
#config = dataclasses.replace(_config.WOD_1_1_0_TRAINING, max_num_objects=64)
#data_iter = dataloader.simulator_state_generator(config=config)
#scenario = next(data_iter)

In [None]:
# Using logged trajectory
#img = visualization.plot_simulator_state(scenario, use_log_traj=True)
#mediapy.show_image(img)

In [None]:
#imgs = []

#state = scenario
#for _ in range(scenario.remaining_timesteps):
  #state = datatypes.update_state_by_log(state, num_steps=1)
  #imgs.append(visualization.plot_simulator_state(state, use_log_traj=True))

#mediapy.show_video(imgs, fps=10)

# Simulated Agents

### Data Processing


In [None]:
max_num_objects = 5 # Maximum number of objects in a scenerio

nums_people = []
nums_other = []
nums_peopleCyclist = []
cyclistPeople_dict = {}
people_dict = {}
other_dict = {}

num_scenerio = 600

config = dataclasses.replace(_config.WOD_1_0_0_VALIDATION, max_num_objects=max_num_objects)
data_iter = dataloader.simulator_state_generator(config=config)


for i in range(num_scenerio):
    scenarios = next(data_iter)
    metadata = scenarios.object_metadata
    object_types = metadata.object_types
    if 2 in object_types:
        nums_people.append(i)
        people_dict[i] = scenarios
        if 3 in object_types:
            nums_peopleCyclist.append(i)
            cyclistPeople_dict[i] = scenarios
    else:
        nums_other.append(i)
        other_dict[i] = scenarios


In [None]:
print(nums_peopleCyclist)
print("number of scenerios with people and cyclist = ", len(nums_peopleCyclist))
print(nums_people)
print("number of scenerios with people = ", len(nums_people))
print(nums_other)
print("number of scenerios with no people = ", len(nums_other))

In [None]:
scenario = cyclistPeople_dict[nums_peopleCyclist[18]] # setting scenario to one of the iterator objects within the dictionary

In [None]:
# Config the multi-agent environment:
init_steps = 11

# Set the dynamics model the environment is using.
# Note each actor interacting with the environment needs to provide action
# compatible with this dynamics model.
dynamics_model = dynamics.StateDynamics()

# Expect users to control all valid object in the scene.
env = _env.MultiAgentEnvironment(
    dynamics_model=dynamics_model,
    config=dataclasses.replace(
        _config.EnvironmentConfig(),
        max_num_objects=max_num_objects,
        controlled_object=_config.ObjectType.VALID,
    ),
)

In [None]:
metadata = scenario.object_metadata
print('All object IDS:', metadata.ids)

# 1 = vehicle
# 2 = pedestrian
# 3 = cyclist
# -1 = non-moving pedestrian?
print("Object types:", metadata.object_types)

In [None]:
# Setting up actors

# An actor that doesn't move, controlling all objects with index > 4
obj_idx = jnp.arange(max_num_objects)
static_actor = agents.create_constant_speed_actor(
    speed=0.0,
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: obj_idx > 4,
)

# IDM actor/policy controlling both object 0 and 1, 4.
# Note IDM policy is an actor hard-coded to use dynamics.StateDynamics().
actor_0 = agents.IDMRoutePolicy(
    is_controlled_func=lambda state: (obj_idx == 0) | (obj_idx == 1) | (obj_idx == 4)
)

# Constant speed actor with predefined fixed speed controlling object 1.
# actor_1 = agents.create_constant_speed_actor(
#     speed=5.0,
#     dynamics_model=dynamics_model,
#     is_controlled_func=lambda state: obj_idx == 1,
# )

# Exper/log actor controlling objects 2 and 3.
actor_2 = agents.create_expert_actor(
    dynamics_model=dynamics_model,
    is_controlled_func=lambda state: (obj_idx == 2) | (obj_idx == 3) #| (obj_idx == 4) #| (obj_idx == 4)
 )

actors = [static_actor, actor_0,  actor_2] # actor_1,

In [None]:
jit_step = jax.jit(env.step)
jit_select_action_list = [jax.jit(actor.select_action) for actor in actors]

In [None]:
states = [env.reset(scenario)]

# calculating states through time
for _ in range(states[0].remaining_timesteps):
  current_state = states[-1]

  outputs = [
      jit_select_action({}, current_state, None, None)
      for jit_select_action in jit_select_action_list
  ]
  action = agents.merge_actions(outputs)
  next_state = jit_step(current_state, action)

  states.append(next_state)

In [None]:
# need to figure out how to calculate collision metric

print(action.data) # positions of all the objects
print("Overlap metric", overlap.OverlapMetric().compute(states[0]).value) # calculates overlap metric for each state
print(states[0].object_metadata) # metadata of objects

In [None]:
collision_dic = {} # to collect collisions, key is the count, value is the collision array
count = 0
num_collisons = 0

print("Object types:", metadata.object_types)
# calculating if collision, if collision add to collsion dic
for state in states:
    count += 1
    collision = overlap.OverlapMetric().compute(state).value
    collision_dic[count] = collision
    if 1 in collision: # if there is a collision
        print(f"Collision: {collision}")
        num_collisons += 1

print("Number of collisions: ", num_collisons)

In [None]:
from waymax.visualization import utils

#viz_config = get_default_viz_config()
imgs = []


for state in states:
    #print(f"count: {count}")
    #print(state.viz_config.center_agent_idx)
    #print(f"center_agent_idx: {state.center_agent_idx}")
    #print(f"Available agent IDs: {state.object_metadata.ids}")

    # HERE was trying to fix weird error I was getting with "invalid" center agent
    # If center_agent_idx is invalid, update to a valid ID (e.g., the first agent):
    #if viz_config.center_agent_idx not in state.object_metadata.ids:
        #viz_config = dataclasses.replace(viz_config, center_agent_idx=state.object_metadata.ids[0] if state.object_metadata.ids else 0)
    #default_viz_config = utils.VizConfig()
    #modified_viz_config = dataclasses.replace(default_viz_config, center_agent_idx=center_agent_idx)
    # Get valid center_agent_idx
    #center_agent_idx = state.object_metadata.ids[0] if state.object_metadata.ids.size > 0 else 0
    #viz_config_kwargs = {'center_agent_idx': center_agent_idx}
    #viz_config = utils.VizConfig(**viz_config_kwargs)

    # Check if the state has any valid object IDs
    #if state.object_metadata.ids.size > 0:
        # Update center_agent_idx in the default viz_config
        #viz_config.center_agent_idx = state.object_metadata.ids[0]
    #else:
        # Handle cases where there are no valid object IDs
        #viz_config.center_agent_idx = 0  # Default to 0

    imgs.append(visualization.plot_simulator_state(state, use_log_traj=False)) # , viz_config=viz_config_dict)
mediapy.show_video(imgs, fps=10)

### Creating plots and statistics

**Scenario 0:** 0 collisions, 1 vehicle, 2 pedestrians, 2 cyclists <br>
**Scenario 1:** 23 timesteps of collision (btwn pedestrian and cyclist), 3 vehicles, 1 pedestrian, 1 cyclist <br>
**Scenario 2:** 0 collisions, 3 vehicles, 1 pedestrian, 1 cyclist <br>
**Scenario 3:** 0 collisions, 2 vehicles, 2 pedestrians, 1 cyclist <br>
**Scenario 4:** 0 collisions, 3 vehicles, 1 pedestrians, 1 cyclist <br>
**Scenario 5:** 0 collisions, 3 vehicles, 1 pedestrians, 1 cyclist <br>
**Scenario 6:** 0 collisions, 2 vehicles, 1 pedestrians, 2 cyclists <br>
**Scenario 7:** 12 timesteps of collison (btwn pedestrian and vehicle), 2 vehicles, 2 pedestrians 1 cyclist <br>
**Scenario 8:** 0 collisions, 3 vehicles, 1 pedestrian, 1 cyclist <br>
**Scenario 9: **0 collisions, 3 vehicles, 1 pedestrians, 1 cyclist <br>
**Scenario 10: **8 timesteps of collisions (btwn vehicle and cyclist), 2 vehicles, 1 pedestrians, 2 cyclists [interesting case of log playback verses IDM making a serious difference] <br>
**Scenario 11:** 0 collisions, 3 vehicles, 1 pedestrians, 1 cyclist <br>
**Scenario 12:** 21 timesteps of collisions (btwn pedestrian and vehicle), 3 vehicles, 1 pedestrian, 1 cyclist [another interesting case of collision avoided using log playback] <br>
**Scenario 13:** 0 collisions, 3 vehicles, 1 pedestrians, 1 cyclist <br>
**Scenario 14:** 0 collisions, 3 vehicles, 1 pedestrians, 1 cyclist <br>
**Scenario 15: **1 timestep of collisions (btwn 2 pedestrians), 2 vehicles, 2 pedestrians, 1 cyclist <br>
**Scenario 16:** 0 collisions, 3 vehicles, 1 pedestrian, 1 cyclist <br>
**Scenario 17:** 0 collisions, 3 vehicles, 1 pedestrians, 1 cyclist <br>
**Scenario 18:** 0 collisions, 3 vehicles, 1 pedestrians, 1 cyclist <br>


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(style="whitegrid")

data = {
    'Scenario': list(range(19)),
    'Collision_Timesteps': [0, 23, 0, 0, 0, 0, 0, 12, 0, 0, 8, 0, 21, 0, 0, 1, 0, 0, 0],
}

df = pd.DataFrame(data)

plt.figure(figsize=(12,6))
sns.barplot(x='Scenario', y='Collision_Timesteps', data=df, palette='viridis')
plt.title('Number of Collision Timesteps per Scenario')
plt.xlabel('Scenario')
plt.ylabel('Collision Timesteps')
plt.axhline(0, color='black', linewidth=0.8)
plt.savefig('collision_timesteps_per_scenario.png', dpi=300)
plt.show()

In [None]:
data_extended = {
    'Scenario': list(range(19)),
    'Vehicles': [1,3,3,2,3,3,2,2,3,3,2,3,3,3,3,2,3,3,3],
    'Pedestrians': [2,1,1,2,1,1,1,2,1,1,1,1,1,1,1,2,1,1,1],
    'Cyclists': [2,1,1,1,1,1,2,1,1,1,2,1,1,1,1,1,1,1,1],
}

df_extended = pd.DataFrame(data_extended)

df_melted = df_extended.melt(id_vars='Scenario', value_vars=['Vehicles', 'Pedestrians', 'Cyclists'],
                             var_name='Entity', value_name='Count')

plt.figure(figsize=(14,7))
sns.barplot(x='Scenario', y='Count', hue='Entity', data=df_melted, palette='Set2')
plt.title('Distribution of Vehicles, Pedestrians, and Cyclists per Scenario')
plt.xlabel('Scenario')
plt.ylabel('Count')
plt.legend(title='Entity')
plt.savefig('entities_per_scenario.png', dpi=300)
plt.show()

In [None]:
plt.figure(figsize=(10, 6))

sns.countplot(data=df_melted, x='Count', hue='Entity', edgecolor='black')

plt.title('Distribution of Vehicles, Pedestrians, and Cyclists per Scenario', fontsize=16) #, fontweight='bold')
plt.xlabel('Number of Entities', fontsize=14)
plt.ylabel('Number of Scenarios', fontsize=14)

plt.xticks(range(0,5))
plt.yticks(range(0, max(df_melted['Count'].value_counts()) + 2))
plt.legend(title='Entity', title_fontsize='13', fontsize='12')


plt.tight_layout()
plt.savefig('distribution_entities_countplot.png', dpi=300)

plt.ylim(0, 19)
plt.xlim(-1, 3)
plt.show()


In [None]:
collision_types = ['Pedestrian ↔ Cyclist', 'Pedestrian ↔ Vehicle', 'Vehicle ↔ Cyclist', 'Pedestrians']
counts = [1, 2, 1, 1]

plt.figure(figsize=(8,8))
plt.pie(counts, labels=collision_types, autopct='%1.1f%%', startangle=140, colors=sns.color_palette('Pastel1'))
plt.title('Distribution of Collision Types')
plt.savefig('collision_types_pie_chart.png', dpi=300)
plt.show()

In [None]:

data = {
    'Scenario': list(range(19)),
    'Collision_Timesteps': [0, 23, 0, 0, 0, 0, 0, 12, 0, 0, 8, 0, 21, 0, 0, 1, 0, 0, 0],
    'Collision_Type': ['None', 'Pedestrian ↔ Cyclist', 'None', 'None', 'None', 'None', 'None',
                       'Pedestrian ↔ Vehicle', 'None', 'None', 'Vehicle ↔ Cyclist', 'None',
                       'Pedestrian ↔ Vehicle', 'None', 'None', 'Pedestrians', 'None', 'None', 'None']
}


df_collisions = pd.DataFrame(data)

# collision timesteps by collision type
plt.figure(figsize=(10, 6))
sns.boxplot(x='Collision_Type', y='Collision_Timesteps', data=df_collisions, palette='Set2')
plt.title('Distribution of Collision Timesteps by Collision Type', fontsize=16)
plt.xlabel('Collision Type', fontsize=14)
plt.ylabel('Number of Collision Timesteps', fontsize=14)
plt.xticks(rotation=25)
plt.tight_layout()
plt.savefig('boxplot_collision_timesteps.png', dpi=300)
plt.show()


In [None]:
plt.figure(figsize=(10, 6))

sns.swarmplot(x='Collision_Type', y='Collision_Timesteps', data=df_collisions, palette='Set1', size=10)
plt.title('Collision Timesteps by Collision Type', fontsize=16)
plt.xlabel('Collision Type', fontsize=14)
plt.ylabel('Number of Collision Timesteps', fontsize=14)
plt.xticks(rotation=15)
plt.tight_layout()
plt.savefig('swarmplot_collision_timesteps.png', dpi=300)
plt.show()


In [None]:

collision_data = {
    'Collision_Type': ['Pedestrian ↔ Cyclist', 'Pedestrian ↔ Vehicle', 'Vehicle ↔ Cyclist', 'Pedestrians'],
    'Count': [1, 2, 1, 1]
}


df_collision_types = pd.DataFrame(collision_data)


sns.set(style="whitegrid")


plt.figure(figsize=(8,6))
sns.barplot(x='Collision_Type', y='Count', data=df_collision_types, palette='Set2')


plt.title('Distribution of Collision Types', fontsize=16)
plt.xlabel('Collision Type', fontsize=14)
plt.ylabel('Number of Occurrences', fontsize=14)
plt.xticks(rotation=45, ha='right')
plt.tight_layout()


plt.savefig('collision_types_distribution.png', dpi=300)
plt.show()
