In [None]:
import os
from pathlib import Path
import mediapy
import matplotlib.pyplot as plt

working_dir = Path.cwd()
while working_dir.name != 'gpudrive':
    working_dir = working_dir.parent
    if working_dir == Path.home():
        raise FileNotFoundError("Base directory 'gpudrive' not found")
os.chdir(working_dir)

from gpudrive.env.dataset import SceneDataLoader
from gpudrive.env.config import EnvConfig
from gpudrive.env.env_torch import GPUDriveTorchEnv
from gpudrive.visualize.utils import img_from_fig

import logging
logging.basicConfig(level=logging.INFO)

%load_ext autoreload
%autoreload 2

### Dataset

In [12]:
train_loader = SceneDataLoader(
    root="data/processed/examples",
    batch_size=4, # Number of worlds
    dataset_size=1000,
    sample_with_replacement=False,
    shuffle=False,
)

### Model


In [None]:
from examples.experimental.eval_utils import load_policy, rollout

policy = load_policy(
    path_to_cpt='examples/experimental/models',
    model_name='',
    device='cpu'
)

In [14]:
# Check that the model weights are not random
# for name, param in policy.state_dict().items():
#     print(f"{name} - Mean: {param.mean():.4f}, Std: {param.std():.4f}")

In [None]:
policy

### GPUDriveTorchEnv

In [None]:
env = GPUDriveTorchEnv(
    config=EnvConfig(),
    data_loader=train_loader,
    max_cont_agents=64, 
    device="cpu",
)

print(env.data_batch)

obs = env.reset()[env.cont_agent_mask]

print(f'observation_space: {env.observation_space}')
print(f'obs shape: {obs.shape}')
print(f'obs dtype: {obs.dtype} \n')

print(f'action_space: {env.action_space}')

plt.hist(obs.flatten());

In [None]:
# Show simulator to make sure we're at the same state
env.vis.figsize = (5, 5)
sim_states = env.vis.plot_simulator_state(
    env_indices=[0],
    zoom_radius=100,
    time_steps=[0],
)

sim_states[0]

In [None]:
( 
 goal_achieved_count,
 frac_goal_achieved,
 collided_count,
 frac_collided,
 off_road_count,
 frac_off_road,
 not_goal_nor_crash_count,
 frac_not_goal_nor_crash_per_scene,
 controlled_agents_per_scene,
 sim_state_frames,
 agent_positions,
 episode_lengths
) = rollout(
    env=env, 
    policy=policy, 
    device='cpu', 
    render_sim_state=True,
    zoom_radius=100,
    deterministic=True,
)

print(f'\n Results: \n')
print(f'Goal achieved: {frac_goal_achieved}')
print(f'Collided: {frac_collided}')
print(f'Off road: {frac_off_road}')
print(f'Not goal nor crashed: {frac_not_goal_nor_crash_per_scene}')

In [None]:
# Show rollout videos
mediapy.show_videos(sim_state_frames, fps=15, codec='gif')