In [11]:
import numpy as np
from itertools import permutations
import copy
import time

import gymnasium as gym
from gymnasium import spaces
from gridworld.envs import GridWorldEnv, SiblingGridWorldEnv
from gymnasium.wrappers import FlattenObservation
from gridworld.wrappers import RelativePosition, InvertedReward
from gymnasium.envs.registration import register

In [2]:
# Define P[s][a] for the vanilla GridWorld
P = {
    0: {
        0: [(1.0, 1, -1, False)],
        1: [(1.0, 0, -1, False)],
        2: [(1.0, 0, -1, False)],
        3: [(1.0, 5, -1, False)],
    },
    1: {
        0: [(1.0, 2, -1, False)],
        1: [(1.0, 1, -1, False)],
        2: [(1.0, 0, -1, False)],
        3: [(1.0, 6, -1, False)],
    },
    2: {
        0: [(1.0, 3, -1, False)],
        1: [(1.0, 2, -1, False)],
        2: [(1.0, 1, -1, False)],
        3: [(1.0, 7, -1, False)],
    },
    3: {
        0: [(1.0, 4, -1, False)],
        1: [(1.0, 3, -1, False)],
        2: [(1.0, 2, -1, False)],
        3: [(1.0, 8, -1, False)],
    },
    4: {
        0: [(1.0, 4, -1, False)],
        1: [(1.0, 4, -1, False)],
        2: [(1.0, 3, -1, False)],
        3: [(1.0, 9, -1, False)],
    },
    5: {
        0: [(1.0, 6, -1, False)],
        1: [(1.0, 0, -1, False)],
        2: [(1.0, 5, -1, False)],
        3: [(1.0, 10, -1, False)],
    },
    6: {
        0: [(1.0, 7, -1, False)],
        1: [(1.0, 1, -1, False)],
        2: [(1.0, 5, -1, False)],
        3: [(1.0, 11, -1, False)],
    },
    7: {
        0: [(1.0, 8, -1, False)],
        1: [(1.0, 2, -1, False)],
        2: [(1.0, 6, -1, False)],
        3: [(1.0, 12, -1, False)],
    },
    8: {
        0: [(1.0, 9, -1, False)],
        1: [(1.0, 3, -1, False)],
        2: [(1.0, 7, -1, False)],
        3: [(1.0, 13, -1, False)],
    },
    9: {
        0: [(1.0, 9, -1, False)],
        1: [(1.0, 4, -1, False)],
        2: [(1.0, 8, -1, False)],
        3: [(1.0, 14, -1, False)],
    },
    10: {
        0: [(1.0, 11, -1, False)],
        1: [(1.0, 5, -1, False)],
        2: [(1.0, 10, -1, False)],
        3: [(1.0, 15, -1, False)],
    },
    11: {
        0: [(1.0, 12, -1, False)],
        1: [(1.0, 6, -1, False)],
        2: [(1.0, 10, -1, False)],
        3: [(1.0, 16, -1, False)],
    },
    12: {
        0: [(1.0, 13, -1, False)],
        1: [(1.0, 7, -1, False)],
        2: [(1.0, 11, -1, False)],
        3: [(1.0, 17, -1, False)],
    },
    13: {
        0: [(1.0, 14, -1, False)],
        1: [(1.0, 8, -1, False)],
        2: [(1.0, 12, -1, False)],
        3: [(1.0, 18, -1, False)],
    },
    14: {
        0: [(1.0, 14, -1, False)],
        1: [(1.0, 9, -1, False)],
        2: [(1.0, 13, -1, False)],
        3: [(1.0, 19, -1, False)],
    },
    15: {
        0: [(1.0, 16, -1, False)],
        1: [(1.0, 10, -1, False)],
        2: [(1.0, 15, -1, False)],
        3: [(1.0, 20, -1, False)],
    },
    16: {
        0: [(1.0, 17, -1, False)],
        1: [(1.0, 11, -1, False)],
        2: [(1.0, 15, -1, False)],
        3: [(1.0, 21, -1, False)],
    },
    17: {
        0: [(1.0, 18, -1, False)],
        1: [(1.0, 12, -1, False)],
        2: [(1.0, 16, -1, False)],
        3: [(1.0, 22, -1, False)],
    },
    18: {
        0: [(1.0, 19, -1, False)],
        1: [(1.0, 13, -1, False)],
        2: [(1.0, 17, -1, False)],
        3: [(1.0, 23, -1, False)],
    },
    19: {
        0: [(1.0, 19, -1, False)],
        1: [(1.0, 14, -1, False)],
        2: [(1.0, 18, -1, False)],
        3: [(1.0, 24, -1, False)],
    },
    20: {
        0: [(1.0, 21, -1, False)],
        1: [(1.0, 15, -1, False)],
        2: [(1.0, 20, -1, False)],
        3: [(1.0, 20, -1, False)],
    },
    21: {
        0: [(1.0, 22, -1, False)],
        1: [(1.0, 16, -1, False)],
        2: [(1.0, 20, -1, False)],
        3: [(1.0, 21, -1, False)],
    },
    22: {
        0: [(1.0, 23, -1, False)],
        1: [(1.0, 17, -1, False)],
        2: [(1.0, 21, -1, False)],
        3: [(1.0, 22, -1, False)],
    },
    23: {
        0: [(1.0, 24, -1, True)],
        1: [(1.0, 18, -1, False)],
        2: [(1.0, 22, -1, False)],
        3: [(1.0, 23, -1, False)],
    },
    24: {
        0: [(1.0, 24, 0, True)],
        1: [(1.0, 24, 0, True)],
        2: [(1.0, 24, 0, True)],
        3: [(1.0, 24, 0, True)],
    },
}

In [3]:
def value_iteration(P, gamma=1.0, theta=1e-10):
    V = np.zeros(len(P), dtype=np.float64)
    while True:
        Q = np.zeros((len(P), len(P[0])), dtype=np.float64)
        for s in range(len(P)):
            for a in range(len(P[s])):
                for prob, next_state, reward, done in P[s][a]:
                    Q[s][a] += prob * (reward + gamma * V[next_state] * (not done))
        if np.max(np.abs(V - np.max(Q, axis=1))) < theta:
            break
        V = np.max(Q, axis=1)
    pi = lambda s: {s:a for s, a in enumerate(np.argmax(Q, axis=1))}[s]
    return Q, V, pi

In [4]:
def sibling():
    env = SiblingGridWorldEnv(P)
#     env = RelativePositionenv)
    return env

register(
    id='SiblingGridWorld-v0',
    entry_point=sibling,
    max_episode_steps=300,
)

In [12]:
env = gym.make('SiblingGridWorld-v0')
env = env.unwrapped
env.render_mode = 'human'
obs, info = env.reset()
obs

array([ 3,  1, 14])

In [13]:
Q, V, pi = value_iteration(env.gw_P)

In [14]:
for i in range(20):
    obs, rew, done, trunc, info = env.step([pi(5*obs[1]+obs[0]), 0])
    Q, V, pi = value_iteration(env.cur_P)
    time.sleep(0.2)
    if done:
        break
# obs, rew, done, trunc, info = env.step(np.array([1, 2]))
# obs, rew, done, trunc, info

In [15]:
env.close()

In [10]:
rng = np.random.default_rng()
rng.integers(0, 5, size=2, dtype=int)

array([1, 2])