# PPO Forward Pass Wall-Clock Benchmark

Benchmarking the forward pass (policy + value) of `DiscretePPOAgent` using a dummy MiniGrid-like batch (batch size 2048).

Steps:
1. Create a `DiscretePPOAgent`.
4. Generate a dummy observation with ProcGen-like shapes and batch of size 2048.
5. Run a warm-up forward pass.
6. Benchmark with `%%timeit`.

The benchmark measures the time to compute: `get_action_and_value`.

In [None]:
import logging
import time

import torch
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf

from src.rl.agents.ppo_discrete import DiscretePPOAgent
from src.rl.environments.make_functions import make_procgen
from src.rl.utils.train import set_cuda_configuration, set_seeds

In [None]:
CONFIG_DIR = "config/procgen_paper"
CONFIG_NAME = "euclidean_baseline"  # hyper_paper, hyperpp
GPU = 1
WARMUP_STEPS = 10

In [None]:
if GlobalHydra.instance().is_initialized():
    GlobalHydra.instance().clear()
initialize(version_base=None, config_path=CONFIG_DIR, job_name="agent_timing")
cfg = compose(
    config_name=CONFIG_NAME,
    overrides=[
        "experiment.seed=23",
        "hydra.searchpath=[config]",
    ],
)
print(OmegaConf.to_yaml(cfg))

In [None]:
# Logging setup
logging.basicConfig(level=cfg.logging_level, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

# Derived fields
cfg.batch_size = int(cfg.num_envs * cfg.num_steps)
cfg.minibatch_size = int(cfg.batch_size // cfg.num_minibatches)
cfg.num_iterations = cfg.total_timesteps // cfg.batch_size
run_name = f"{cfg.env_id}__{cfg.experiment.exp_name}__{cfg.experiment.seed}__{int(time.time())}"
cfg.experiment.run_name = run_name

# Seeds and device
set_seeds(cfg.experiment.seed, torch_deterministic=cfg.experiment.torch_deterministic)
device = set_cuda_configuration(GPU)

In [None]:
envs = make_procgen(
    env_id=cfg.env_id,
    num_envs=cfg.num_envs,
    level_distribution=cfg.level_distribution,
    start_level=0,
    num_levels=cfg.num_levels,
    capture_video=cfg.experiment.capture_video,
    gamma=cfg.gamma,
    run_name=run_name,
)

In [None]:
agent = DiscretePPOAgent(
    env_type=cfg.env_type,
    envs=envs,
    gamma=cfg.gamma,
    num_steps=cfg.num_steps,
    gae_lambda=cfg.gae_lambda,
    batch_size=cfg.batch_size,
    minibatch_size=cfg.minibatch_size,
    update_epochs=cfg.update_epochs,
    clip_coef=cfg.clip_coef,
    ent_coef=cfg.ent_coef,
    vf_coef=cfg.vf_coef,
    max_grad_norm=cfg.max_grad_norm,
    target_kl=cfg.target_kl,
    norm_adv=cfg.norm_adv,
    embedding_dim=cfg.embedding_dim,
    shared_encoder=cfg.shared_encoder,
    last_layer_tanh=cfg.last_layer_tanh,
    feat_reg_coef=cfg.feat_reg_coef,
    compute_embedding_metrics=cfg.compute_embedding_metrics,
    actor_cfg=cfg.policy,
    critic_cfg=cfg.value_fn,
    optim_cfg=cfg.optimizer,
    device=device,
).to(device)

In [None]:
obs_shape = envs.observation_space.sample()["rgb"].shape

In [None]:
test = torch.randint(0, 255, (cfg.minibatch_size, *obs_shape), dtype=torch.uint8).to(device)

In [None]:
for _ in range(WARMUP_STEPS):
    _ = agent.get_action_and_value(test)

In [None]:
torch.cuda.synchronize()

In [None]:
%%timeit
_ = agent.get_action_and_value(test)