In [7]:
import warnings; warnings.filterwarnings('ignore')

import itertools
import time
import numpy as np
from pprint import pprint

from collections import defaultdict

import gymnasium as gym
from gridworld.envs import SiblingGridWorldEnv
from gymnasium.envs.registration import register
from tabulate import tabulate
import tqdm as tqdm

import random
import matplotlib
import matplotlib.pyplot as plt
SEEDS = (12, 34, 56, 78, 90)

%matplotlib inline

from utils import *
from sibling_gw_agent import SiblingGWAgent

In [2]:
plt.style.use('fivethirtyeight')
params = {
    'figure.figsize': (15, 8),
    'font.size': 24,
    'legend.fontsize': 20,
    'axes.titlesize': 28,
    'axes.labelsize': 24,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20
}
plt.rcParams.update(params)
np.set_printoptions(precision=3, suppress=True)

In [3]:
def sibling():
    env = SiblingGridWorldEnv(P_gridworld)
#     env = RelativePositionenv)
    return env

register(
    id='SiblingGridWorld-v0',
    entry_point=sibling,
    max_episode_steps=300,
)

env = gym.make('SiblingGridWorld-v0')
env = env.unwrapped
# env.render_mode = 'human'
obs, info = env.reset()

In [4]:
n_episodes = 30_000
agent = SiblingGWAgent(env, n_episodes=n_episodes)

In [5]:
for episode in tqdm(range(n_episodes)):
    agent.episode = episode
    state, info = env.reset()
    done = False
    while not done:
        action = agent.select_action(state)
        action = np.unravel_index(action, env.action_space.nvec, order='F')
        next_state, reward, terminated, truncated, info = agent.env.step(action)
        
        # update the agent
        agent.update(state, action, reward, terminated, next_state)

        # update if the environment is done or truncated
        done = terminated or truncated
        state = next_state

100%|██████████| 3000/3000 [00:52<00:00, 56.61it/s] 


In [25]:
env.render_mode = 'human'
obs, info = env.reset()
obs

array([ 1,  2, 17])

In [26]:
for i in range(300):
    obs_idx = np.ravel_multi_index(obs, env.observation_space.nvec, order='F')
    action = np.argmax(agent.Q[obs_idx])
    action = np.unravel_index(action, env.action_space.nvec, order='F')
    obs, rew, done, trunc, info = env.step(action)
    time.sleep(0.05)
    if done:
        print(f"Finished in {i} steps.")
        break

Finished in 10 steps.
