In [None]:
import time
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter
from tqdm import tqdm

from nnx.data_structures.union_find import BetterUnionFind, NaiveUnionFind

seed = 3

rng = np.random.default_rng(seed=seed)

# Outline
-----
Initially we start with a comparison of the algorithms for Union-Find. As a starting point we implemnt a naive version is also know as Quick-Find where in worst case we can have $O(N^{2})$ runtime. Afterwards, we optimize it by adding a weighting to reduce the unnecessary risk of having subgraphs with high depth and also using path reduction. Beneficial about it that we can reduce the worst case scenario to $O(N+M \times ln^{*}(N))$.

![union_find_complexity.png](images/union_find/union_find_complexity.png)

Finally, we then move to an actual problem of appliance for this data structure. Namely, estimating percolation threshold using Monte Carlo methods.

In [None]:
def generate_operations(num_elements: int) -> list[tuple[int, int]]:
    """Generate a sequence of random merge operations for Union Find testing.

    Args:
        num_elements: Number of elements for union find.

    Returns:
        List of (node1, node2) tuples representing merge operations.

    """
    operations = []
    for _ in range(num_elements):
        # Generate random pair of nodes to merge
        a = rng.integers(0, num_elements - 1)
        b = rng.integers(0, num_elements - 1)

        operations.append((a, b))

    return operations

# Calculate Differences in Timing
------
For the same configuration we can sample different operations and calculate the time it needed to compute these. Since we talk of asymptotic runtimes, we need to sample different constallations for the same hyperparameters for reduce the effect of noise. In this case the `num_elements`. When using a large `num_trials` when can generate many runs for the same configuration of hyperparameters and come close to the asymptotical complexity. 

To not overcomplicate matters, we run it for `num_trials=50` and report the median.

In [None]:
def time_algo(
    algo: NaiveUnionFind | BetterUnionFind,
    operations: list,
) -> float:
    """Time the computation of the algorithm for a given set of operations.

    Args:
        algo: the algorithm to benchmark
        operations: set of operations to perform

    Returns:
        The time needed to perform all operations

    """
    start = time.perf_counter()
    for operation in operations:
        algo.union(*operation)

    return time.perf_counter() - start


times_naive = []
times_better = []

sizes = range(10, int(2.5e3), 10)
num_trials = 50  # reduce this to get quicker but noisier results

for num_elements in tqdm(sizes):
    sampling_naive = []
    sampling_better = []
    for _ in range(num_trials):
        operations = generate_operations(num_elements=num_elements)

        sampling_naive.append(time_algo(NaiveUnionFind(num_elements), operations))
        sampling_better.append(time_algo(BetterUnionFind(num_elements), operations))

    # taking quantiles instead of average to have reduce effect of outliers
    times_naive.append(np.quantile(sampling_naive, 0.5))
    times_better.append(np.quantile(sampling_better, 0.5))


In [None]:
def _log_star(n: int) -> float:
    if n <= 1:
        return 0
    count = 0
    while n > 1:
        n = np.log(n)
        count += 1
    return count


def theoretical_complexity(
    n: int,
    complexity_fn: Callable[[int], float],
    scale_factor: float,
) -> float:
    """Calculate theoretical complexity based on input size.

    Args:
        n: Input size
        complexity_fn: Function that calculates the complexity (e.g., n^2)
        scale_factor: Scaling factor to match with actual times

    Returns:
        Theoretical time

    """
    return complexity_fn(n) * scale_factor


