In [1]:
%matplotlib inline

import jax
import jumanji
from jumanji.wrappers import AutoResetWrapper
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import pickle

## RL FUNCTIONS

In [2]:
def state_to_index(x, y, columns):
        return x * columns + y

def index_to_state(idx, columns):
    x = idx // columns
    y = idx % columns
    return x, y

def generate_transition_matrix(walls):
    rows = walls.shape[0]  # Number of rows
    columns = walls.shape[1]  # Number of columns
    num_states = rows * columns
    num_actions = 4  # up, right, down, left

    # Transition matrix P(s'|s,a)
    P = np.zeros((num_states, num_states, num_actions))

    # Actions: up (0), right (1), down (2), left (3)
    actions = {
        0: (-1, 0),  # Up
        1: (0, 1),   # Right
        2: (1, 0),   # Down
        3: (0, -1)   # Left
    }

    # Populate the transition matrix
    for x in range(rows):
        for y in range(columns):
            current_state = state_to_index(x, y, columns)
            for action, (dx, dy) in actions.items():
                new_x, new_y = x + dx, y + dy
                if 0 <= new_x < rows and 0 <= new_y < columns and not walls[new_x, new_y]:
                    next_state = state_to_index(new_x, new_y, columns)
                else:
                    next_state = current_state  # Stay in place on wall or out-of-bounds
                
                P[next_state, current_state, action] = 1

    return P

def generate_reward_function(target_position, walls):
    rows = walls.shape[0]  # Number of rows
    columns = walls.shape[1]  # Number of columns
    num_states = rows * columns
    num_actions = 4  # up, right, down, left
    rewards = np.zeros(rows * columns)
    target_index = state_to_index(target_position[0], target_position[1], columns)
    rewards[target_index] = 1
    return rewards

def value_iteration(P, reward, discount, precision=1e-5):
    state_size = P.shape[0]
    action_size = P.shape[2]
    value = np.zeros(state_size)
    prev_value = np.zeros(state_size)
    pi_vi = np.zeros((action_size, state_size))
    for i in range(300):
        prev_value = value.copy()
        for state in range(state_size):
            value[state] = np.max(
                [
                    reward[state] + discount * np.sum(P[:, state, action] * value)
                    for action in range(action_size)
                ]
            )

    for state in range(state_size):
        values = np.array(
            [
                reward[state] + discount * np.sum(P[:, state, action] * value)
                for action in range(action_size)
            ]
        )
        values = np.round(values, decimals=5)
        maximum = np.max(values)
        indices = np.argwhere(values == maximum).flatten()

        # chosen_index = np.random.choice(indices)
        # pi_vi[chosen_index, state] = 1

        # pi_vi[indices, state] = 1/(len(indices))

        best_action = np.argmax([reward[state] + discount * np.sum(P[:,state,action] * value) for action in range(action_size)])
        pi_vi[best_action, state] = 1

    return pi_vi

def plot_policy(policy, grid, title):
    n, m = grid.shape
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(grid, cmap='gray')

    for i in range(n):
        for j in range(m):
            if grid[i, j]:
                continue  # Skip walls
            state = i * m + j
            if policy[0, state] > 0:  # Up
                # print("up")
                ax.arrow(j, i, 0, -0.5, head_width=0.2, head_length=0.2, fc='red', ec='red')
            if policy[1, state] > 0:  # Right
                ax.arrow(j, i, 0.5, 0, head_width=0.2, head_length=0.2, fc='green', ec='green')
                # print("right")
            if policy[2, state] > 0:  # Down
                ax.arrow(j, i, 0, 0.5, head_width=0.2, head_length=0.2, fc='blue', ec='blue')
                # print("down")
            if policy[3, state] > 0:  # Left
                # print("left")
                ax.arrow(j, i, -0.5, 0, head_width=0.2, head_length=0.2, fc='yellow', ec='yellow')
    
    ax.set_xticks(np.arange(-0.5, m, 1))
    ax.set_yticks(np.arange(-0.5, n, 1))
    plt.savefig(title)

In [10]:
env = jumanji.make("Maze-v0")  # Create a Snake environment
# env = AutoResetWrapper(env)     # Automatically reset the environment when an episode terminates

In [11]:
batch_size = 50
rollout_length = 1000
num_actions = env.action_spec.num_values

