<a href="https://colab.research.google.com/github/FatLads/Notebooks/blob/main/FlatLand_DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!sudo apt install -y xvfb ffmpeg
!pip install -q 'imageio==2.4.0'
!pip install -q pyvirtualdisplay
!pip install -q tf-agents
!pip install -q flatland-rl

Reading package lists... Done
Building dependency tree       
Reading state information... Done
ffmpeg is already the newest version (7:3.4.8-0ubuntu0.2).
xvfb is already the newest version (2:1.19.6-1ubuntu4.8).
0 upgraded, 0 newly installed, 0 to remove and 29 not upgraded.
[31mERROR: flatland-rl 2.2.2 has requirement gym==0.14.0, but you'll have gym 0.18.0 which is incompatible.[0m
[31mERROR: tf-agents 0.7.1 has requirement cloudpickle>=1.3, but you'll have cloudpickle 1.2.2 which is incompatible.[0m
[31mERROR: tf-agents 0.7.1 has requirement gym>=0.17.0, but you'll have gym 0.14.0 which is incompatible.[0m
[31mERROR: tensorflow-probability 0.12.1 has requirement cloudpickle>=1.3, but you'll have cloudpickle 1.2.2 which is incompatible.[0m


In [None]:
from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import PIL.Image
import pyvirtualdisplay

import tensorflow as tf

from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment, py_environment
from tf_agents.eval import metric_utils
from tf_agents.networks import sequential 
from tf_agents.policies import random_tf_policy 
from tf_agents.environments import utils
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory, time_step
from tf_agents.specs import tensor_spec, BoundedArraySpec
from tf_agents.utils import common 
from flatland.envs.rail_env import RailEnv

display = pyvirtualdisplay.Display(visible=0, size=(1400,900)).start()

In [None]:
num_iterations = 30000 # @param {type:"integer"}

initial_collect_steps = 100 # @param {type:"integer"}
collect_steps_per_iteration = 1 # @param {type:"integer"}
replay_buffer_max_length = 100000 # @param {type:"integer"}

batch_size = 64 # @param {type:"integer"}
learning_rate = 1e-4 # @param {type:"number"}
log_interval = 200 # @param {type:"integer"}

num_eval_episodes = 10 # @param {type:"integer"}
eval_interval = 1000 # @param {type:"integer"}

Next, we're gonna try and reimplement the `Rails` env with all of the things that TensorFlow needs, like explained [here](https://www.tensorflow.org/agents/tutorials/2_environments_tutorial) 

In [None]:
class FatRails(py_environment.PyEnvironment):
    def __init__(self, *args, **kwargs):
        self.env = RailEnv(*args, **kwargs)
        self._episode_ended = False

    def action_spec(self):
        return BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=4, name='action') # We need to convert from an array of actions (index=agent, value=action) to a dict
    def observation_spec(self):
        return BoundedArraySpec(shape=(self.env.width, self.env.height, 16), dtype=np.int32, minimum=0, maximum=1, name='observation') #TODO
    def _step(self, action):
        action_dict = {v: k for v, k in enumerate(action)}
        step_env = self.env.step(action_dict)
        agents_statuses = step_env[3]['status']
        for status in agents_statuses.values(): # Check if there's someone that didn't arrive yet
            if status<3: # The status observations are 0..3, check the docs
                return time_step.transition(step_env[0][0][0], step_env[1][0])
        self._episode_ended = True
        return time_step.termination(step_env[0][0][0], step_env[1][0]) # If no one is moving/has to depart, we're finished
    def _reset(self):
        reset = self.env.reset()
        self._episode_ended = False
        return time_step.restart(np.array(reset[0][0][0], dtype=np.int32))

Note the `status<3`: this exploits the agents' statuses we can find [here](https://gitlab.aicrowd.com/flatland/flatland/-/blob/master/flatland/envs/agent_utils.py)

In [None]:
env = FatRails(16,16)

In [None]:
env.step([RailEnvActions.MOVE_FORWARD])

TimeStep(step_type=array(1, dtype=int32), reward=array(-1., dtype=float32), discount=array(1., dtype=float32), observation=array([[[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 1., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 1.],
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]],

       [[0., 0., 0., ..., 1., 0., 0.],
        [1., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 1., 0., ..., 0., 1., 0.],
        [1., 0., 0., ..., 0., 0., 1.],
        [0., 0., 0., ..., 0., 0., 0.]],

       ...,

       [[0., 0., 0., ..., 1., 0., 0.],
        [1., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],


In [None]:
train_env = tf_py_environment.TFPyEnvironment(env)
test_env = tf_py_environment.TFPyEnvironment(env)

In [None]:
train_env.time_step_spec()

TimeStep(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), observation=BoundedTensorSpec(shape=(16, 16, 16), dtype=tf.int32, name='observation', minimum=array(0, dtype=int32), maximum=array(1, dtype=int32)))

In [None]:
action_tensor_spec = tensor_spec.from_spec(env.action_spec())
action_tensor_spec.maximum-action_tensor_spec.minimum + 1


5

In [None]:
fc_layer_params = (100,50)

action_tensor_spec = tensor_spec.from_spec(env.action_spec())
num_actions = action_tensor_spec.maximum-action_tensor_spec.minimum + 1



def dense_layer(num_units):
    return tf.keras.layers.Dense(
        num_units,
        activation=tf.keras.activations.relu,
        kernel_initializer=tf.keras.initializers.VarianceScaling(
            scale=2.0, mode='fan_in', distribution='truncated_normal'
        )
    )

dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
    num_actions,
    activation=None,
    kernel_initializer=tf.keras.initializers.RandomUniform(
        minval=-0.03, maxval=0.03
    ),
    bias_initializer=tf.keras.initializers.Constant(-0.2)
)
q_net = sequential.Sequential(dense_layers+[q_values_layer])

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(train_env.time_step_spec(),
                           train_env.action_spec(),
                           q_network=q_net,
                           optimizer=optimizer,
                           td_errors_loss_fn=common.element_wise_squared_loss,
                           train_step_counter=train_step_counter)
agent.initialize()

ValueError: ignored

In [None]:
eval_policy = agent.policy
collect_policy = agent.collect_policy

In [None]:
def compute_avg_return(environment, policy, num_episodes):
    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step)
            episode_return += time_step.reward
        total_return += episode_return

    avg_return = total_return / num_episodes 
    return avg_return.numpy()[0]

In [None]:
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length
)

In [None]:
def collect_step(environment, policy, replay):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)
    replay.add_batch(traj)

def collect_data(env, policy, replay, steps):
    for _ in range(steps):
        collect_step(env, policy, replay)

collect_data(train_env, random_policy, replay_buffer, initial_collect_steps)

In [None]:
data = replay_buffer.as_dataset(
    num_parallel_calls=3,
    sample_batch_size=batch_size,
    num_steps=2
).prefetch(3)

iterator = iter(data)

In [None]:
try:
    %%time
except:
    pass

agent.train = common.function(agent.train)

agent.train_step_counter.assign(0)
avg_return = compute_avg_return(test_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(5000):
    old_data = data
    collect_data(train_env, agent.collect_policy, replay_buffer, collect_steps_per_iteration)
    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step = agent.train_step_counter.numpy()

    if step % log_interval == 0:
        print(f"step={step}: loss = {train_loss}")
    if step % eval_interval ==0:
        avg_return = compute_avg_return(test_env, agent.policy, num_eval_episodes)
        print(f"Step {step}, avg_ret: {avg_return}")
        returns.append(avg_return)

In [None]:
iterations = range(0, num_iterations+1, eval_interval)
plt.plot(iterations, returns)
plt.ylabel('Average Return')
plt.xlabel('Iterations')
plt.ylim(top=250)