88from math import sqrt
99
1010# 1 for manhattan, 0 for euclidean
11+ from typing import Optional
12+
1113HEURISTIC = 0
1214
1315grid = [
2224
2325delta = [[- 1 , 0 ], [0 , - 1 ], [1 , 0 ], [0 , 1 ]] # up, left, down, right
2426
27+ TPosition = tuple [int , int ]
28+
2529
2630class Node :
2731 """
@@ -39,7 +43,15 @@ class Node:
3943 True
4044 """
4145
42- def __init__ (self , pos_x , pos_y , goal_x , goal_y , g_cost , parent ):
46+ def __init__ (
47+ self ,
48+ pos_x : int ,
49+ pos_y : int ,
50+ goal_x : int ,
51+ goal_y : int ,
52+ g_cost : int ,
53+ parent : Optional [Node ],
54+ ) -> None :
4355 self .pos_x = pos_x
4456 self .pos_y = pos_y
4557 self .pos = (pos_y , pos_x )
@@ -61,7 +73,7 @@ def calculate_heuristic(self) -> float:
6173 else :
6274 return sqrt (dy ** 2 + dx ** 2 )
6375
64- def __lt__ (self , other ) -> bool :
76+ def __lt__ (self , other : Node ) -> bool :
6577 return self .f_cost < other .f_cost
6678
6779
@@ -81,23 +93,22 @@ class AStar:
8193 (4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)]
8294 """
8395
84- def __init__ (self , start , goal ):
96+ def __init__ (self , start : TPosition , goal : TPosition ):
8597 self .start = Node (start [1 ], start [0 ], goal [1 ], goal [0 ], 0 , None )
8698 self .target = Node (goal [1 ], goal [0 ], goal [1 ], goal [0 ], 99999 , None )
8799
88100 self .open_nodes = [self .start ]
89- self .closed_nodes = []
101+ self .closed_nodes : list [ Node ] = []
90102
91103 self .reached = False
92104
93- def search (self ) -> list [tuple [ int ] ]:
105+ def search (self ) -> list [TPosition ]:
94106 while self .open_nodes :
95107 # Open Nodes are sorted using __lt__
96108 self .open_nodes .sort ()
97109 current_node = self .open_nodes .pop (0 )
98110
99111 if current_node .pos == self .target .pos :
100- self .reached = True
101112 return self .retrace_path (current_node )
102113
103114 self .closed_nodes .append (current_node )
@@ -118,8 +129,7 @@ def search(self) -> list[tuple[int]]:
118129 else :
119130 self .open_nodes .append (better_node )
120131
121- if not (self .reached ):
122- return [(self .start .pos )]
132+ return [self .start .pos ]
123133
124134 def get_successors (self , parent : Node ) -> list [Node ]:
125135 """
@@ -147,7 +157,7 @@ def get_successors(self, parent: Node) -> list[Node]:
147157 )
148158 return successors
149159
150- def retrace_path (self , node : Node ) -> list [tuple [ int ] ]:
160+ def retrace_path (self , node : Optional [ Node ] ) -> list [TPosition ]:
151161 """
152162 Retrace the path from parents to parents until start node
153163 """
@@ -173,20 +183,19 @@ class BidirectionalAStar:
173183 (2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)]
174184 """
175185
176- def __init__ (self , start , goal ) :
186+ def __init__ (self , start : TPosition , goal : TPosition ) -> None :
177187 self .fwd_astar = AStar (start , goal )
178188 self .bwd_astar = AStar (goal , start )
179189 self .reached = False
180190
181- def search (self ) -> list [tuple [ int ] ]:
191+ def search (self ) -> list [TPosition ]:
182192 while self .fwd_astar .open_nodes or self .bwd_astar .open_nodes :
183193 self .fwd_astar .open_nodes .sort ()
184194 self .bwd_astar .open_nodes .sort ()
185195 current_fwd_node = self .fwd_astar .open_nodes .pop (0 )
186196 current_bwd_node = self .bwd_astar .open_nodes .pop (0 )
187197
188198 if current_bwd_node .pos == current_fwd_node .pos :
189- self .reached = True
190199 return self .retrace_bidirectional_path (
191200 current_fwd_node , current_bwd_node
192201 )
@@ -220,12 +229,11 @@ def search(self) -> list[tuple[int]]:
220229 else :
221230 astar .open_nodes .append (better_node )
222231
223- if not self .reached :
224- return [self .fwd_astar .start .pos ]
232+ return [self .fwd_astar .start .pos ]
225233
226234 def retrace_bidirectional_path (
227235 self , fwd_node : Node , bwd_node : Node
228- ) -> list [tuple [ int ] ]:
236+ ) -> list [TPosition ]:
229237 fwd_path = self .fwd_astar .retrace_path (fwd_node )
230238 bwd_path = self .bwd_astar .retrace_path (bwd_node )
231239 bwd_path .pop ()
@@ -236,9 +244,6 @@ def retrace_bidirectional_path(
236244
237245if __name__ == "__main__" :
238246 # all coordinates are given in format [y,x]
239- import doctest
240-
241- doctest .testmod ()
242247 init = (0 , 0 )
243248 goal = (len (grid ) - 1 , len (grid [0 ]) - 1 )
244249 for elem in grid :
@@ -252,6 +257,5 @@ def retrace_bidirectional_path(
252257
253258 bd_start_time = time .time ()
254259 bidir_astar = BidirectionalAStar (init , goal )
255- path = bidir_astar .search ()
256260 bd_end_time = time .time () - bd_start_time
257261 print (f"BidirectionalAStar execution time = { bd_end_time :f} seconds" )
0 commit comments