In [6]:
import jax
import jumanji
import chex


env = jumanji.environments.Maze()

In [36]:
key=jax.random.PRNGKey(10)
state, timestep = env.reset(key)
print(state)

State(agent_position=Position(row=Array(0, dtype=int32), col=Array(2, dtype=int32)), target_position=Position(row=Array(1, dtype=int32), col=Array(4, dtype=int32)), walls=Array([[False, False, False, False, False,  True, False, False, False,
        False],
       [False,  True, False,  True, False,  True, False,  True, False,
         True],
       [False,  True, False,  True, False,  True, False,  True, False,
        False],
       [False,  True,  True,  True, False,  True, False,  True,  True,
         True],
       [False,  True, False, False, False, False, False,  True, False,
         True],
       [False,  True, False,  True, False,  True, False,  True, False,
         True],
       [False,  True, False,  True, False,  True, False, False, False,
        False],
       [False,  True, False,  True,  True,  True,  True,  True,  True,
         True],
       [False,  True, False, False, False, False, False, False, False,
        False],
       [False,  True, False,  True, False,  Tr

In [37]:
target_position = state['target_position']
agent_position = state['agent_position']
print(target_position)
print(agent_position)

Position(row=Array(1, dtype=int32), col=Array(4, dtype=int32))
Position(row=Array(0, dtype=int32), col=Array(2, dtype=int32))


In [17]:
print(target_position[0])

1


In [40]:
new_state, timestep = env.step(state, 1)

In [41]:
print(new_state['agent_position'])

Position(row=Array(0, dtype=int32), col=Array(3, dtype=int32))


In [38]:
print(timestep)

TimeStep(step_type=Array(0, dtype=int8), reward=Array(0., dtype=float32), discount=Array(1., dtype=float32), observation=Observation(agent_position=Position(row=Array(0, dtype=int32), col=Array(2, dtype=int32)), target_position=Position(row=Array(1, dtype=int32), col=Array(4, dtype=int32)), walls=Array([[False, False, False, False, False,  True, False, False, False,
        False],
       [False,  True, False,  True, False,  True, False,  True, False,
         True],
       [False,  True, False,  True, False,  True, False,  True, False,
        False],
       [False,  True,  True,  True, False,  True, False,  True,  True,
         True],
       [False,  True, False, False, False, False, False,  True, False,
         True],
       [False,  True, False,  True, False,  True, False,  True, False,
         True],
       [False,  True, False,  True, False,  True, False, False, False,
        False],
       [False,  True, False,  True,  True,  True,  True,  True,  True,
         True],
      

In [39]:
print(timestep['reward'])

0.0


In [42]:
print(env)

Maze environment:
 - num_rows: 10
 - num_cols: 10
 - time_limit: 100
 - generator: <jumanji.environments.routing.maze.generator.RandomGenerator object at 0x000002CD856C9DC0>


In [44]:
walls = new_state['walls']

In [49]:
import jax.numpy as jnp

state_indices = jax.random.choice(key, jnp.arange(env.num_rows * env.num_cols), (1, ), replace=False, p=~walls.flatten())
print(state_indices)

[90]


In [50]:
(state_row, state_col) = jnp.divmod(state_indices, env.num_cols)
print(state_row[0], state_col[0])

9 0


In [98]:
import numpy as np
from jumanji.environments.routing.maze.types import Position, State


def rollin_mdp(env, rollin_type, optimal_actions, seed = 10):
    states = []
    actions = []
    next_states = []
    rewards = []

    key = jax.random.PRNGKey(seed)

    state, timestep = env.reset(key)

    goal_state = state['target_position']
    walls = state['walls']
    maze_key = state['key']

    for i in range(env.time_limit):
        if rollin_type == 'uniform':
            state = sample_state(env, walls, goal_state, maze_key, i)
            action = sample_action()  
        elif rollin_type == 'expert':
            action = optimal_actions[state] 
        else:
            raise NotImplementedError
        
        next_state, timestep = env.step(state, action)
        reward = timestep['reward']


        agent_position = state['agent_position']  # TODO this is a Position type. which kind of data structure do we want? (E.G. array/list)
        next_agent_position = next_state['agent_position']

        states.append([agent_position[0], agent_position[1]])
        actions.append(action)
        next_states.append([next_agent_position[0], next_agent_position[1]])
        rewards.append(reward)
        state = next_state

    states = np.array(states)
    actions = np.array(actions)
    next_states = np.array(next_states)
    rewards = np.array(rewards)

    return states, actions, next_states, rewards, goal_state, walls

def find_optimal_actions(env): #TODO

    #it should be a dictionary where the keys are 'Position' types and the values are the corresponding actions

    #raise NotImplementedError

    return [[sample_action() for j in range(env.num_cols)] for i in range(env.num_rows)]

def sample_state(env, walls, target_position = None, maze_key = None, step_count = None):

    seed = np.random.randint(low = 0, high = env.num_rows*env.num_cols)

    key = jax.random.PRNGKey(seed)
    state_indices = jax.random.choice(key, jnp.arange(env.num_rows * env.num_cols), (1,), replace=False, p=~walls.flatten())
    (state_row, state_col) = jnp.divmod(state_indices, env.num_cols)

    agent_position = Position(row = state_row[0], col = state_col[0])

    if target_position is None:
        return agent_position #in this case we only want to return the agent position and not the full state
    else:
        return State(agent_position=agent_position, target_position=target_position, walls=walls, action_mask=jnp.array([True, True, True, True]), key=maze_key, step_count=jnp.array(step_count+1, jnp.int32))


def sample_action():

    return np.random.choice([1, 2, 3, 4])

def generate_maze_histories_from_envs(envs, n_hists, n_samples, rollin_type):
    trajs = []
    for env in envs:

        optimal_actions = find_optimal_actions(env) #TODO: since in generate_mdp_histories the optimal actions are required and usually they are part

        for j in range(n_hists):
            (context_states, context_actions, context_next_states, context_rewards, goal_state, walls) = rollin_mdp(env, rollin_type=rollin_type, optimal_actions=optimal_actions)
            
            for k in range(n_samples):
                query_state = sample_state(env, walls) 
                print(query_state)
                optimal_action = optimal_actions(query_state)

                traj = {
                    'query_state': query_state,
                    'optimal_action': optimal_action,
                    'context_states': context_states,
                    'context_actions': context_actions,
                    'context_next_states': context_next_states,
                    'context_rewards': context_rewards,
                    'goal': goal_state,
                }

                trajs.append(traj)
    return trajs

    
def generate_maze_histories(horizon, n_envs, **kwargs):

    envs = [jumanji.environments.Maze(time_limit = horizon) for _ in range(n_envs)]
    trajs = generate_maze_histories_from_envs(envs, **kwargs)
                                              
    return trajs


In [84]:
optimal_actions = find_optimal_actions(env)
states, actions, next_states, rewards, goal_state, walls = rollin_mdp(env, 'uniform', optimal_actions)

In [85]:
print(states)

[[2 4]
 [3 0]
 [0 7]
 [9 2]
 [8 2]
 [9 0]
 [3 0]
 [2 4]
 [1 6]
 [2 2]
 [0 2]
 [5 4]
 [8 8]
 [1 8]
 [4 5]
 [8 7]
 [4 2]
 [9 8]
 [3 4]
 [6 6]
 [4 6]
 [0 7]
 [9 2]
 [8 9]
 [9 0]
 [9 0]
 [9 2]
 [2 8]
 [8 8]
 [0 7]
 [4 2]
 [4 0]
 [8 5]
 [6 2]
 [2 2]
 [8 5]
 [0 8]
 [4 8]
 [4 2]
 [8 2]
 [8 8]
 [8 5]
 [8 8]
 [2 4]
 [9 0]
 [2 2]
 [8 2]
 [6 0]
 [2 8]
 [6 6]
 [3 0]
 [3 0]
 [1 8]
 [3 4]
 [4 0]
 [6 6]
 [4 2]
 [6 6]
 [6 2]
 [2 2]
 [1 0]
 [5 8]
 [5 2]
 [1 8]
 [9 2]
 [0 7]
 [9 2]
 [2 9]
 [0 2]
 [1 0]
 [4 2]
 [8 2]
 [2 0]
 [5 2]
 [0 7]
 [1 2]
 [0 2]
 [2 2]
 [0 2]
 [4 5]
 [2 8]
 [4 8]
 [9 2]
 [4 8]
 [9 0]
 [1 2]
 [3 4]
 [1 8]
 [4 0]
 [0 7]
 [4 0]
 [2 8]
 [6 0]
 [4 4]
 [1 8]
 [9 8]
 [4 5]
 [6 6]
 [6 2]
 [0 7]]


In [99]:
trajectories = generate_maze_histories(10, 1, n_hists=1, n_samples=5, rollin_type = 'uniform')
print(trajectories)

TypeError: 'list' object is not callable