Skip to content
Merged
95 changes: 52 additions & 43 deletions graphs/bi_directional_dijkstra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@
import numpy as np


def pass_and_relaxation(
graph: dict,
v: str,
visited_forward: set,
visited_backward: set,
cst_fwd: dict,
cst_bwd: dict,
queue: PriorityQueue,
parent: dict,
shortest_distance: float | int,
) -> float | int:
for nxt, d in graph[v]:
if nxt in visited_forward:
continue
old_cost_f = cst_fwd.get(nxt, np.inf)
new_cost_f = cst_fwd[v] + d
if new_cost_f < old_cost_f:
queue.put((new_cost_f, nxt))
cst_fwd[nxt] = new_cost_f
parent[nxt] = v
if nxt in visited_backward:
if cst_fwd[v] + d + cst_bwd[nxt] < shortest_distance:
shortest_distance = cst_fwd[v] + d + cst_bwd[nxt]
return shortest_distance


def bidirectional_dij(
source: str, destination: str, graph_forward: dict, graph_backward: dict
) -> int:
Expand Down Expand Up @@ -51,53 +77,36 @@ def bidirectional_dij(
if source == destination:
return 0

while queue_forward and queue_backward:
while not queue_forward.empty():
_, v_fwd = queue_forward.get()

if v_fwd not in visited_forward:
break
else:
break
while not queue_forward.empty() and not queue_backward.empty():
_, v_fwd = queue_forward.get()
visited_forward.add(v_fwd)

while not queue_backward.empty():
_, v_bwd = queue_backward.get()

if v_bwd not in visited_backward:
break
else:
break
_, v_bwd = queue_backward.get()
visited_backward.add(v_bwd)

# forward pass and relaxation
for nxt_fwd, d_forward in graph_forward[v_fwd]:
if nxt_fwd in visited_forward:
continue
old_cost_f = cst_fwd.get(nxt_fwd, np.inf)
new_cost_f = cst_fwd[v_fwd] + d_forward
if new_cost_f < old_cost_f:
queue_forward.put((new_cost_f, nxt_fwd))
cst_fwd[nxt_fwd] = new_cost_f
parent_forward[nxt_fwd] = v_fwd
if nxt_fwd in visited_backward:
if cst_fwd[v_fwd] + d_forward + cst_bwd[nxt_fwd] < shortest_distance:
shortest_distance = cst_fwd[v_fwd] + d_forward + cst_bwd[nxt_fwd]

# backward pass and relaxation
for nxt_bwd, d_backward in graph_backward[v_bwd]:
if nxt_bwd in visited_backward:
continue
old_cost_b = cst_bwd.get(nxt_bwd, np.inf)
new_cost_b = cst_bwd[v_bwd] + d_backward
if new_cost_b < old_cost_b:
queue_backward.put((new_cost_b, nxt_bwd))
cst_bwd[nxt_bwd] = new_cost_b
parent_backward[nxt_bwd] = v_bwd

if nxt_bwd in visited_forward:
if cst_bwd[v_bwd] + d_backward + cst_fwd[nxt_bwd] < shortest_distance:
shortest_distance = cst_bwd[v_bwd] + d_backward + cst_fwd[nxt_bwd]
shortest_distance = pass_and_relaxation(
graph_forward,
v_fwd,
visited_forward,
visited_backward,
cst_fwd,
cst_bwd,
queue_forward,
parent_forward,
shortest_distance,
)

shortest_distance = pass_and_relaxation(
graph_backward,
v_bwd,
visited_backward,
visited_forward,
cst_bwd,
cst_fwd,
queue_backward,
parent_backward,
shortest_distance,
)

if cst_fwd[v_fwd] + cst_bwd[v_bwd] >= shortest_distance:
break
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ show-source = true
target-version = "py311"

[tool.ruff.mccabe] # DO NOT INCREASE THIS VALUE
max-complexity = 20 # default: 10
max-complexity = 17 # default: 10

[tool.ruff.pylint] # DO NOT INCREASE THESE VALUES
max-args = 10 # default: 5
Expand Down