## Scenario loading and structure 

`GPUDrive` is a multi-agent driving simulator built on top of the [Waymo Open Motion Dataset (WOMD)](https://waymo.com/open/) (See also [Ettinger et al., 2021](https://arxiv.org/abs/2104.10133)). 

In this tutorial, we explain the structure of a traffic scenario and show use processed scenario data with `GPUDrive`.

**Useful links to learn more**:
- [`waymo-open-dataset`](https://github.com/waymo-research/waymo-open-dataset): Official dataset repo
- [tf.Example proto format](https://waymo.com/open/data/motion/tfexample): Data dictionary for a raw WOMD scenario
- [GPUDrive `data_utils`](https://github.com/Emerge-Lab/gpudrive/tree/main/data_utils): Docs and code we use to process the WOMD scenarios

In [13]:
# Dependencies
import json
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
import pandas as pd

# Set working directory to the base directory 'gpudrive'
working_dir = Path.cwd()
while working_dir.name != 'my_gpudrive':
    working_dir = working_dir.parent
    if working_dir == Path.home():
        raise FileNotFoundError("Base directory 'gpudrive' not found")
os.chdir(working_dir)

cmap = ["r", "g", "b", "y", "c"]
%config InlineBackend.figure_format = 'svg'
sns.set("notebook", font_scale=1.1, rc={"figure.figsize": (8, 3)})
sns.set_style("ticks", rc={"figure.facecolor": "none", "axes.facecolor": "none"})

### Iterating through the WOMD dataset

We upload a folder containing three scenarios in the `data/processed/examples` directory that you can work with. The full dataset can be downloaded [here](https://github.com/Emerge-Lab/gpudrive/tree/main?tab=readme-ov-file#dataset). 


Notice that the data folder is structured as follows:

```bash
data/
    - tfrecord-xxxxx-of-xxxxx
    - ....
    - tfrecord-xxxxx-of-xxxxx
```

Every file beginning with `tfrecord` is a unique traffic scenario.

To use the dataset with the simulator, we use the conventions from [PyTorch dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). 


Here is example of how to set up a dataloader:

In [14]:
from gpudrive.env.dataset import SceneDataLoader

data_loader = SceneDataLoader(
    root="data/processed/examples",  # Path to the dataset
    # Batch size, you want this to be equal to the number of worlds (envs) so that every world receives a different scene
    batch_size=10,
    dataset_size=4,  # Total number of different scenes we want to use
    sample_with_replacement=True,
    seed=42,
    shuffle=True,
)

In [15]:
# The full dataset that we will be using
data_loader.dataset

['data/processed/examples/tfrecord-00002-of-01000_407.json',
 'data/processed/examples/tfrecord-00002-of-01000_407.json',
 'data/processed/examples/tfrecord-00000-of-01000_402.json',
 'data/processed/examples/tfrecord-00000-of-01000_325.json',
 'data/processed/examples/tfrecord-00000-of-01000_4.json',
 'data/processed/examples/tfrecord-00000-of-01000_402.json',
 'data/processed/examples/tfrecord-00000-of-01000_4.json',
 'data/processed/examples/tfrecord-00000-of-01000_325.json',
 'data/processed/examples/tfrecord-00000-of-01000_325.json',
 'data/processed/examples/tfrecord-00000-of-01000_4.json']

In [16]:
# Notice that it only has 4 unique scenes, since we set the dataset_size to 4
set(data_loader.dataset)

{'data/processed/examples/tfrecord-00000-of-01000_325.json',
 'data/processed/examples/tfrecord-00000-of-01000_4.json',
 'data/processed/examples/tfrecord-00000-of-01000_402.json',
 'data/processed/examples/tfrecord-00002-of-01000_407.json'}

In [17]:
data_files = next(iter(data_loader))

data_files[0]

'data/processed/examples/tfrecord-00002-of-01000_407.json'

In [18]:
from gpudrive.env.env_torch import GPUDriveTorchEnv
from gpudrive.env.config import EnvConfig

In [19]:
# Pass the data_loader to the environment
env = GPUDriveTorchEnv(
    config=EnvConfig(),
    data_loader=data_loader,
    max_cont_agents=64,
    device="cpu",
)

### Deep dive: What is inside a traffic scenario? 🤔🔬

Though every scenario in the WOMD is unique, they all share the same basic data structure. Traffic scenarios are essentially dictionaries, which you can inspect using tools like [JSON Formatter](https://jsonformatter.org/json-viewer). We'll also look at one in this notebook. In a nutshell, traffic scenarios contain a few key elements:

- **Road map**: The layout and structure of the roads.
- **Human driving (expert) demonstrations**: Examples of human driving behavior.
- **Road objects**: Elements such as stop signs and other traffic signals.

In [20]:
# Take an example scene
data_path = "data/processed/examples/tfrecord-00000-of-01000_325.json"

with open(data_path) as file:
    traffic_scene = json.load(file)

traffic_scene.keys()

dict_keys(['name', 'scenario_id', 'objects', 'roads', 'tl_states', 'metadata'])


We will show you how to render a scene in ⏭️ tutorial `03`, which introduces the gym environment wrapper. Let's first take a closer look at the data structure.

### Global Overview

A traffic scene includes the following key elements:

- **`name`**: The name of the traffic scenario.  
- **`scenario_id`**: Unique identifier for every scenario.
- **`objects`**: Dynamic entities such as vehicles or other moving elements in the scene.  
- **`roads`**: Stationary elements, including road points and fixed objects.  
- **`tl_states`**: Traffic light states (currently not included in processing).  
- **`metadata`**: Additional details about the traffic scenario, such as the index of the self-driving car (SDC) and details for the WOSAC Challenge.

In [21]:
traffic_scene["tl_states"]

{}

In [22]:
traffic_scene["name"]

'tfrecord-00000-of-01000_325.json'

In [23]:
traffic_scene["metadata"]

{'sdc_track_index': 61,
 'objects_of_interest': [],
 'tracks_to_predict': [{'track_index': 3, 'difficulty': 0},
  {'track_index': 32, 'difficulty': 0},
  {'track_index': 1, 'difficulty': 0}]}

In [24]:
traffic_scene["scenario_id"]

'ef3a8f65142f41ac'

In [25]:
pd.Series(
    [
        traffic_scene["objects"][idx]["type"]
        for idx in range(len(traffic_scene["objects"]))
    ]
).value_counts().plot(kind="bar", rot=45, color=cmap)
plt.title(
    f'Distribution of road objects in traffic scene. Total # objects: {len(traffic_scene["objects"])}'
)
plt.show()

  plt.show()


This traffic scenario only contains vehicles and pedestrians, some scenes have cyclists as well.

In [26]:
pd.Series(
    [traffic_scene["roads"][idx]["type"]
        for idx in range(len(traffic_scene["roads"]))]
).value_counts().plot(kind="bar", rot=45, color=cmap)
plt.title(
    f'Distribution of road points in traffic scene. Total # points: {len(traffic_scene["roads"])}'
)
plt.show()

  plt.show()


### In-Depth: Road Objects

This is a list of different road objects in the traffic scene. For each road object, we have information about its position, velocity, size, in which direction it's heading, whether it's a valid object, the type, and the final position of the vehicle.

In [27]:
# Take the first object
idx = 0

# For each object, we have this information:
traffic_scene["objects"][idx].keys()

dict_keys(['position', 'width', 'length', 'height', 'heading', 'velocity', 'valid', 'goalPosition', 'type', 'id', 'mark_as_expert'])

In [28]:
# Position contains the (x, y) coordinates for the vehicle at every time step
print(json.dumps(traffic_scene["objects"][idx]["position"][:10], indent=4))

[
    {
        "x": -8330.69921875,
        "y": 8096.3125,
        "z": -38.52962875366211
    },
    {
        "x": -8329.8330078125,
        "y": 8095.95556640625,
        "z": -38.500244512748395
    },
    {
        "x": -8328.87109375,
        "y": 8095.533203125,
        "z": -38.49168714732872
    },
    {
        "x": -8327.91796875,
        "y": 8095.09716796875,
        "z": -38.482875825391815
    },
    {
        "x": -8326.9345703125,
        "y": 8094.64990234375,
        "z": -38.4840157041386
    },
    {
        "x": -8325.93359375,
        "y": 8094.1865234375,
        "z": -38.47635243273912
    },
    {
        "x": -8324.958984375,
        "y": 8093.74658203125,
        "z": -38.46117327593842
    },
    {
        "x": -8323.9189453125,
        "y": 8093.279296875,
        "z": -38.463948981567015
    },
    {
        "x": -8322.89453125,
        "y": 8092.822265625,
        "z": -38.461732818135395
    },
    {
        "x": -8321.853515625,
        "y": 8092.355

In [29]:
# Width and length together make the size of the object, and is used to see if there is a collision
traffic_scene["objects"][idx]["width"], traffic_scene["objects"][idx]["length"]

(2.066138744354248, 4.621184349060059)

An object's heading refers to the direction it is pointing or moving in. The default coordinate system in Nocturne is right-handed, where the positive x and y axes point to the right and downwards, respectively. In a right-handed coordinate system, 0 degrees is located on the x-axis and the angle increases counter-clockwise.

Because the scene is created from the viewpoint of an ego driver, there may be instances where the heading of certain vehicles is not available. These cases are represented by the value `-10_000`, to indicate that these steps should be filtered out or are invalid.

In [30]:
# Heading is the direction in which the vehicle is pointing
plt.plot(traffic_scene["objects"][idx]["heading"])
plt.xlabel("Time step")
plt.ylabel("Heading")
plt.show()

  plt.show()


In [31]:
# Velocity shows the velocity in the x- and y- directions
print(json.dumps(traffic_scene["objects"][idx]["velocity"][:10], indent=4))

[
    {
        "x": 8.662109375,
        "y": -3.5693359375
    },
    {
        "x": 8.662109375,
        "y": -3.5693359375
    },
    {
        "x": 9.619140625,
        "y": -4.2236328125
    },
    {
        "x": 9.53125,
        "y": -4.3603515625
    },
    {
        "x": 9.833984375,
        "y": -4.47265625
    },
    {
        "x": 10.009765625,
        "y": -4.6337890625
    },
    {
        "x": 9.74609375,
        "y": -4.3994140625
    },
    {
        "x": 10.400390625,
        "y": -4.6728515625
    },
    {
        "x": 10.244140625,
        "y": -4.5703125
    },
    {
        "x": 10.41015625,
        "y": -4.66796875
    }
]


In [32]:
# Valid indicates if the state of the vehicle was observed for each timepoint
plt.xlabel("Time step")
plt.ylabel("IS VALID")
plt.plot(traffic_scene["objects"][idx]["valid"], "_", lw=5)
plt.show()

  plt.show()


In [33]:
# Each object has a goalPosition, an (x, y) position within the scene
traffic_scene["objects"][idx]["goalPosition"]

{'x': -8275.5390625, 'y': 8071.49462890625, 'z': -38.305127086351085}

In [34]:
# Finally, we have the type of the vehicle
traffic_scene["objects"][idx]["type"]

'vehicle'

### In-Depth: Road Points

Road points are static objects in the scene.

In [35]:
traffic_scene["roads"][idx].keys()

dict_keys(['geometry', 'type', 'map_element_id', 'id'])

In [36]:
# This point represents the edge of a road
traffic_scene["roads"][idx]["type"]

'road_edge'

In [37]:
# Geometry contains the (x, y) position(s) for a road point
# Note that this will be a list for road lanes and edges but a single (x, y) tuple for stop signs and alike
print(json.dumps(traffic_scene["roads"][idx]["geometry"][:10], indent=4))

[
    {
        "x": -8541.524547445766,
        "y": 8167.38744960166,
        "z": -39.31986597933531
    },
    {
        "x": -8541.052105238197,
        "y": 8167.254393662127,
        "z": -39.32161462356804
    },
    {
        "x": -8540.579632931813,
        "y": 8167.121444633392,
        "z": -39.32336326780076
    },
    {
        "x": -8540.107127841891,
        "y": 8166.9886121632835,
        "z": -39.32336326780076
    },
    {
        "x": -8539.634587292505,
        "y": 8166.855905899633,
        "z": -39.32511191203348
    },
    {
        "x": -8539.162008608988,
        "y": 8166.723335492638,
        "z": -39.326860556266205
    },
    {
        "x": -8538.689389127983,
        "y": 8166.590910594858,
        "z": -39.32860920049893
    },
    {
        "x": -8538.216726198703,
        "y": 8166.458640862011,
        "z": -39.33035784473165
    },
    {
        "x": -8537.744017176647,
        "y": 8166.326535951388,
        "z": -39.33210648896437
    },
    {
 

In [38]:
road_types = set()
for road in traffic_scene["roads"]:
    road_types.add(road["type"])

In [39]:
road_types

{'crosswalk', 'driveway', 'lane', 'road_edge', 'road_line', 'stop_sign'}