def plot_union_find_complexity(
    sizes: list[int],
    times_naive: list[float],
    times_better: list[float],
) -> None:
    """Create a plot comparing actual times with theoretical complexity.

    Args:
        sizes: List of problem sizes
        times_naive: Execution times for naive algorithm
        times_better: Execution times for better algorithm

    """
    sizes_array = np.array(sizes)

    _, ax = plt.subplots(figsize=(12, 8))

    ax.plot(sizes, times_naive, "b-", label="Naive Union Find (Actual)", linewidth=2)
    ax.plot(
        sizes,
        times_better,
        "g-",
        label="Weighted Union Find with Path Compression (Actual)",
        linewidth=2,
    )

    def _naive_complexity_fn(n: int) -> float:
        return n**2

    def _better_complexity_fn(n: int) -> float:
        return n + n * _log_star(n)

    naive_complexity_values = np.array([_naive_complexity_fn(n) for n in sizes_array])
    better_complexity_values = np.array([_better_complexity_fn(n) for n in sizes_array])

    # For naive: O(n**2)
    naive_scale = (
        np.mean(np.array(times_naive) / naive_complexity_values)
        if len(times_naive) > 0
        else 1e-10
    )

    # For better: O(n + m log* n)
    better_scale = (
        np.mean(np.array(times_better) / better_complexity_values)
        if len(times_better) > 0
        else 1e-10
    )

    # Generate theoretical curves with the scaling factors
    theo_naive = naive_complexity_values * naive_scale
    theo_better = better_complexity_values * better_scale

    ax.plot(
        sizes,
        theo_naive,
        "b--",
        label="Naive Union Find (Theoretical O(n²))",
        alpha=0.7,
    )
    ax.plot(
        sizes,
        theo_better,
        "g--",
        label="Weighted UF with Path Compression (Theoretical O(n + m log* n))",
        alpha=0.7,
    )

    if len(times_naive) > 0 and len(times_better) > 0:
        speedup = np.array(times_naive) / np.array(times_better)
        max_speedup_idx = np.argmax(speedup)
        max_speedup = speedup[max_speedup_idx]
        max_speedup_size = sizes[max_speedup_idx]

        # Add annotation for maximum speedup
        ax.annotate(
            f"Max Speedup: {max_speedup:.2f}x at n={max_speedup_size}",
            xy=(max_speedup_size, times_naive[max_speedup_idx]),
            xytext=(max_speedup_size + 300, times_naive[max_speedup_idx] * 1.2),
            arrowprops={
                "facecolor": "black",
                "shrink": 0.05,
                "width": 1.5,
                "headwidth": 8,
            },
            fontsize=10,
        )

    # Add log scale for y-axis to better visualize differences
    ax.set_yscale("log")

    ax.grid(visible=True, linestyle="--", alpha=0.7)

    # Add labels and title
    ax.set_xlabel("Number of Elements (n)", fontsize=12)
    ax.set_ylabel("Time (seconds)", fontsize=12)
    ax.set_title("Union Find Algorithms: Time Complexity Comparison", fontsize=14)

    ax.legend(fontsize=10)

    # Format y-axis to show actual time values
    ax.yaxis.set_major_formatter(ScalarFormatter())

    # Add a secondary y-axis for showing the speedup
    if len(times_naive) > 0 and len(times_better) > 0:
        ax2 = ax.twinx()
        ax2.plot(sizes, speedup, "r-", label="Speedup Factor", alpha=0.6)
        ax2.set_ylabel("Speedup Factor (Naive / Better)", color="r", fontsize=12)
        ax2.tick_params(axis="y", labelcolor="r")
        ax2.legend(fontsize=10)

    plt.tight_layout()
    plt.savefig("union_find_complexity.png", dpi=300)
    plt.show()


plot_union_find_complexity(sizes, times_naive, times_better)

# Monte Carlo Simulation for Percolation
------
## Overview
Percolation theory studies how connectivity emerges in random networks. This simulation implements site percolation on a 2D square using Monte Carlo methods.

## Methodology
The simulation works by:
1. Starting with an empty grid where all nodes are inactive
2. Randomly opening nodes one by one
3. Using a Union-Find data structure to efficiently track connected clusters
4. Determining the critical threshold when a connected path forms from top to bottom
5. Repeating this process multiple times to obtain statistical estimates

## Percolation Threshold
The percolation threshold is the ratio (active to inactive nodes) at which a connected path appears. In the implementation this is measured as the ratio of active nodes to total nodes at the moment of percolation. Increasing the grid size leads to a sharper transistion and ultimately for an infinite 2D square, the theoretical threshold is approximately 0.593. 

![percolation_probabilities](images/union_find/percolation_probabilities.png)

## Implementation Notes
This code uses a weighted quick-union algorithm with path compression for efficient cluster identification. Virtual top and bottom nodes are used to simplify the detection of a percolating path.

