In [1]:
import warnings; warnings.filterwarnings('ignore')

import itertools
import time
import numpy as np
from pprint import pprint

from collections import defaultdict

import gymnasium as gym
from gridworld.envs import SiblingGridWorldEnv
from gymnasium.envs.registration import register
from gymnasium.wrappers import TimeLimit
from tabulate import tabulate
import tqdm as tqdm

import torch
from torch.utils.tensorboard import SummaryWriter

import random
import matplotlib
import matplotlib.pyplot as plt
SEEDS = (12, 34, 56, 78, 90)

%matplotlib inline

from utils import *
from sibling_gw_agent import SiblingGWAgent

In [2]:
plt.style.use('fivethirtyeight')
params = {
    'figure.figsize': (15, 8),
    'font.size': 24,
    'legend.fontsize': 20,
    'axes.titlesize': 28,
    'axes.labelsize': 24,
    'xtick.labelsize': 20,
    'ytick.labelsize': 20
}
plt.rcParams.update(params)
np.set_printoptions(precision=3, suppress=True)

layout = {
    "Training": {
        "Q_gw": ["Multiline", ["Q_gw/center", "Q_gw/corner"]],
        "Q_bandit": ["Multiline", ["Q_bandit/zero", "Q_bandit/twenty", "Q_bandit/correct"]],
    },
}

writer = SummaryWriter('runs/sibling_gw', comment="Sibling_GW")
writer.add_custom_scalars(layout)

In [47]:
def sibling():
    env = SiblingGridWorldEnv(P_gridworld)
    env = TimeLimit(env, max_episode_steps=100)
#     env = RelativePositionenv)
    return env

register(
    id='SiblingGridWorld-v0',
    entry_point=sibling,
    max_episode_steps=100,
)

env = gym.make('SiblingGridWorld-v0')
env = env.unwrapped
 #env.render_mode = 'human'
obs, info = env.reset(options={'randomize_world': True})
print(env._true_world_idx)

1


In [48]:
n_episodes = 1
agent = SiblingGWAgent(env, gamma=1.0, 
            init_alpha=0.5, min_alpha=0.05, alpha_decay_ratio=0.5, 
            init_epsilon=1.0, min_epsilon=0.1, epsilon_decay_ratio=0.9, 
            n_episodes=n_episodes)
# agent = SiblingGWAgent(env, min_epsilon=0.5, epsilon_decay_ratio=0.9, n_episodes=n_episodes)
# agent.episode = 5000
# agent.epsilons[agent.episode]

In [49]:
for episode in tqdm(range(agent.episode, agent.episode + n_episodes)):
    agent.episode = episode
    state, info = env.reset()
    # if episode % 3_500 == 0:
    #     state, info = env.reset(options={'randomize_world': True})
    done = False

    while not done:
        #action = agent.select_action(state)
        action = agent.custom_action(state)[1]
        next_state, reward, terminated, truncated, info = env.step(action)
        
        # update the agent
        agent.update(state, action, reward, terminated, next_state)

        # update if the environment is done or truncated
        done = terminated or truncated
        state = next_state
        # time.sleep(1)

    if episode % 1 == 0:
        writer.add_scalar("Q_gw/center",
            np.max(agent.Q_gw[agent.state_multi_to_lin(np.array([2, 2]))]), 
            episode
        )
        writer.add_scalar("Q_gw/corner",
            np.max(agent.Q_gw[agent.state_multi_to_lin(np.array([0, 0]))]), 
            episode
        )
        writer.add_scalar("Q_bandit/zero",
            np.max(agent.Q_bandit[0]), 
            episode
        )
        writer.add_scalar("Q_bandit/twenty",
            np.max(agent.Q_bandit[20]), 
            episode
        )
        writer.add_scalar("Q_bandit/correct",
            np.max(agent.Q_bandit[env._true_world_idx]), 
            episode
        )
        # writer.add_image("Q_gw", 
        #     -agent.Q_gw/np.max(np.abs(agent.Q_gw)), episode, dataformats='HW')

# time.sleep(2)
writer.flush()
writer.close()

env.close()
print(env.num_moves)

100%|██████████| 1/1 [00:00<00:00, 382.97it/s]

6





In [45]:
agent.Q_bandit

array([-1.  , -0.5 , -1.  , -0.5 , -1.  , -1.  , -0.4 ,  0.  , -1.  ,
        0.  , -1.  ,  0.  , -0.4 ,  0.  , -1.  ,  0.  , -1.  ,  0.  ,
       -0.25, -0.25, -1.  , -1.  , -1.  , -1.  ], dtype=float32)

In [18]:
max_indices = np.where(agent.Q_bandit == np.max(agent.Q_bandit))[0]
max_indices[np.where(np.where(agent.Q_bandit == np.max(agent.Q_bandit))[0] == env._true_world_idx)[0]]

array([19])

In [None]:
rng = np.random.default_rng()
env.render_mode = None
# obs, info = env.reset(options={'randomize_world': True})
obs, _ = env.reset()
env._agent_location = np.array([2, 2])
obs = env._agent_location

avg_steps = 100
alpha = 0.1
running_avg = [10]
max_ep_len = 20
for e in tqdm(range(1, 1_000)):
    for i in range(1, max_ep_len):
        # action = agent.greedy_action(obs)
        action = agent.custom_action(obs)[1]
        obs, rew, done, trunc, info = env.step(action)
        # print(obs, rew, info)
        # time.sleep(0.05)
        if done or i == max_ep_len - 1:
            # print(f"Finished in {i} steps.")
            avg_steps += 1/e * (i - avg_steps)
            running_avg += [(1-alpha) * running_avg[-1] + alpha * avg_steps]
            break
    obs, _ = env.reset()
    env._agent_location = np.array([2, 2])
    obs = env._agent_location
    # if e % 100 == 0:
    #     print(f"Episode {e} finished, current average estimate is {avg_steps}.")

print(f"Cumulative average estimate: {avg_steps}")

In [None]:
plt.plot(running_avg)

In [None]:
# env.render_mode = 'human'
obs, info = env.reset()
env._agent_location = np.array([2, 2])
obs = env._agent_location
# print(np.concatenate([env._agent_location, env._world_belief]))
env._render_frame()
# time.sleep(1.0)
print(env._true_world_idx)

for i in range(1,max_ep_len+1):
    action = agent.greedy_action(obs)
    obs, rew, done, trunc, info = env.step(action)
    # print(obs, rew, info)
    # time.sleep(0.05)
    if done or trunc:
        print(f"Finished in {i} steps.")
        break

In [None]:
np.max(agent.Q_bandit), np.argmax(agent.Q_bandit), env._true_world_idx

In [None]:
print(agent.Q_bandit)
print(agent.N_bandit)

In [None]:
# env.close()