In [None]:
from stable_baselines3 import TD3,SAC
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env

import gymnasium as gym
from gymnasium.wrappers import RescaleAction
import torch
import nav2d        # Have to import the nav2d Python script, else we can't make env
import numpy as np
import os, re, json, time
from datetime import datetime
from tqdm import tqdm

In [None]:
import pyautogui

model_load = TD3.load('results/Nav2D_TD3_SB3_results/result_00006/run_60')

# testing parameters
n_test = 10
success_count = 0

# environment options
width = 1920
height = 1080
default_camera_config = {"azimuth" : 90.0, "elevation" : -90.0, "distance" : 3, "lookat" : [0.0, 0.0, 0.0]}
render_mode = "human" if n_test<=10 else "rgb_array"
camera_id = 2

DEFAULT_CAMERA = "overhead_camera"
ENABLE_FRAME = True                     # enable the body frames
RENDER_EVERY_FRAME = True              # similar sim speed as MuJoCo rendering when set to False, else slower

test_env = gym.make("Nav2D-v0", render_mode=render_mode, 
                    width=width,height=height,
                    default_camera_config=default_camera_config,
                    camera_id=camera_id,
                    max_episode_steps=1_000,
                    is_eval=False
                    )
obs, info = test_env.reset()

core_env = test_env.unwrapped
rew_goal = core_env.rew_goal_scale

if DEFAULT_CAMERA=="overhead_camera": pyautogui.press('tab')
if ENABLE_FRAME: pyautogui.press('e') 
if not RENDER_EVERY_FRAME: pyautogui.press('d') 

for eps in range(n_test):
    obs, _ = test_env.reset()
    done = False

    while not done:
        action, _ = model_load.predict(obs, deterministic=True)
        # print(f"{action}           ", end='\r')
        nobs, rew, term, trunc, info = test_env.step(action)
        print(f"action: {action} | rew_appr: {info.get('rew_approach',-10.0):10.6f}                      ", end="\r")
        done = term or trunc

        obs = nobs if not done else test_env.reset()[0]

        # --- count the success
        if rew == rew_goal: success_count += 1  

print(f"\rSuccess rate out of {n_test} runs is {success_count/n_test*100:.2f}%             ")
test_env.close()