<a href="https://colab.research.google.com/github/YanickSchraner/rl-on-trains-workshop/blob/main/01_Intro_to_environment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Flatland
Docs at: http://flatland-rl-docs.s3-website.eu-central-1.amazonaws.com/


## File structure


```
├── Notebooks, Readme, packages ..
├── agents: RL agents implementation
│   ├── curiosity.py
│   ├── dqn.py
│   ├── qlearning.py
│   └── random.py
├── helpers: Helpers to train, test, inspect agents
│   └── rl_helpers.py
└── videos: Save videos of your best agents here!
    └── video.mp4
```



In [1]:
#@title << Setup Google Colab by running this cell {display-mode: "form"}
import sys
if 'google.colab' in sys.modules:
    # Clone GitHub repository
    !git clone --single-branch --branch evaluation_setup https://github.com/YanickSchraner/rl-on-trains-workshop
        
    # Copy files required to run the code
    !cp -r "rl-workshop/agents" "rl-workshop/rl_helpers" .
    
    # Install packages via pip
    !pip install -r "rl-workshop/requirements.txt"
    
    # Restart Runtime
    import os
    os.kill(os.getpid(), 9)

ModuleNotFoundError: ignored

In [None]:
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.observations import TreeObsForRailEnv, GlobalObsForRailEnv
from flatland.envs.schedule_generators import sparse_schedule_generator

In [None]:
n_agents = 5
x_dim = 25
y_dim = 25
n_cities = 4
max_rails_between_cities = 2
max_rails_in_city = 3
seed = 42

# Different agent types (trains) with different speeds.
speed_ration_map = {
    1.: 1.0,       # 100% of fast passenger train
    1. / 2.: 0.0,  # 0% of fast freight train
    1. / 3.: 0.0,  # 0% of slow commuter train
    1. / 4.: 0.0   # 0% of slow freight train
}

In [None]:
# We are training an Agent using the Tree Observation with depth 2
observation_tree_depth = 2

tree_observation = TreeObsForRailEnv(max_depth=observation_tree_depth)

env = RailEnv(
    width=25,
    height=25,
    rail_generator=sparse_rail_generator(
        max_num_cities=4,
        seed=42,
        grid_mode=False,
        max_rails_between_cities=2,
        max_rails_in_city=2
    ),
    schedule_generator=sparse_schedule_generator(),
    number_of_agents=1,
    obs_builder_object=GlobalObsForRailEnv(),
    malfunction_generator_and_process_data=None,  # mal_gen.no_malfunction_generator(),
    malfunction_generator=None,
    remove_agents_at_target=True,
    random_seed=1,
    record_steps=False, close_following=True
)

In [None]:
env = RailEnv(
    width=x_dim,
    height=y_dim,
    rail_generator=sparse_rail_generator(
        max_num_cities=n_cities,
        seed=seed,
        grid_mode=False,
        max_rails_between_cities=max_rails_between_cities,
        max_rails_in_city=max_rails_in_city
    ),
    schedule_generator=sparse_schedule_generator(),
    number_of_agents=n_agents,
    obs_builder_object=tree_observation
)

obs = env.reset()

In [None]:
env_renderer = RenderTool(env, gl="PGL", screen_width=512, screen_height=512)

In [None]:
while True:
    obs, rew, done, info = env.step({0: np.random.randint(0, 5)})
    img = env_renderer.render_env(show=False, frames=False, show_observations=False, return_image=True)
    img = Image.fromarray(img)
    img.save("imgs/env.png", "PNG")
    break
