In [None]:
import gymnasium as gym
from gym.wrappers import TimeLimit

import src.gymnasium_env  # noqa: F401
from src.config import build_cfg
from src.gymnasium_env.wrappers import OneHotFlatPosWrapper, OneHotGridPosWrapper
from src.train import train_dql_agent
from src.utils import query_environment

%load_ext autoreload
%autoreload 2


def make_wrapped_env(
    env: gym.Env, grid_size: int, max_steps_per_episode: int, flat: bool
) -> gym.Env:
    env = (
        OneHotFlatPosWrapper(env, grid_size=grid_size)
        if flat
        else OneHotGridPosWrapper(env, grid_size=grid_size)
    )
    env = TimeLimit(env, max_episode_steps=max_steps_per_episode)
    return env

In [None]:
query_environment("GridWorld-v0")

In [None]:
cfg = build_cfg(
    "src/configs/gridworld.yaml",
)
env = gym.make(
    "GridWorld-v0",
    size=cfg.grid_size,
    reset_success_count=500,
)
eval_env = gym.make("GridWorld-v0", size=cfg.grid_size, render_mode="rgb_array")

In [None]:
train_dql_agent(
    config=cfg,
    env=make_wrapped_env(env, cfg.grid_size, cfg.max_steps_per_episode, True),
    state_dim=2 * cfg.grid_size**2,
    eval_env=make_wrapped_env(eval_env, cfg.grid_size, cfg.max_steps_per_episode, True),
)

In [None]:
train_dql_agent(
    config=cfg,
    env=make_wrapped_env(env, cfg.grid_size, cfg.max_steps_per_episode, False),
    state_dim=(cfg.grid_size, cfg.grid_size, 2),
    eval_env=make_wrapped_env(
        eval_env, cfg.grid_size, cfg.max_steps_per_episode, False
    ),
)