In [12]:
random_key = jax.random.PRNGKey(1)
key1, key2 = jax.random.split(random_key)

In [14]:
def step_fn(state, key):
  action = jax.random.randint(key=key, minval=0, maxval=num_actions, shape=())
  new_state, timestep = env.step(state, action)
  return new_state, {
                      "state": [state.agent_position],
                      "action": action, 
                      "next_state": [new_state.agent_position],
                      "reward": timestep.reward,
                      "whole_timestep": timestep
                    }

def run_n_steps(state, key, n):
  random_keys = jax.random.split(key, n)
  state, rollout = jax.lax.scan(step_fn, state, random_keys)
  return rollout

# Instantiate a batch of environment states
keys_init = jax.random.split(key1, batch_size)
state, timestep = jax.vmap(env.reset)(keys_init)

# Collect a batch of rollouts
keys_rollout = jax.random.split(key2, batch_size)
rollout = jax.vmap(run_n_steps, in_axes=(0, 0, None))(state, keys_rollout, rollout_length)

In [15]:
state

State(agent_position=Position(row=Array([0, 0, 0, 8, 4, 3, 2, 5, 3, 9, 8, 2, 3, 6, 4, 6, 2, 9, 4, 5, 0, 9,
       1, 9, 5, 7, 0, 4, 0, 1, 0, 2, 9, 0, 4, 7, 2, 8, 6, 0, 4, 9, 6, 3,
       0, 7, 0, 0, 6, 2], dtype=int32), col=Array([3, 6, 5, 1, 2, 0, 7, 2, 6, 0, 0, 0, 2, 5, 9, 7, 2, 6, 2, 2, 9, 0,
       8, 6, 2, 2, 0, 8, 4, 0, 7, 2, 2, 9, 3, 2, 6, 9, 8, 0, 0, 6, 1, 0,
       0, 4, 8, 7, 6, 3], dtype=int32)), target_position=Position(row=Array([4, 6, 4, 2, 2, 4, 3, 6, 2, 5, 2, 6, 0, 6, 1, 0, 2, 2, 0, 6, 3, 8,
       6, 2, 4, 0, 8, 9, 5, 0, 4, 1, 2, 2, 1, 0, 8, 0, 8, 6, 8, 7, 7, 0,
       2, 2, 0, 0, 7, 6], dtype=int32), col=Array([7, 3, 5, 9, 9, 0, 6, 8, 2, 2, 3, 0, 0, 8, 6, 2, 1, 8, 6, 4, 2, 7,
       6, 2, 2, 7, 4, 0, 6, 6, 8, 2, 9, 4, 2, 6, 6, 4, 0, 4, 0, 0, 6, 8,
       2, 0, 9, 6, 6, 4], dtype=int32)), walls=Array([[[False,  True, False, ...,  True, False,  True],
        [False,  True, False, ...,  True, False,  True],
        [False, False, False, ..., False, False, False],
      

In [9]:
data = []
for k in range(batch_size):
    walls = rollout["whole_timestep"].observation.walls[k][0]
    target_position = np.array([rollout["whole_timestep"].observation.target_position.row[k][0], rollout["whole_timestep"].observation.target_position.col[k][0]])

    P = generate_transition_matrix(walls)
    r = generate_reward_function(target_position, walls)
    pi_opt = value_iteration(P, r, 0.99999)
    plot_policy(pi_opt, grid=np.array(walls), title=f"policies/policy_{k}.png")

    data.append(
        {
        "optimal_policy": pi_opt,
        "context_actions": np.array(rollout["action"][k]),
        "context_states": np.array(jnp.vstack((rollout["state"][0].row[k], rollout["state"][0].col[k]))),
        "context_next_states": np.array(jnp.vstack((rollout["next_state"][0].row[k], rollout["next_state"][0].col[k]))),
        "context_rewards": np.array(rollout["reward"][k]),
        "env_key": keys_init[k],
        "rollout_key": keys_rollout[k],
        }
    )


<IPython.core.display.Javascript object>

  self.comm = Comm('matplotlib', data={'id': self.uuid})


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

  fig, ax = plt.subplots(figsize=(10, 10))


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

FileNotFoundError: [Errno 2] No such file or directory: '/data/data.pkl'