# Chapter 3: Masking and BFS

Masks control which elements participate in operations. Complement masks enable efficient BFS.

In [None]:
import graphblas as gb
from graphblas import Matrix, Vector, semiring, binary
import networkx as nx
import matplotlib.pyplot as plt

## Graph Setup

In [None]:
# Slightly larger graph for BFS
edges = [(0,1), (0,2), (1,3), (2,3), (2,4), (3,5), (4,5)]
rows, cols = zip(*edges)

A = Matrix.from_coo(rows, cols, [True]*len(edges), nrows=6, ncols=6, dtype=bool)
# Make undirected
A = A.ewise_add(A.T, binary.any).new()
print(A)

In [None]:
G = nx.Graph()
G.add_edges_from(edges)
pos = nx.spring_layout(G, seed=42)
nx.draw(G, pos, with_labels=True, node_color='lightblue', 
        node_size=500, font_size=16)
plt.title("Undirected Graph")
plt.show()

## BFS Without Masking (Wrong)

In [None]:
# Starting from node 0
frontier = Vector.from_coo([0], [True], size=6, dtype=bool)

print("Level 0:", list(frontier.to_coo()[0]))
for i in range(1, 4):
    frontier = frontier.mxv(A, semiring.any_pair).new()
    print(f"Level {i}:", list(frontier.to_coo()[0]))

print("\nProblem: nodes get revisited!")

## BFS With Complement Mask

In [None]:
def bfs_levels(A, source):
    """BFS returning level of each node from source."""
    n = A.nrows
    levels = Vector(int, size=n)
    frontier = Vector.from_coo([source], [True], size=n, dtype=bool)
    
    level = 0
    while frontier.nvals > 0:
        # Record level for nodes in current frontier
        levels(mask=frontier.S) << level
        print(f"Level {level}: nodes {list(frontier.to_coo()[0])}")
        
        # Expand to neighbors
        frontier = frontier.mxv(A, semiring.any_pair).new()
        
        # Mask out already-visited nodes (complement mask)
        frontier(mask=~levels.S, replace=True) << frontier
        level += 1
    
    return levels

levels = bfs_levels(A, source=0)
print("\nFinal levels:")
print(levels)

## BFS With Parent Tracking

In [None]:
def bfs_parents(A, source):
    """BFS returning parent of each node in BFS tree."""
    n = A.nrows
    parents = Vector(int, size=n)
    parents[source] = source  # Source is its own parent
    
    # Frontier stores (node -> parent) pairs
    frontier = Vector.from_coo([source], [source], size=n, dtype=int)
    
    while frontier.nvals > 0:
        # For each edge (u,v) where u is in frontier, v gets parent u
        # mxv with min_second: if multiple parents, pick smallest
        frontier = A.mxv(frontier, semiring.min_second).new()
        
        # Only keep unvisited nodes
        frontier(mask=~parents.S, replace=True) << frontier
        
        # Record parents
        parents(mask=frontier.S) << frontier
    
    return parents

parents = bfs_parents(A, source=0)
print("BFS parents:")
print(parents)

In [None]:
# Visualize BFS tree
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Original graph
ax1.set_title("Original Graph")
nx.draw(G, pos, with_labels=True, node_color='lightblue', 
        node_size=500, font_size=16, ax=ax1)

# BFS tree
ax2.set_title("BFS Tree (from node 0)")
tree = nx.DiGraph()
p_indices, p_values = parents.to_coo()
for child, parent in zip(p_indices, p_values):
    if child != parent:
        tree.add_edge(parent, child)
nx.draw(tree, pos, with_labels=True, node_color='lightgreen', 
        node_size=500, font_size=16, arrows=True, ax=ax2)
plt.show()