In [None]:
# Install Flatland
!git clone https://gitlab.aicrowd.com/flatland/flatland.git/ --branch 223_UpdateEditor_55_notebooks
%cd flatland
!pip install -e .
!pip install imageio

Agent

In [1]:
import numpy as np
import random
from collections import deque

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Reshape
from tensorflow.python.keras import Input

"""
An agent for Deep Q-Learning models.

It encapsulates a neural network to predict Q-values and an experience-replay buffer to train the model on data batches.
"""


class SingleDQNAgent:
    """
    Args:
        env: OpenAI gym associated environment
        optimizer: Neural Network optimizer
        gamma: Discount
        epsilon: Exploration factor
    """

    REPLAY_BUFFER_MAX_LEN = 2000

    def __init__(self, env, optimizer, gamma=0.6, epsilon=0.1):
        self.env = env
        self._state_size = env.observation_space.n
        self._action_size = env.action_space.n
        self._optimizer = optimizer

        self.replay_buffer = deque(maxlen=self.REPLAY_BUFFER_MAX_LEN)

        self.gamma = gamma
        self.epsilon = epsilon

        # Build q anf target networks
        self.q_network = self._build_compile_model()
        self.target_network = self._build_compile_model()
        self.update_target_model()

    """
    Store the experience in the Replay Buffer
    """

    def store(self, state, action, reward, next_state, terminated):
        self.replay_buffer.append((state, action, reward, next_state, terminated))

    """
    Build the neural network to estimate q values
    """

    def _build_compile_model(self):
        model = Sequential()
        model.add(Input(shape=(11,)))
        model.add(Reshape((11,)))
        model.add(Dense(22, activation='relu'))
        model.add(Dense(22, activation='relu'))
        model.add(Dense(self._action_size, activation='linear'))

        model.compile(loss='mse', optimizer=self._optimizer)
        return model

    """
    Update the target network using the weights of the q one
    """

    def update_target_model(self):
        self.target_network.set_weights(self.q_network.get_weights())

    """
    Decide the action to take based on the model or exploring new one
    
    Args:
        state: the current state extracted from the observation
    """

    def act(self, state):
        # print(self.q_network.get_weights())
        # Exploration
        if np.random.rand() <= self.epsilon:
            return self.env.action_space.sample()

        # Exploitation
        q_values = self.q_network.predict(state)
        return np.argmax(q_values[0])

    """
    The method used to train the model using a batch of experiences.

    Args:
        batch_size: the number of samples extracted from the Replay Buffer and used to train the model.
    """

    def retrain(self, batch_size):
        if batch_size > len(self.replay_buffer):
            raise ValueError("Replay Buffer length exceeded.")

        minibatch = random.sample(self.replay_buffer, batch_size)

        for state, action, reward, new_state, done in minibatch:
            target = self.q_network.predict(state)

            if done:
                target[0][action] = reward
            else:
                t = self.target_network.predict(new_state)
                target[0][action] = reward + self.gamma * np.amax(t)

            self.q_network.fit(state, target, epochs=1, verbose=0)


Environment

In [2]:
from PIL import Image
from IPython.core.display import display
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from gym import Env
from gym.spaces import Discrete

"""
An OpenAI gym environment that wraps a Flatland one
"""


class SingleAgentEnvironment(Env):
    flatland_env = None
    renderer = None

    """
    Args:
        flatland_env: The Flatland environment
        renderer: The renderer
    """

    def __init__(self, flatland_env, renderer=None):
        self.flatland_env = flatland_env
        self.renderer = renderer

        self.reward_range = (-1, 1)
        self.action_space = Discrete(5)
        self.observation_space = Discrete(11)

    """
    Execute an action.
    Args:
        action_dict: the dictionary agent -> action to perform
    Return:
        new_observation: The new observation for each agent
        reward: The reward for each agent
        done: True if an agent has concluded
        info: Some info for each agent
    """

    def step(self, action_dict):
        return self.flatland_env.step(action_dict)

    """
    Reset the environment and return an observation
    Returns:
        observation: The new observation
    """

    def reset(self):
        return self.flatland_env.reset(regenerate_rail=False,
                                       regenerate_schedule=False,
                                       random_seed=True)

    """
        Render the environment
    """

    def render(self, mode='human'):
        # TODO: Merge both strategies (Jupyter vs .py)
        # In .py files
        # self.renderer.render_env(show=False, show_observations=False, show_predictions=False)
        # In Jupyter Notebooks
        env_renderer = RenderTool(self.flatland_env, gl="PILSVG")
        env_renderer.render_env()

        image = env_renderer.get_image()
        pil_image = Image.fromarray(image)
        display(pil_image)
        return image

    """
        Reset the renderer the environment
    """
    def reset_renderer(self):
        self.renderer = RenderTool(
            self.flatland_env,
            gl="PILSVG",
            agent_render_variant=AgentRenderVariant.AGENT_SHOWS_OPTIONS_AND_BOX,
            show_debug=True,
            screen_height=700,
            screen_width=1300)

    def close_window(self):
        self.renderer.close_window()


Observation

In [3]:
import numpy as np

from flatland.core.env_observation_builder import ObservationBuilder


"""
Observation for a DQN based single agent.

An observation is a dictionary of the form:
{"observations": List[List[int],List[int],List[int]], "position": tuple}

Similar to the Q-learning version but returns always the same structure, even impossible transitions in the form 
[0,0,0].
"""


