# Part 2: Observations

One of the main objectives of the Flatland challenge is to find suitable observations to solve the task at hand. 

In this notebook, we will see the existing observations that can be used, how you can create custom ones, and how to render them visually.

# Setup

In [0]:
# Install Flatland
%cd /content
!git clone https://gitlab.aicrowd.com/flatland/flatland.git/ --branch 223_UpdateEditor_55_notebooks
%cd flatland
!pip install -e .

In [0]:
import PIL
from flatland.utils.rendertools import RenderTool

def render_env(env):
    env_renderer = RenderTool(env, gl="PILSVG")
    env_renderer.render_env()

    image = env_renderer.get_image()
    pil_image = PIL.Image.fromarray(image)
    display(pil_image)
    return image

# The big picture

In [0]:
from flatland.envs.rail_generators import random_rail_generator

transition_prob = [
    1.0,  # Type 0 - empty cell
    1.0,  # Type 1 - straight
    1.0,  # Type 2 - simple switch
    0.3,  # Type 3 - diamond crossing
    0.5,  # Type 4 - single slip
    0.5,  # Type 5 - double slip
    0.2,  # Type 6 - symmetrical
    0.0,  # Type 7 - dead end
    0.2,  # Type 8 - turn left
    0.2,  # Type 9 - turn right
    1.0  # Type 10 - mirrored switch
]

rail_generator = random_rail_generator(cell_type_relative_proportion=transition_prob)

In [0]:
from flatland.core.env_observation_builder import ObservationBuilder

# Observations are fully customizable
class SimpleObs(ObservationBuilder):
    """
    Simplest observation builder. The object returns observation vectors with 5 identical components,
    all equal to the ID of the respective agent.
    """

    def reset(self):
        return

    def get(self, handle):
        observation = handle * np.ones(5)
        return observation

In [0]:
import pprint
import numpy as np
from flatland.envs.rail_env import RailEnv

# Create a RailEnv that uses SimpleObs
env = RailEnv(
    width=15,
    height=15,
    rail_generator=rail_generator,
    obs_builder_object=SimpleObs(),
    number_of_agents=1
)

observations, info = env.reset()

print('Observations:')
pprint.pprint(observations)

In [0]:
from flatland.envs.observations import TreeObsForRailEnv
from flatland.core.grid.grid4_utils import get_new_position
from typing import List

# Build observations which indicate the shortest path to the target
class SingleAgentNavigationObs(ObservationBuilder):
    """
    We build a representation vector with 3 binary components, indicating which of the 3 available directions
    for each agent (Left, Forward, Right) lead to the shortest path to its target.
    E.g., if taking the Left branch (if available) is the shortest route to the agent's target, the observation vector
    will be [1, 0, 0].
    """

    def __init__(self):
        super().__init__()

    def reset(self):
        pass

    def get(self, handle: int = 0) -> List[int]:
        agent = self.env.agents[handle]

        if agent.position:
            possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
        else:
            possible_transitions = self.env.rail.get_transitions(*agent.initial_position, agent.direction)

        num_transitions = np.count_nonzero(possible_transitions)

        # Start from the current orientation, and see which transitions are available;
        # organize them as [left, forward, right], relative to the current orientation
        # If only one transition is possible, the forward branch is aligned with it.
        if num_transitions == 1:
            observation = [0, 1, 0]
        else:
            min_distances = []
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                if possible_transitions[direction]:
                    new_position = get_new_position(agent.position, direction)
                    min_distances.append(
                        self.env.distance_map.get()[handle, new_position[0], new_position[1], direction])
                else:
                    min_distances.append(np.inf)

            observation = [0, 0, 0]
            observation[np.argmin(min_distances)] = 1

        return observation

In [0]:
# Create a RailEnv that uses SingleAgentNavigationObs
env = RailEnv(
    width=20,
    height=20,
    rail_generator=random_rail_generator(),
    number_of_agents=1,
    obs_builder_object=SingleAgentNavigationObs()
)

obs, info = env.reset()

In [0]:
action_to_direction = {0: 'no-op', 1: 'left', 2: 'forward', 3: 'right', 4: 'halt'}

print("Directions of shortest paths")
for agent_handle in obs:
    for idx, shortest in enumerate(obs[agent_handle]):
        if shortest:
            action = np.argmax(obs[0]) + 1
            print('- Agent {}: {}'.format(agent_handle, action_to_direction[action]))

In [0]:
from IPython.display import clear_output

obs, info = env.reset()

# Move in a direction that is on a shortest path
# This results in an optimal policy if and only if there is a single agent!
for step in range(150):
    obs, all_rewards, done, _ = env.step({0: action})
    action = np.argmax(obs[0]) + 1

    clear_output(wait=True)
    render_env(env)

    print("Timestep: {}".format(step))
    print("Action: {} ({})".format(action, action_to_direction[action]))
    print("Rewards: {}".format(all_rewards))
    print("Done: {}".format(done))

    if done['__all__']:
        print("All done!")
        break