In [1]:
#this notebook implements the A* algorithm for single agent on flatland environment
#The state space for searching usign A* is (location_x, location_y , orientation)

In [2]:
from flatland.envs.rail_generators import complex_rail_generator, rail_from_manual_specifications_generator, random_rail_generator , sparse_rail_generator
from flatland.envs.schedule_generators import complex_schedule_generator, sparse_schedule_generator
from flatland.envs.rail_env import RailEnv
from flatland.utils.rendertools import RenderTool
from flatland.envs.observations import TreeObsForRailEnv,GlobalObsForRailEnv,LocalObsForRailEnv
from flatland.envs.predictions import ShortestPathPredictorForRailEnv
from flatland.utils.ordered_set import OrderedSet
from flatland.core.grid.grid4_utils import get_new_position,get_direction
from flatland.envs.schedule_utils import Schedule
from flatland.core.grid.grid4 import Grid4TransitionsEnum

import math
import numpy as np
import time
import random
from MappedQueue import PriorityQueue

In [3]:
def environment1 ():
#     random.seed(1)
#     np.random.seed(1)
    NUMBER_OF_AGENTS = 1
    env = RailEnv(
                width=20,
                height=20,
                rail_generator=complex_rail_generator(
                                        nr_start_goal=20,
                                        nr_extra=1,
                                        min_dist=8,
                                        max_dist=99999,
                                        seed=0),
                schedule_generator=complex_schedule_generator(),
                number_of_agents=NUMBER_OF_AGENTS)
    return env


In [4]:
env = environment1()
env_renderer = RenderTool(env ,screen_width=1600 , screen_height=750 , show_debug=True )
obs , info = env.reset()
env_renderer.render_env(show=True , frames = True , show_observations = False , show_predictions=False )

In [5]:
def heuristic(a , b):
    (x1, y1) = a
    (x2, y2) = b
    return abs(x1 - x2) + abs(y1 - y2)

In [6]:
def decode_direction(direction):
    if direction == Grid4TransitionsEnum.NORTH:
        return 0
    if direction == Grid4TransitionsEnum.WEST:
        return 1
    if direction == Grid4TransitionsEnum.SOUTH:
        return 2
    if direction == Grid4TransitionsEnum.EAST:
        return 3

In [7]:
def get_new_position_direction (position , direction):
    
    cell_transitions = env.rail.get_transitions(*position, direction)
    transition_bit = bin(env.rail.get_full_transitions(*position))
    total_transitions = transition_bit.count("1")

    crossing_found = False
    if int(transition_bit, 2) == int('1000010000100001', 2):
        crossing_found = True

    if crossing_found:
        total_transitions = 2
    num_transitions = np.count_nonzero(cell_transitions)
    
    last_is_dead_end = False
    if num_transitions == 1:
        nbits = total_transitions
        if nbits == 1:
            last_is_dead_end = True

    possible_transitions = env.rail.get_transitions(*position, direction)
    
    transitions = []
    
    
    for i , branch_direction in enumerate([(direction + 4 + i) % 4 for i in range(-1, 2)]):
        if last_is_dead_end and env.rail.get_transition((*position, direction),
                                                                 (branch_direction + 2) % 4):
            
            new_cell = get_new_position(position, (branch_direction + 2) % 4)
            transitions.append((new_cell , (branch_direction + 2) % 4))
        
        elif possible_transitions[branch_direction]:
            new_cell = get_new_position(position, branch_direction)
            transitions.append((new_cell , branch_direction))
        
        else:
            transitions.append(None)
            
    return transitions

In [8]:
#implementing A*
#state space is ((x,y) , orientation)
agent = env.agents[0]
start = (agent.initial_position , agent.initial_direction)
goal = agent.target

frontier = PriorityQueue()
frontier.put(start , 0)
came_from = {}
cost_so_far = {}
came_from[start] = (None,None)
cost_so_far[start] = 0
explored = set()

while not frontier.empty():
    current = frontier.get()
    position , direction = current
    
    if position == goal:
        break
        
    explored.add (current)
    
    transitions = get_new_position_direction (position , direction)
    for i, d in enumerate(transitions):
        if (d is not None):
            
            new_position , new_direction = d
            new_cost = cost_so_far[current] + 1
            next = (new_position , new_direction)
            
            if next not in cost_so_far or new_cost < cost_so_far[next]:
                
                if next in explored:
                    continue
                
                cost_so_far[next] = new_cost
                priority = new_cost + heuristic(goal , new_position)
                frontier.put(next , priority)
                came_from[next] = (current , i)

In [9]:
cost_so_far

{((1, 6), 3): 0,
 ((1, 7), 1): 1,
 ((1, 8), 1): 2,
 ((1, 9), 1): 3,
 ((1, 10), 1): 4,
 ((2, 10), 2): 5,
 ((2, 11), 1): 6,
 ((3, 11), 2): 7,
 ((3, 12), 1): 8,
 ((2, 12), 0): 9,
 ((4, 12), 2): 9,
 ((4, 13), 1): 10,
 ((5, 13), 2): 11,
 ((5, 14), 1): 12,
 ((6, 14), 2): 13,
 ((6, 15), 1): 14,
 ((6, 13), 3): 14,
 ((7, 15), 2): 15,
 ((7, 16), 1): 16,
 ((7, 17), 1): 17,
 ((8, 16), 2): 17}

In [14]:
visited = OrderedSet()
predicted = OrderedSet()

In [15]:
for keys in explored:
    visited.add(keys[0])
    
for keys, _ in cost_so_far.items():
    predicted.add(keys[0])

In [12]:
env.dev_obs_dict[0] = visited
env.dev_pred_dict[0] = predicted
env_renderer.render_env(show=True , frames = False , show_observations = True , show_predictions=True)