<a href="https://colab.research.google.com/github/FatLads/Notebooks/blob/main/DQN_FlatLand.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install flatland-rl keras-rl2

Collecting flatland-rl
[?25l  Downloading https://files.pythonhosted.org/packages/61/30/e002f8b7d9075c88f2f00e870294e7896c92db8b3c94ae9c442ca0e42bc2/flatland-rl-2.2.2.tar.gz (3.3MB)
[K     |████████████████████████████████| 3.3MB 4.6MB/s 
[?25hCollecting keras-rl2
[?25l  Downloading https://files.pythonhosted.org/packages/dd/34/94ffeab44eef43e22a01d82aa0ca062a97392c2c2415ba8b210e72053285/keras_rl2-1.0.4-py3-none-any.whl (53kB)
[K     |████████████████████████████████| 61kB 5.6MB/s 
[?25hCollecting tox>=3.5.2
  Using cached https://files.pythonhosted.org/packages/ec/7e/4609fd0386d41f0b94fe952708970fb87cc1fb66e088758b1f0ab336802e/tox-3.23.0-py2.py3-none-any.whl
Collecting pytest<5,>=3.8.2
  Using cached https://files.pythonhosted.org/packages/70/c7/e8cb4a537ee4fc497ac80a606a667fd1832f28ad3ddbfa25bf30473eae13/pytest-4.6.11-py2.py3-none-any.whl
Collecting pytest-runner>=4.2
  Using cached https://files.pythonhosted.org/packages/40/96/9024a1c07bbe5e16bdcbcbd021b608e37b32df4301ae2090aa

In [2]:
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import sparse_rail_generator
from flatland.envs.schedule_generators import sparse_schedule_generator
from PIL import Image
import numpy as np
import gym

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation, Flatten, Convolution2D, Permute
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K

from rl.agents.dqn import DQNAgent
from rl.policy import LinearAnnealedPolicy, BoltzmannQPolicy, EpsGreedyQPolicy
from rl.memory import SequentialMemory
from rl.core import Processor
from rl.callbacks import FileLogger, ModelIntervalCheckpoint

In [19]:
class RailProcessor(Processor):
    def process_observation(self, observation):

        return np.array(observation[0][0][0]).flatten() # For now, we'll just keep the transition maps
    def process_reward(self, reward):
        reward_sum = 0
        for _, train_reward in reward:
            reward_sum += train_reward
        return reward_sum

In [4]:
from flatland.utils.rendertools import RenderTool
import matplotlib.pyplot as plt

def render_env(env, figsize=(8, 8)):
  """Show the environment using matplotlib"""
  env_renderer = RenderTool(env, gl="PILSVG")
  # img is a numpy array
  img = env_renderer.render_env(show=True, return_image=True)

  plt.figure(figsize=figsize)
  plt.imshow(img)
  plt.show()

In [5]:
from flatland.envs.rail_env import RailEnv
from flatland.envs.rail_generators import complex_rail_generator
from flatland.envs.observations import GlobalObsForRailEnv

seed = 69

width = 10 # @param{type: "integer"}
height = 10 # @param{type: "integer"}
agents =  2 # @param{type: "integer"}

random_rail_generator = complex_rail_generator(
    nr_start_goal=10, # @param{type:"integer"} number of start and end goals 
                      # connections, the higher the easier it should be for 
                      # the trains
    nr_extra=10, # @param{type:"integer"} extra connections 
                 # (useful for alternite paths), the higher the easier
    min_dist=10,
    max_dist=99999,
    seed=seed
)

env = RailEnv(
    width=width,
    height=height,
    rail_generator=random_rail_generator,
    obs_builder_object=GlobalObsForRailEnv(),
    number_of_agents=agents
)

# env.reset is needed to build the first step of the env
_ = env.reset() # assigned to _ just to suppress the output

In [7]:
nb_actions = env.action_space[0]
nb_actions

5

In [18]:
input_shape = np.array(env.reset()[0][0][0]).flatten().shape
input_shape

(1600,)

In [20]:
model = Sequential()
model.add(Dense(100, input_shape=input_shape, activation="relu"))
model.add(Dense(50, input_shape=input_shape, activation="relu"))
model.add(Dense(100, input_shape=input_shape, activation="relu"))
model.add(Dense(nb_actions, activation="linear"))
print(model.summary())

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_8 (Dense)              (None, 100)               160100    
_________________________________________________________________
dense_9 (Dense)              (None, 50)                5050      
_________________________________________________________________
dense_10 (Dense)             (None, 100)               5100      
_________________________________________________________________
dense_11 (Dense)             (None, 5)                 505       
Total params: 170,755
Trainable params: 170,755
Non-trainable params: 0
_________________________________________________________________
None


In [21]:
memory = SequentialMemory(limit=1000000, window_length=4)
processor = RailProcessor()
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.05,
                              nb_steps=1000000)
dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory,
               processor=processor, nb_steps_warmup=50000, gamma=.99, target_model_update=10000,
               train_interval=4, delta_clip=1.)
dqn.compile(Adam(lr=.00025), metrics=['mae'])

In [24]:
# Okay, now it's time to learn something! We capture the interrupt exception so that training
# can be prematurely aborted. Notice that now you can use the built-in tensorflow.keras callbacks!
weights_filename = 'dqn_flatland_weights.h5f'
checkpoint_weights_filename = 'dqn_flatland_weights_{step}.h5f'
log_filename = 'dqn_flatland_log.json'
callbacks = [ModelIntervalCheckpoint(checkpoint_weights_filename, interval=250000)]
callbacks += [FileLogger(log_filename, interval=100)]
dqn.fit(env, callbacks=callbacks, nb_steps=1750000, log_interval=10000)

# After training is done, we save the final weights one more time.
dqn.save_weights(weights_filename, overwrite=True)

# Finally, evaluate our algorithm for 10 episodes.
dqn.test(env, nb_episodes=10, visualize=False)

Training for 1750000 steps ...
Interval 1 (0 steps performed)


ValueError: ignored