## Gymnasium environments

This Section shows how you can make and use the `gym` environments that interface with the simulator.

In [3]:
import os
from pathlib import Path
import numpy as np
import torch
import imageio
from IPython.display import HTML, Image

# Set working directory to the base directory 'gpudrive'
working_dir = Path.cwd()
while working_dir.name != 'gpudrive':
    working_dir = working_dir.parent
    if working_dir == Path.home():
        raise FileNotFoundError("Base directory 'gpudrive' not found")
os.chdir(working_dir)

from pygpudrive.env.config import EnvConfig, RenderConfig
from pygpudrive.env.env_torch import GPUDriveTorchEnv

### Helper functions

In [4]:
def display_gif(filename, width=500, height=500):
    display(
        Image(
            data=open(filename, "rb").read(), format="gif", width=width, height=height
        )
    )

### Settings

In [6]:
EPISODE_LENGTH = 91  # Number of steps in each episode
MAX_NUM_OBJECTS = 128  # Maximum number of objects in the scene we control
NUM_WORLDS = 50  # Number of parallel environments

# Set the path to where you want to save the videos
VIDEO_PATH = "./videos"

SCENE_NAME = "example_scene"

FPS = 4  # Video frames per second

### Initializing environments

- We provide both a torch and jax gymnasium interface with the simulator. Most functionality is specified in the `GPUDriveGymEnv` class in the `base_env`, `torch_env` and `jax_env` both inherit from the `GPUDriveGymEnv`, the only difference between these is that one exports torch tensors and the other jax arrays.
- All environment settings are defined in the `EnvConfig` dataclass. 
- All rendering configs are defined in the `RenderConfig` dataclass


In [7]:
env_config = EnvConfig(
    steer_actions = torch.round(
        torch.linspace(-1.0, 1.0, 3), decimals=3),
    accel_actions = torch.round(
        torch.linspace(-3, 3, 3), decimals=3
    )
)

In [8]:
render_config = RenderConfig(
    resolution=(256, 256), # Make sure to set the resolution
)

---

> **For more info about the environment configurations, see `pygpudrive/env/README.md`**

---

In [None]:
env = GPUDriveTorchEnv(
    config=env_config,
    num_worlds=NUM_WORLDS,
    max_cont_agents=MAX_NUM_OBJECTS,  # Number of agents to control
    data_dir="example_data",
    device="cuda",
    render_config=render_config,
)

### Run an episode with `num_worlds` parallel environments

A single rollout (one episode) is implemented as follows:
- We step through 50 worlds simultaneously.
- Use the `world_render_idx` argument in `render(.)` to select which world to render.

In [None]:
obs = env.reset()
frames = []

for _ in range(EPISODE_LENGTH):
    # SELECT ACTIONS
    rand_action = torch.Tensor(
        [[env.action_space.sample() for _ in range(MAX_NUM_OBJECTS * NUM_WORLDS)]]
    ).reshape(NUM_WORLDS, MAX_NUM_OBJECTS)

    # STEP
    env.step_dynamics(rand_action)

    obs = env.get_obs()
    reward = env.get_rewards()
    done = env.get_dones()

    # RENDER
    frame = env.render(world_render_idx=0)
    frames.append(frame)

### Display video

In [None]:
# Calculate duration between frames
duration = 1 / FPS

Path(f"{VIDEO_PATH}/{SCENE_NAME}").parent.mkdir(parents=True, exist_ok=True)

# Save the frames as a gif
imageio.mimsave(f"{VIDEO_PATH}/{SCENE_NAME}.gif", frames, duration=duration)

In [None]:
display(HTML(f"<h3>{SCENE_NAME}</h3>"))
display_gif(f"{VIDEO_PATH}/{SCENE_NAME}.gif")