In [4]:
import os
from time import perf_counter
import jax.numpy as jnp
import jax
from waymax import config as _config
import dataclasses
from waymax import agents
from waymax import dataloader
from waymax import env as _env
from waymax import config as _config
from waymax import dynamics, datatypes
import warnings
warnings.filterwarnings("ignore")

from functools import partial
from typing import Callable, Optional

import jax
import jax.numpy as jnp
from jax.lib import xla_bridge

from waymax import datatypes
from waymax import dynamics
from waymax.agents import actor_core

### Constants

In [5]:
MAX_NUM_OBJECTS = 128
BATCH_SIZE = 2
NUM_STEPS = 500
EPISODE_LENGTH = 90

### Configurations

In [6]:
# Configure
data_config = dataclasses.replace(
    _config.WOD_1_0_0_TRAINING,
    max_num_objects=MAX_NUM_OBJECTS,
    batch_dims=(BATCH_SIZE,),
)

env_config = dataclasses.replace(
    _config.EnvironmentConfig(),
    max_num_objects=MAX_NUM_OBJECTS,
    controlled_object=_config.ObjectType.VALID,
)

# Initialize dataset (takes ~ 80 s)
data_iter = dataloader.simulator_state_generator(config=data_config)
next(data_iter);  # Skip first state (jit compiling / warmup)

2024-05-10 16:51:26.882144: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
2024-05-10 16:52:40.954408: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignor

### Make environment and populate it with actors

In [7]:
# Define environment
env = _env.MultiAgentEnvironment(
    dynamics_model=dynamics.InvertibleBicycleModel(
        normalize_actions=True,
    ),
    config=env_config,
)

obj_idx = jnp.arange(MAX_NUM_OBJECTS)

actor = agents.create_expert_actor(
    dynamics_model=dynamics.InvertibleBicycleModel(),
    is_controlled_func=lambda state: obj_idx > 1,
)

actors = [actor]

# Jit the step and select action functions
jit_step = jax.jit(env.step)
jit_select_action_list = [jax.jit(actor.select_action) for actor in actors]

# Rollout (step through environment, compute metrics and transition)
rng = jax.random.PRNGKey(0)

'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)


In [8]:
# Ensure that we're on the GPU
jax.devices()

[cuda(id=0)]

In [9]:
xla_bridge.get_backend().platform

'gpu'

### Step environment | **for loop**

In [10]:
state = env.reset(next(data_iter))

start_py_loop = perf_counter()

for _ in range(NUM_STEPS):
    
    # Take random actions
    outputs = [
        jit_select_action({}, state, None, rng)
        for jit_select_action in jit_select_action_list
    ]
    action = agents.merge_actions(outputs)

    step_return = env.reward(state, action)
    state = jit_step(state, action)

    if state.is_done: # Reset environment
        state = env.reset(next(data_iter))
        
end_py_loop = perf_counter()

print(f'Python for loop: {(end_py_loop - start_py_loop):.2f} s')

'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring f

Python for loop: 42.21 s


### Step environment | **scan loop**

We write the same code as above using a Jax scan loop

In [11]:
scenario = next(data_iter)

In [15]:
def scan_env_step(state, _):
    """Scanned env step function."""
    
    outputs = [
        jit_select_action({}, state, None, rng)
        for jit_select_action in jit_select_action_list
    ]
    
    action = agents.merge_actions(outputs)

    step_return = env.reward(state, action)
    
    next_state = jax.lax.cond(
        pred=state.is_done,
        true_fun=lambda: env.reset(scenario),
        false_fun=lambda: jit_step(state, action),
    )
    
    return next_state, next_state

In [16]:
state = env.reset(scenario)

start_jax_scan = perf_counter()

final_state, trajectory = jax.lax.scan(
    f=scan_env_step,
    init=state,
    xs=None,
    length=NUM_STEPS,
)

end_jax_scan = perf_counter()

print(f'Jax scan loop: {(end_jax_scan - start_jax_scan):.2f} s')

Jax scan loop: 0.44 s


'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
'+ptx84' is not a recognized feature for this target (ignoring feature)