In [None]:
def translate_index(row: int, col: int, grid_size: int) -> int:
    """Translate row and col from grid to flattened.

    Raises:
        RuntimeError: in the case where range exceeds the allowed range.

    Returns:
        Translated index.

    """
    if (abs(row) > grid_size) or (abs(col) > grid_size):
        msg = f"{row} and {col} exceed the allowed range of [0, {grid_size - 1}]."
        raise RuntimeError(msg)
    row = row if row >= 0 else grid_size + row
    col = col if col >= 0 else grid_size + col

    return grid_size * row + col


def simulate_experiment(grid_size: int) -> float:
    """Simulate a single trial.

    Returns:
        The percolation threshold which is the ratio of black and white.

    """
    grid = np.zeros(shape=(grid_size, grid_size))
    algo = BetterUnionFind(grid_size * grid_size + 2)  # rember the virtual nodes

    entry_node = grid_size * grid_size
    exit_node = grid_size * grid_size + 1

    # Connect virtual top to first row and virtual bottom to last row
    for col_idx in range(grid_size):
        algo.union(entry_node, translate_index(0, col_idx, grid_size))
        algo.union(exit_node, translate_index(grid_size - 1, col_idx, grid_size))

    ops = [(0, 1), (0, -1), (-1, 0), (1, 0)]
    while not algo.connected(entry_node, exit_node):
        while True:
            idx = rng.integers(0, grid_size)
            idy = rng.integers(0, grid_size)
            if not grid[idx, idy]:
                break

        grid[idx, idy] = True
        grid_idx = translate_index(idx, idy, grid_size)

        for cx, cy in ops:
            idcx = idx + cx
            idcy = idy + cy
            # only take valid and open neighbours
            if (0 <= idcx < grid_size) and (0 <= idcy < grid_size) and grid[idcx, idcy]:
                neighbor_idx = translate_index(idcx, idcy, grid_size)
                algo.union(grid_idx, neighbor_idx)

    return np.sum(grid) / grid_size**2


def monte_carlo_simulation(num_trials: int) -> dict[int, list[float]]:
    """Execute the Monte Carlo simulation.

    Args:
        num_trials: amount of runs for a single configuration.

    Returns:
        Dict containing the history (trials) for each configuration.

    """
    results = {}
    for grid_size in [5, 25, 50, 100]:
        results[grid_size] = [
            simulate_experiment(grid_size=grid_size) for _ in range(num_trials)
        ]

    return results

In [None]:
def plot_percolation_probability_curve(
    results: dict[int, list[float]],
) -> None:
    """Plot percolation probability vs site occupation probability.

    Args:
        results: dictionary containing the histories for each grid size.
        num_thresholds: number of probability thresholds to check.

    """
    probability_thresholds = np.linspace(0, 1, 100)

    percolation_probabilities = {}
    for size, thresholds in results.items():
        percolation_probabilities[size] = []

        for p in probability_thresholds:
            # Calculate the probability of percolation by counting how many
            # trials had a threshold less than or equal to p
            num_percolating = sum(1 for t in thresholds if t <= p)
            probability = num_percolating / len(thresholds)
            percolation_probabilities[size].append(probability)

    _, ax = plt.subplots(figsize=(10, 7))

    theoretical_threshold = 0.592746  # Theoretical threshold for 2D site percolation

    colors = plt.cm.viridis(np.linspace(0, 1, len(results)))

    for i, size in enumerate(results):
        ax.plot(
            probability_thresholds,
            percolation_probabilities[size],
            "o-",
            color=colors[i],
            label=f"L = {size}",
            markersize=4,
            linewidth=2,
            alpha=0.8,
        )

    ax.axvline(
        x=theoretical_threshold,
        color="red",
        linestyle="--",
        label=f"Theoretical threshold ≈ {theoretical_threshold:.4f}",
    )

    ax.axhline(y=0, color="gray", linestyle="-", alpha=0.3)
    ax.axhline(y=1, color="gray", linestyle="-", alpha=0.3)

    ax.set_xlabel("Site Occupation Probability (p)", fontsize=12)
    ax.set_ylabel("Percolation Probability", fontsize=12)
    ax.set_title("Percolation Probability vs. Site Occupation Probability", fontsize=14)
    ax.grid(visible=True, alpha=0.3)
    ax.legend(fontsize=10)

    plt.tight_layout()
    plt.savefig("percolation_probabilities.png", dpi=300, bbox_inches="tight")


In [None]:
results = monte_carlo_simulation(100)
plot_percolation_probability_curve(results)