In [None]:
import gymnasium as gym


def query_environment(name):
    env = gym.make(name)
    spec = gym.spec(name)
    print(f"Action Space: {env.action_space}")
    print(f"Observation Space: {env.observation_space}")
    print(f"Max Episode Steps: {spec.max_episode_steps}")
    print(f"Nondeterministic: {spec.nondeterministic}")
    print(f"Reward Threshold: {spec.reward_threshold}")


query_environment("GridWorld-v0")

Action Space: Discrete(4)
Observation Space: Dict('agent': Box(0, 9, (2,), int64), 'goal': Box(0, 9, (2,), int64))
Max Episode Steps: None
Nondeterministic: False
Reward Threshold: None


bisogna fare una funzione chiamata `train_dqn_agent`, in cui:
- la funzione riceve due parametri `num_episodes` e `grid_size`
- inizialmente istanzia l'`env` e l'`agent`
- ad ogni episodio resetta l'`env`
- per ogni episodio fino che non ha finito (stato terminale):
    - l'`agent` seleziona un azione tramite epsilon greedy
    - l'`agent` ha la seguente funzione `step(s,a,r,s',a')` che viene chiamata per calcolare la loss, calcolare il gradiente e aggiornare la `q_network`

In [None]:
import gymnasium as gym
import jax.nn as nn
import jax.numpy as jnp
from loguru import logger
from tqdm.notebook import tqdm

import src.gymnasium_env
import wandb
from src.agent import DeepQLearningAgent
from src.config import init_wandb

%load_ext autoreload
%autoreload 2


def onehot_agent_goal_positions(agent, goal, grid_size=10):
    N = grid_size * grid_size
    agent_idx = jnp.ravel_multi_index((agent[1], agent[0]), (grid_size, grid_size))
    goal_idx = jnp.ravel_multi_index((goal[1], goal[0]), (grid_size, grid_size))

    oh_agent = nn.one_hot(agent_idx, N, dtype=jnp.int32)
    oh_goal = nn.one_hot(goal_idx, N, dtype=jnp.int32)

    return jnp.concatenate([oh_agent, oh_goal])


def train_dql_agent():
    config = init_wandb()
    wandb.define_metric("train/episode_reward", step_metric="episode")
    wandb.define_metric("epsilone", step_metric="global_step")

    logger.info("Starting training with config: {}", config)

    grid_size = config.grid_size
    env = gym.make("GridWorld-v0", size=grid_size)

    state_dim = 2 * grid_size * grid_size
    action_dim = env.action_space.n
    logger.debug("State dim: {}, Action dim: {}", state_dim, action_dim)

    decay_steps = int(config.n_episodes * config.max_steps_per_episode)
    agent = DeepQLearningAgent(config, state_dim, action_dim, decay_steps)

    for ep in tqdm(range(1, config.n_episodes + 1), desc="Episodes"):
        obs, _ = env.reset()
        state = onehot_agent_goal_positions(obs["agent"], obs["goal"], grid_size)
        done = False
        ep_reward = 0.0

        for _ in range(config.max_steps_per_episode):
            action = agent.act(state)
            next_obs, reward, done, _, _ = env.step(action)
            next_state = onehot_agent_goal_positions(
                next_obs["agent"], next_obs["goal"], grid_size
            )

            _ = agent.learn(state, action, reward, next_state, done)
            ep_reward += reward
            state = next_state

            wandb.log({"global_step": agent.step, "epsilon": agent.eps})

            if done:
                break

        wandb.log(
            {
                "episode": ep,
                "train/episode_reward": ep_reward,
            }
        )

    env.close()
    wandb.finish()


train_dql_agent()


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


[32m2025-05-03 14:37:42.711[0m | [1mINFO    [0m | [36m__main__[0m:[36mtrain_dql_agent[0m:[36m32[0m - [1mStarting training with config: {'batch_size': 32, 'lr': 0.001, 'gamma': 0.99, 'eps_start': 1.0, 'eps_end': 0.05, 'hidden_dim': 128, 'n_episodes': 500, 'max_steps_per_episode': 32, 'grid_size': 5, 'seed': 0, 'target_update_frequency': 4}[0m
[32m2025-05-03 14:37:42.711[0m | [34m[1mDEBUG   [0m | [36m__main__[0m:[36mtrain_dql_agent[0m:[36m39[0m - [34m[1mState dim: 50, Action dim: 4[0m


AttributeError: <class 'wandb.sdk.wandb_config.Config'> object has no attribute 'max_n_steps'