class SingleDQNAgentObs(ObservationBuilder):

    def __init__(self):
        super().__init__()

    """
    """
    def reset(self):
        # TODO
        pass

    """
    Args:
        handle: the agent index
    """
    def get(self, handle=0):
        agent = self.env.agents[handle]
        observations = []

        if agent.position:
            possible_transitions = self.env.rail.get_transitions(*agent.position, agent.direction)
            position = agent.position
        else:
            possible_transitions = self.env.rail.get_transitions(*agent.initial_position, agent.direction)
            position = agent.initial_position

        num_transitions = np.count_nonzero(possible_transitions)

        if num_transitions == 1:
            observations = [[0, 0, 0], [0, 1, 0], [0, 0, 0]]
        else:
            i = 0
            for direction in [(agent.direction + i) % 4 for i in range(-1, 2)]:
                observation = [0, 0, 0]
                if possible_transitions[direction]:
                    observation[i] = 1
                observations.append(observation)
                i = i + 1

        return {"observations": observations, "state": position}

DQN

In [4]:
import numpy as np

from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.utils.rendertools import RenderTool, AgentRenderVariant
from tensorflow.python.keras.optimizer_v2.adam import Adam

from IPython.display import clear_output
%matplotlib inline

"""
    Execution of the Deep Q-Learning algorithm for a single agent navigation
"""

# Render the environment
render = True
renderer = None
# Print stats within the episode
print_stats = False
# Print stats at the end of each episode
print_episode_stats = True
# Frequency of episodes to print
print_episode_stats_freq = 1

random_seed = 42
np.random.seed(random_seed)

action_to_direction = {0: 'no-op', 1: 'left', 2: 'forward', 3: 'right', 4: 'halt'}

EPISODES = 100
TIMESTEPS = 2000

WIDTH = 40
HEIGHT = 40

BATCH_SIZE = 32

env = RailEnv(
    width=WIDTH,
    height=HEIGHT,
    rail_generator=sparse_rail_generator(
        # Number of cities (= train stations)
        max_num_cities=3,
        # Distribute the cities evenly in a grid
        grid_mode=False,
        # Max number of rails connecting to a city
        max_rails_between_cities=1,
        # Number of parallel tracks in cities
        max_rails_in_city=1,
        seed=random_seed),
    obs_builder_object=SingleDQNAgentObs(),
    number_of_agents=1,
    random_seed=random_seed)

environment = SingleAgentEnvironment(env)
agents = [SingleDQNAgent(environment, Adam(lr=0.01)) for i in range(env.number_of_agents)]

agents[0].q_network.summary()

"""
Transform observation dictionary to neural network input (numpy column)
Args:
    observation: the observation to change
"""


def reshape_observation(observations):
    for a in range(env.number_of_agents):
        pos = [observations[a]["state"][0], observations[a]["state"][1]]
        observation = [i for row in observations[a]["observations"] for i in row]
        observation.extend(pos)
        observations[a] = np.array(observation).reshape((-1, 11))

    return observations


# Dictionary agent -> action used in step
action_dict = dict()

# Stats
stats = []

for e in range(0, EPISODES):
    # Reset the renderer
    if render:
        old_observation = environment.reset()

    # Reset the environment
    old_observations, info = environment.reset()
    old_observations = reshape_observation(old_observations)

    # Initialize variables
    episode_reward = 0
    terminated = False

    # Episode stats
    action_counter = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}

    for time_step in range(TIMESTEPS):

        # Initially False, remains False if no agent updates it
        update_values = False

        # Choose actions
        for a in range(env.number_of_agents):
            if info["action_required"][a]:
                update_values = True
                action = agents[a].act(old_observations[a])
                action_counter[action] += 1
                if print_stats:
                    print("Agent " + str(a) + " performs action: " + str(action))
            else:
                action = 0
            action_dict.update({a: action})

        # Apply the chosen actions
        new_observations, reward, terminated, info = environment.step(action_dict)

        if print_stats:
            print("Step (obs, reward, terminated, info): ")
            print(new_observations)
            print(reward)
            print(terminated)
            print(info)
            print("_______________________________")

        for a in range(env.number_of_agents):
            # Episode reward is the mean
            episode_reward += reward[a] / env.number_of_agents

            if update_values or terminated[a]:
                # Reshape the observations to feed the network
                new_observations = reshape_observation(new_observations)

                # Store S A R S' for each agent
                agents[a].store(old_observations[a], action_dict[a], reward[a], new_observations[a], terminated[a])

                old_observations = new_observations

        if render:
            clear_output(wait=True)
            environment.render()

        # Termination causes the end of the episode
        if terminated["__all__"]:
            for a in range(env.number_of_agents):
                agents[a].update_target_model()
            break

        # Retrain when the batch is ready
        for a in range(env.number_of_agents):
            if len(agents[a].replay_buffer) > BATCH_SIZE:
                agents[a].retrain(BATCH_SIZE)

    if (e + 1) % print_episode_stats_freq == 0:
        if print_episode_stats:
            print("**********************************")
            print("Episode: {}".format(e + 1))
            print("Action counter: " + str(action_counter))
            print("Final reward: " + str(episode_reward))
            print("**********************************")
        stats.append({"action_counter": action_counter, "episode_reward": episode_reward})


KeyboardInterrupt: 