In [None]:
import time
from typing import Callable

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import ScalarFormatter
from tqdm import tqdm

from nnx.data_structures.union_find import BetterUnionFind, NaiveUnionFind

seed = 3

rng = np.random.default_rng(seed=seed)

In [None]:
def generate_operations(num_elements: int) -> list[tuple[int, int]]:
    """Generate a sequence of random merge operations for Union Find testing.

    Args:
        num_elements: Number of elements for union find.

    Returns:
        List of (node1, node2) tuples representing merge operations.

    """
    operations = []
    for _ in range(num_elements):
        # Generate random pair of nodes to merge
        a = rng.integers(0, num_elements - 1)
        b = rng.integers(0, num_elements - 1)

        operations.append((a, b))

    return operations

In [None]:
def time_algo(
    algo: NaiveUnionFind | BetterUnionFind,
    operations: list,
) -> float:
    """Time the computation of the algorithm for a given set of operations.

    Args:
        algo: the algorithm to benchmark
        operations: set of operations to perform

    Returns:
        The time needed to perform all operations

    """
    start = time.perf_counter()
    for operation in operations:
        algo.union(*operation)

    return time.perf_counter() - start


times_naive = []
times_better = []

sizes = range(10, int(2.5e3), 10)
samples = 50

for num_elements in tqdm(sizes):
    sampling_naive = []
    sampling_better = []
    for _ in range(samples):
        operations = generate_operations(num_elements=num_elements)

        sampling_naive.append(time_algo(NaiveUnionFind(num_elements), operations))
        sampling_better.append(time_algo(BetterUnionFind(num_elements), operations))

    # taking quantiles instead of average to have reduce effect of outliers
    times_naive.append(np.quantile(sampling_naive, 0.9))
    times_better.append(np.quantile(sampling_better, 0.9))


In [None]:
def _log_star(n: int) -> float:
    if n <= 1:
        return 0
    count = 0
    while n > 1:
        n = np.log(n)
        count += 1
    return count


def theoretical_complexity(
    n: int,
    complexity_fn: Callable[[int], float],
    scale_factor: float,
) -> float:
    """Calculate theoretical complexity based on input size.

    Args:
        n: Input size
        complexity_fn: Function that calculates the complexity (e.g., n^2)
        scale_factor: Scaling factor to match with actual times

    Returns:
        Theoretical time

    """
    return complexity_fn(n) * scale_factor


def plot_union_find_complexity(
    sizes: list[int], times_naive: list[float], times_better: list[float]
) -> None:
    """Create a plot comparing actual times with theoretical complexity.

    Args:
        sizes: List of problem sizes
        times_naive: Execution times for naive algorithm
        times_better: Execution times for better algorithm

    """
    sizes_array = np.array(sizes)

    fig, ax = plt.subplots(figsize=(12, 8))

    ax.plot(sizes, times_naive, "b-", label="Naive Union Find (Actual)", linewidth=2)
    ax.plot(
        sizes,
        times_better,
        "g-",
        label="Weighted Union Find with Path Compression (Actual)",
        linewidth=2,
    )

    def _naive_complexity_fn(n: int) -> float:
        return n**2

    def _better_complexity_fn(n: int) -> float:
        return n + n * _log_star(n)

    naive_complexity_values = np.array([_naive_complexity_fn(n) for n in sizes_array])
    better_complexity_values = np.array([_better_complexity_fn(n) for n in sizes_array])

    # For naive: O(n**2)
    naive_scale = (
        np.mean(np.array(times_naive) / naive_complexity_values)
        if len(times_naive) > 0
        else 1e-10
    )

    # For better: O(n + m log* n)
    better_scale = (
        np.mean(np.array(times_better) / better_complexity_values)
        if len(times_better) > 0
        else 1e-10
    )

    # Generate theoretical curves with the scaling factors
    theo_naive = naive_complexity_values * naive_scale
    theo_better = better_complexity_values * better_scale

    ax.plot(
        sizes,
        theo_naive,
        "b--",
        label="Naive Union Find (Theoretical O(n²))",
        alpha=0.7,
    )
    ax.plot(
        sizes,
        theo_better,
        "g--",
        label="Weighted UF with Path Compression (Theoretical O(n + m log* n))",
        alpha=0.7,
    )

    if len(times_naive) > 0 and len(times_better) > 0:
        speedup = np.array(times_naive) / np.array(times_better)
        max_speedup_idx = np.argmax(speedup)
        max_speedup = speedup[max_speedup_idx]
        max_speedup_size = sizes[max_speedup_idx]

        # Add annotation for maximum speedup
        ax.annotate(
            f"Max Speedup: {max_speedup:.2f}x at n={max_speedup_size}",
            xy=(max_speedup_size, times_naive[max_speedup_idx]),
            xytext=(max_speedup_size + 300, times_naive[max_speedup_idx] * 1.2),
            arrowprops={
                "facecolor": "black",
                "shrink": 0.05,
                "width": 1.5,
                "headwidth": 8,
            },
            fontsize=10,
        )

    # Add log scale for y-axis to better visualize differences
    ax.set_yscale("log")

    ax.grid(visible=True, linestyle="--", alpha=0.7)

    # Add labels and title
    ax.set_xlabel("Number of Elements (n)", fontsize=12)
    ax.set_ylabel("Time (seconds)", fontsize=12)
    ax.set_title("Union Find Algorithms: Time Complexity Comparison", fontsize=14)

    ax.legend(fontsize=10)

    # Format y-axis to show actual time values
    ax.yaxis.set_major_formatter(ScalarFormatter())

    # Add a secondary y-axis for showing the speedup
    if len(times_naive) > 0 and len(times_better) > 0:
        ax2 = ax.twinx()
        ax2.plot(sizes, speedup, "r-", label="Speedup Factor", alpha=0.6)
        ax2.set_ylabel("Speedup Factor (Naive / Better)", color="r", fontsize=12)
        ax2.tick_params(axis="y", labelcolor="r")
        ax2.legend(loc="lower right", fontsize=10)

    plt.tight_layout()
    plt.savefig("union_find_complexity.png", dpi=300)
    plt.show()

    return fig


fig = plot_union_find_complexity(sizes, times_naive, times_better)