# Chapter 4: Matrix-Matrix Multiply

A² counts 2-hop paths, A³ counts 3-hop paths. Transitive closure finds all reachability.

In [None]:
import graphblas as gb
from graphblas import Matrix, semiring, binary
import networkx as nx
import matplotlib.pyplot as plt

## Graph Setup

In [None]:
edges = [(0,1), (1,2), (2,3), (0,2)]
rows, cols = zip(*edges)

A = Matrix.from_coo(rows, cols, [1]*len(edges), nrows=4, ncols=4, dtype=int)
print("Adjacency matrix A:")
print(A)

In [None]:
G = nx.DiGraph(edges)
pos = {0: (0,1), 1: (1,1), 2: (1,0), 3: (2,0)}
nx.draw(G, pos, with_labels=True, node_color='lightblue', 
        node_size=500, font_size=16, arrows=True)
plt.show()

## A² = 2-Hop Paths

In [None]:
A2 = A.mxm(A, semiring.plus_times).new()
print("A² (2-hop path counts):")
print(A2)
print("\nA²[0,2] = 1: one 2-hop path from 0 to 2 (via 1)")
print("A²[0,3] = 1: one 2-hop path from 0 to 3 (via 2)")

## A³ = 3-Hop Paths

In [None]:
A3 = A2.mxm(A, semiring.plus_times).new()
print("A³ (3-hop path counts):")
print(A3)
print("\nA³[0,3] = 2: two 3-hop paths from 0 to 3")
print("  Path 1: 0 -> 1 -> 2 -> 3")
print("  Path 2: 0 -> 2 -> ? (none, so actually just one path)")

## Transitive Closure

In [None]:
def transitive_closure(A):
    """Compute transitive closure using repeated squaring."""
    n = A.nrows
    A_bool = A.apply(binary.pair).new(dtype=bool)  # Convert to boolean
    
    # Add self-loops (identity)
    TC = A_bool.dup()
    for i in range(n):
        TC[i, i] = True
    
    # Repeated squaring until convergence
    prev_nvals = 0
    while TC.nvals != prev_nvals:
        prev_nvals = TC.nvals
        TC = TC.mxm(TC, semiring.any_pair).new()
    
    return TC

TC = transitive_closure(A)
print("Transitive closure:")
print(TC)
print("\nTC[i,j] = True means j is reachable from i")

In [None]:
# Visualize reachability from node 0
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.set_title("Original Graph")
nx.draw(G, pos, with_labels=True, node_color='lightblue', 
        node_size=500, font_size=16, arrows=True, ax=ax1)

ax2.set_title("Reachability from Node 0")
reachable = [j for j in range(4) if TC[0, j].get(False)]
colors = ['lightgreen' if i in reachable else 'lightgray' for i in range(4)]
nx.draw(G, pos, with_labels=True, node_color=colors, 
        node_size=500, font_size=16, arrows=True, ax=ax2)
plt.show()