Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions graphs/johnson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import heapq
from collections.abc import Hashable

Node = Hashable
edge = tuple[Node, Node, float]
adjacency = dict[Node, list[tuple[Node, float]]]


def _collect_nodes_and_edges(graph: adjacency) -> tuple[list[Node], list[edge]]:
nodes = set()
edges: list[edge] = []
for u, neighbors in graph.items():
nodes.add(u)
for v, w in neighbors:
nodes.add(v)
edges.append((u, v, w))
return list(nodes), edges


def _bellman_ford(nodes: list[Node], edges: list[edge]) -> dict[Node, float]:
"""
Bellman-Ford relaxation to compute potentials h[v] for all vertices.
Raises ValueError if a negative weight cycle exists.
"""
dist: dict[Node, float] = dict.fromkeys(nodes, 0.0)
n = len(nodes)

for _ in range(n - 1):
updated = False
for u, v, w in edges:
if dist[u] + w < dist[v]:
dist[v] = dist[u] + w
updated = True
if not updated:
break
else:
for u, v, w in edges:
if dist[u] + w < dist[v]:
raise ValueError("Negative weight cycle detected")
return dist


def _dijkstra(
start: Node,
nodes: list[Node],
graph: adjacency,
potentials: dict[Node, float],
) -> dict[Node, float]:
"""
Dijkstra over reweighted graph, using potentials h to make weights non-negative.
Returns distances from start in the reweighted space.
"""
inf = float("inf")
dist: dict[Node, float] = dict.fromkeys(nodes, inf)
dist[start] = 0.0
heap: list[tuple[float, Node]] = [(0.0, start)]

while heap:
d_u, u = heapq.heappop(heap)
if d_u > dist[u]:
continue
for v, w in graph.get(u, []):
w_prime = w + potentials[u] - potentials[v]
if w_prime < 0:
raise ValueError(
"Negative edge weight after reweighting: numeric error"
)
new_dist = d_u + w_prime
if new_dist < dist[v]:
dist[v] = new_dist
heapq.heappush(heap, (new_dist, v))
return dist


def johnson(graph: adjacency) -> dict[Node, dict[Node, float]]:
"""
Compute all-pairs shortest paths using Johnson's algorithm.
Args:
graph: adjacency list {u: [(v, weight), ...], ...}
Returns:
dict of dicts: dist[u][v] = shortest distance from u to v
Raises:
ValueError: if a negative weight cycle is detected
Example:
>>> g = {
... 0: [(1, 3), (2, 8), (4, -4)],
... 1: [(3, 1), (4, 7)],
... 2: [(1, 4)],
... 3: [(0, 2), (2, -5)],
... 4: [(3, 6)],
... }
>>> round(johnson(g)[0][3], 2)
2.0
"""
nodes, edges = _collect_nodes_and_edges(graph)
potentials = _bellman_ford(nodes, edges)

all_pairs: dict[Node, dict[Node, float]] = {}
inf = float("inf")
for s in nodes:
dist_reweighted = _dijkstra(s, nodes, graph, potentials)
dists_orig: dict[Node, float] = {}
for v in nodes:
d_prime = dist_reweighted[v]
if d_prime < inf:
dists_orig[v] = d_prime - potentials[s] + potentials[v]
else:
dists_orig[v] = inf
all_pairs[s] = dists_orig

return all_pairs
24 changes: 24 additions & 0 deletions graphs/tests/test_johnson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import math

import pytest

from graphs.johnson import johnson


def test_johnson_basic():
g = {
0: [(1, 3), (2, 8), (4, -4)],
1: [(3, 1), (4, 7)],
2: [(1, 4)],
3: [(0, 2), (2, -5)],
4: [(3, 6)],
}
dist = johnson(g)
assert math.isclose(dist[0][3], 2.0, abs_tol=1e-9)
assert math.isclose(dist[3][2], -5.0, abs_tol=1e-9)


def test_johnson_negative_cycle():
g2 = {0: [(1, 1)], 1: [(0, -3)]}
with pytest.raises(ValueError):
johnson(g2)