In [4]:
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from copy import deepcopy


def extend_shortest_paths(L, W):
    """Computes the min-plus matrix multiplication of L and W."""
    n = len(L)
    L_prime = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            L_prime[i, j] = np.inf
            for k in range(n):
                L_prime[i, j] = min(L_prime[i, j], L[i, k] + W[k, j])
    return L_prime


def slow_all_pairs_shortest_paths(W):
    """Computes the shortest paths between all pairs of vertices in a graph."""
    n = len(W)
    L = W
    for _ in range(n - 2):
        L = extend_shortest_paths(L, W)
    return L


def faster_all_pairs_shortest_paths(W):
    """Uses the repeated squaring method to compute the shortest paths between all pairs of vertices in a graph."""
    n = len(W)
    L = deepcopy(W)
    m = 1
    history = [L]
    while m < n - 1:
        L = extend_shortest_paths(L, L)
        history.append(L)
        m *= 2
    return L, history

In [5]:
# Initial weight matrix from Figure 23.2 of Cormen et al.
graph = np.array([
    [0, np.inf, np.inf, np.inf, -1, np.inf],
    [1, 0, np.inf, 2, np.inf, np.inf],
    [np.inf, 2, 0, np.inf, np.inf, -8],
    [-4, np.inf, np.inf, 0, 3, np.inf],
    [np.inf, 7, np.inf, np.inf, 0, np.inf],
    [np.inf, 5, 10, np.inf, np.inf, 0]
])

# Compute the shortest paths between all pairs of vertices
L_slow = slow_all_pairs_shortest_paths(graph)
L_fast, history = faster_all_pairs_shortest_paths(graph)

# Print the history
for i, L in enumerate(history):
    print(f"m = {2**i}")
    print(L)
    print()

m = 1
[[ 0. inf inf inf -1. inf]
 [ 1.  0. inf  2. inf inf]
 [inf  2.  0. inf inf -8.]
 [-4. inf inf  0.  3. inf]
 [inf  7. inf inf  0. inf]
 [inf  5. 10. inf inf  0.]]

m = 2
[[ 0.  6. inf inf -1. inf]
 [-2.  0. inf  2.  0. inf]
 [ 3. -3.  0.  4. inf -8.]
 [-4. 10. inf  0. -5. inf]
 [ 8.  7. inf  9.  0. inf]
 [ 6.  5. 10.  7. inf  0.]]

m = 4
[[ 0.  6. inf  8. -1. inf]
 [-2.  0. inf  2. -3. inf]
 [-5. -3.  0. -1. -3. -8.]
 [-4.  2. inf  0. -5. inf]
 [ 5.  7. inf  9.  0. inf]
 [ 3.  5. 10.  7.  2.  0.]]

m = 8
[[ 0.  6. inf  8. -1. inf]
 [-2.  0. inf  2. -3. inf]
 [-5. -3.  0. -1. -6. -8.]
 [-4.  2. inf  0. -5. inf]
 [ 5.  7. inf  9.  0. inf]
 [ 3.  5. 10.  7.  2.  0.]]

