In [None]:
import os
from matplotlib import rc
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tqdm
import os
import sys
from pathlib import Path

# Set working directory explicitly
project_root = Path("/home/emerge/PufferDrive").resolve()
os.chdir(project_root)

# Verify
print(f"Working directory: {os.getcwd()}")
print(f"Resources exist: {os.path.exists('resources/drive/binaries/map_000.bin')}")

# Now set up your path stuff
sys.path.insert(0, str(project_root))
sys.argv = [""]

import pufferlib.pufferl as pufferl

### Configs

In [None]:
ENV_NAME = "puffer_drive"
NUM_ENVS = 1
MAX_AGENTS = 3
BACKEND = "Serial"

### Dependencies

In [None]:
import pufferlib.ocean.wosac.evaluator as wosac_evaluator

### Set up environment


In [None]:
args = pufferl.load_config(ENV_NAME)
args["vec"] = dict(backend=BACKEND, num_envs=NUM_ENVS)
args["env"]["num_agents"] = MAX_AGENTS
args["init_steps"] = 10
args["env"]["control_non_vehicles"] = True

vecenv = pufferl.load_env(ENV_NAME, args)
policy = pufferl.load_policy(args, vecenv, ENV_NAME)

### Pipeline

In [None]:
from importlib import reload

reload(pufferl)
reload(wosac_evaluator);

In [None]:
evaluator = wosac_evaluator.WOSACEvaluator(args)

# Roll out trained policy in the simulator to collect trajectories
# Output is a dict with every element (e.g., "x") of shape: [num_agents, num_rollouts, num_steps]
simulated_trajectories = evaluator.collect_simulated_trajectories(args, vecenv=vecenv, policy=policy)

print(f"keys: {simulated_trajectories.keys()}")
print(f"shape of x: {simulated_trajectories['x'].shape}")
print(f"scenario_id: {simulated_trajectories['scenario_id'][0]}")

In [None]:
# Sanity check: Visualize rollouts
plt.title("Simulated trajectories")
plt.scatter(simulated_trajectories["x"][0, :, :], simulated_trajectories["y"][0, :, :], alpha=0.1);

In [None]:
simulated_trajectories["id"]

In [None]:
# TODO(2) Prepare ground truth data
# x_batch, y_batch, z_batch, heading_batch = evaluator.collect_ground_truth_data()

# TODO(3) Compute WOSAC metrics
# results = evaluator.compute_metrics(x_hat_batch, y_hat_batch, z_hat_batch, heading_hat_batch)

# return results