In [1]:
import getopt
import random
import sys
import time
from typing import List

import numpy as np

from flatland.core.env_observation_builder import ObservationBuilder
from flatland.core.grid.grid4_utils import get_new_position
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator,random_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator
from flatland.utils.misc import str2bool
from flatland.utils.rendertools import RenderTool

random.seed(100)
np.random.seed(100)



In [2]:
from flatland.envs.observations import TreeObsForRailEnv
class SingleAgentNavigationObs(ObservationBuilder):
    """
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
        super().__init__()

    def reset(self):
        pass

    def get(self, handle: int = 0) -> List[int]:
        agent = self.env.agents[handle]

        if agent.position:
            possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
            print("Agent Postion " , *agent.position , "Agent direction" ,  agent.direction, "Transitions :" , possible_transitions )
        else:
            possible_transitions = self.env.rail.get_transitions(*agent.initial_position, agent.direction)

        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
                    if agent.position:
                        new_position = get_new_position(agent.position, direction)
                    else:
                        new_position = get_new_position(agent.initial_position , direction)
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                    
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
            observation[np.argmin(min_distances)] = 1

        return observation

In [3]:
env = RailEnv(width=10,
              height=10,
              rail_generator=random_rail_generator(),
              number_of_agents=1,
              obs_builder_object=SingleAgentNavigationObs())
env.reset()

obs, all_rewards, done, _ = env.step({0: 0})

env_renderer = RenderTool(env, gl="PILSVG", show_debug=True)
env_renderer.render_env(show=True, frames=True, show_observations=True)


Hi there
[[inf inf inf inf inf inf inf inf inf inf]
 [inf inf inf 16. inf inf inf inf  9. 14.]
 [inf inf inf 19. inf inf inf inf 12. 11.]
 [inf 14. 21. 20. inf 10.  9. inf  7.  6.]
 [inf inf 12. 21. inf inf inf  7.  6.  5.]
 [inf inf inf 12. 11.  0.  9.  8. inf  8.]
 [inf 17. inf 17. inf  9. 10.  9. inf inf]
 [inf 16. 15.  8.  7.  2. 11. inf inf inf]
 [inf 19. 24. inf  8.  7.  6.  5. inf inf]
 [inf inf inf inf inf inf inf inf inf inf]]
Hi there
[[inf inf inf inf inf inf inf inf inf inf]
 [inf inf inf 16. inf inf inf inf  9. 14.]
 [inf inf inf 19. inf inf inf inf 12. 11.]
 [inf 14. 21. 20. inf 10.  9. inf  7.  6.]
 [inf inf 12. 21. inf inf inf  7.  6.  5.]
 [inf inf inf 12. 11.  0.  9.  8. inf  8.]
 [inf 17. inf 17. inf  9. 10.  9. inf inf]
 [inf 16. 15.  8.  7.  2. 11. inf inf inf]
 [inf 19. 24. inf  8.  7.  6.  5. inf inf]
 [inf inf inf inf inf inf inf inf inf inf]]


  Observation builder needs to populate: env.dev_obs_dict")


In [4]:

for _ in range(1):
    action = np.argmax(obs[0])+1
#     action2 = np.argmax(obs[1]) + 1
    obs, all_rewards, done, _ = env.step({0:action})
    print("Rewards: ", all_rewards, "  [done=", done, "]")

    env_renderer.render_env(show=True, frames=True, show_observations=False)
    time.sleep(3.0)

Agent Postion  8 0 Agent direction 3 Transitions : (0, 1, 0, 0)
Hi there
[[inf inf inf inf inf inf inf inf inf inf]
 [inf inf inf 16. inf inf inf inf  9. 14.]
 [inf inf inf 19. inf inf inf inf 12. 11.]
 [inf 14. 21. 20. inf 10.  9. inf  7.  6.]
 [inf inf 12. 21. inf inf inf  7.  6.  5.]
 [inf inf inf 12. 11.  0.  9.  8. inf  8.]
 [inf 17. inf 17. inf  9. 10.  9. inf inf]
 [inf 16. 15.  8.  7.  2. 11. inf inf inf]
 [inf 19. 24. inf  8.  7.  6.  5. inf inf]
 [inf inf inf inf inf inf inf inf inf inf]]
Rewards:  {0: -1.0}   [done= {0: False, '__all__': False} ]


In [6]:
env_renderer.render_env(show=True, frames=True, show_observations=False)