In [1]:
import getopt
import random
import sys
import time
from collections import deque
# make sure the root path is in system path
from pathlib import Path

from flatland.envs.malfunction_generators import malfunction_from_params
# base_dir = Path(__file__).resolve().parent.parent
# sys.path.append(str(base_dir))

import matplotlib.pyplot as plt
import numpy as np
import torch
from double_duelling_dqn import Agent
from observation_utils import normalize_observation

from flatland.envs.rail_generators import complex_rail_generator, rail_from_manual_specifications_generator, random_rail_generator , sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv,LocalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.utils.ordered_set import OrderedSet
from flatland.core.grid.grid4_utils import get_new_position

cpu


In [9]:
Global_observations = GlobalObsForRailEnv()
Tree_observations = TreeObsForRailEnv(max_depth=2)

In [3]:
# random.seed(1000)
# np.random.seed(1000)

In [4]:
def environment1 ():
    transition_probability = [1.0,  # empty cell - Case 0
                          1.0,  # Case 1 - straight
                          1.0,  # Case 2 - simple switch
                          0.3,  # Case 3 - diamond drossing
                          0.5,  # Case 4 - single slip
                          0.5,  # Case 5 - double slip
                          0.2,  # Case 6 - symmetrical
                          0.0,  # Case 7 - dead end
                          0.2,  # Case 8 - turn left
                          0.2,  # Case 9 - turn right
                          1.0]  # Case 10 - mirrored switch

    # Example generate a random rail
    env = RailEnv(width=10,
                  height=10,
                  rail_generator=random_rail_generator(
                            cell_type_relative_proportion=transition_probability
                            ),
                  number_of_agents=1,
                  obs_builder_object= Tree_observations )
    return env

In [5]:
def environment2 ():
 #fix the randomness

#     random.seed(10)
#     np.random.seed(10)

    x_dim = 35
    y_dim = 35
    n_agents = 1

    stochastic_data = {'malfunction_rate': 8000,  # Rate of malfunction occurence of single agent
                       'min_duration': 15,  # Minimal duration of malfunction
                       'max_duration': 50  # Max duration of malfunction
                       }

    TreeObservation = TreeObsForRailEnv(max_depth=2)
    speed_ration_map = {1.: 0.,  # Fast passenger train
                        1. / 2.: 1.0,  # Fast freight train
                        1. / 3.: 0.0,  # Slow commuter train
                        1. / 4.: 0.0}  # Slow freight train
    
    env = RailEnv(width=x_dim,
                  height=y_dim,
                  rail_generator=sparse_rail_generator(max_num_cities=5,
                                                       # Number of cities in map (where train stations are)
                                                       seed=1,  # Random seed
                                                       grid_mode=False,
                                                       max_rails_between_cities=2,
                                                       max_rails_in_city=3),
                  schedule_generator=sparse_schedule_generator(),
                  number_of_agents=n_agents,
                  malfunction_generator_and_process_data=malfunction_from_params(stochastic_data),
                  # Malfunction data generator
                  obs_builder_object=TreeObservation)

    return env    

In [6]:
env = environment2()
env_renderer = RenderTool(env, gl="PILSVG", )
obs , info = env.reset()

num_features_per_node = env.obs_builder.observation_dim
tree_depth = 2
nr_nodes = 0
for i in range(tree_depth + 1):
    nr_nodes += np.power(4, i)
state_size = num_features_per_node * nr_nodes

# The action space of flatland is 5 discrete actions
action_size = 5

# And the max number of steps we want to take per episode
max_steps = int(3 * (env.height + env.width))
eps = 0.01

agent = Agent(state_size, action_size)

In [7]:
agent.load("Nets/checkpoint_env1")

In [8]:
_done = False
total_reward =0 
for _ in range(max_steps):
    
    _action = {}
    for a in range(env.get_num_agents()):
        if ( info['action_required'][a]):
            state = normalize_observation(obs[a], tree_depth, observation_radius=10)
            _action[a] =  agent.act(state, eps=eps)
            
        else:
            _action[a] = 0
    obs, all_rewards, done, info = env.step(_action)
    total_reward += all_rewards[0]
    _done = done['__all__']
    env_renderer.render_env(show=True , frames = True , show_observations = True , show_predictions=True )
    time.sleep(0.3)
    
    if (_done == True):
        break

  Predictors builder needs to populate: env.dev_pred_dict")


In [9]:
env_renderer.close_window()