In [None]:
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm  # For progress bars
import warnings
from collections import defaultdict
import time
from concurrent.futures import ThreadPoolExecutor, TimeoutError

# Suppress networkx warnings about path existence
warnings.filterwarnings('ignore', category=UserWarning)

def generate_topology(topology, n_nodes=10000, n_edges=50000):
    """Generate network topologies with reduced size for feasibility"""
    if topology == 'RN':
        G = nx.gnm_random_graph(n_nodes, n_edges)
    elif topology == 'BAN':
        m = max(1, n_edges // n_nodes)  # Ensure m ≥ 1
        G = nx.barabasi_albert_graph(n_nodes, m)
    elif topology == 'SLN':
        side = int(np.sqrt(n_nodes))
        G = nx.grid_2d_graph(side, side)
        G = nx.convert_node_labels_to_integers(G)
    elif topology == 'TLN':
        G = nx.triangular_lattice_graph(int(np.sqrt(n_nodes)), int(np.sqrt(n_nodes)))
        G = nx.convert_node_labels_to_integers(G)
    elif topology == 'HLN':
        G = nx.hexagonal_lattice_graph(int(np.sqrt(n_nodes)), int(np.sqrt(n_nodes)))
        G = nx.convert_node_labels_to_integers(G)
    else:
        raise ValueError(f"Unsupported topology: {topology}")

    # Ensure we have exactly n_edges
    while G.number_of_edges() < n_edges:
        u, v = np.random.choice(G.nodes(), 2, replace=False)
        if not G.has_edge(u, v):
            G.add_edge(u, v)
    return G

def purification(C1, C2):
    """Deutsch's purification protocol"""
    return (13 + 14*C1 + 14*C2 + 40*C1*C2) / (41 + 4*C1 + 4*C2 + 32*C1*C2)

def get_alternate_paths(G, u, v, main_path, max_path_length=5):
    """Find alternate paths between u and v excluding main path nodes"""
    forbidden = set(main_path) - {u, v}
    subgraph = G.copy()
    subgraph.remove_nodes_from(forbidden)

    paths = []
    try:
        # Limit path search with cutoff
        for path in nx.all_simple_paths(subgraph, u, v, cutoff=max_path_length):
            paths.append(len(path) - 1)  # Store path lengths
    except nx.NetworkXNoPath:
        pass

    # Include direct edge if exists
    if G.has_edge(u, v):
        paths.append(1)
    return paths

def simulate_strategy(G, pairs, delta, strategy='SPF'):
    """Simulate either SPF or SPL strategy"""
    concurrences = []

    for s, d, path in tqdm(pairs, desc=f"Processing {strategy}"):
        C_total = 1.0

        for i in range(len(path) - 1):
            u, v = path[i], path[i+1]
            path_lengths = get_alternate_paths(G, u, v, path)

            if not path_lengths:
                C_total = 0
                break

            # Sort paths according to strategy
            if strategy == 'SPF':
                path_lengths = sorted(path_lengths)
            else:  # SPL
                path_lengths = sorted(path_lengths, reverse=True)

            # Compute purified concurrence
            Cs = [max(0, 1 - l * delta) for l in path_lengths]
            current_C = Cs[0]
            for c in Cs[1:]:
                current_C = purification(current_C, c)

            C_total *= current_C

        concurrences.append(C_total)

    return np.mean(concurrences) if concurrences else 0.0

def main_simulation(topologies=['RN', 'BAN', 'SLN', 'TLN', 'HLN'],
                   delta=0.02,
                   l0_values=range(1, 7),
                   num_samples=10000):
    """Main simulation function with optimized sampling"""
    results = {topo: {'SPF': {}, 'SPL': {}} for topo in topologies}

    for topo in topologies:
        print(f"\n=== Processing {topo} topology ===")
        G = generate_topology(topo)
        print(f"Generated {topo} with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

        # Precompute all shortest paths
        print("Precomputing path relationships...")
        path_dict = dict(nx.all_pairs_shortest_path_length(G))

        # Precompute l0 -> list of (s,d) pairs mapping
        l0_to_pairs = defaultdict(list)
        for s in G.nodes():
            for d in G.nodes():
                if s != d:
                    try:
                        l0 = path_dict[s][d]
                        if l0 in l0_values:
                            l0_to_pairs[l0].append((s, d))
                    except KeyError:
                        pass

        for l0 in l0_values:
            print(f"\nProcessing l0 = {l0}")
            pairs = []

            # Get available pairs for this l0
            candidate_pairs = l0_to_pairs.get(l0, [])

            if not candidate_pairs:
                print(f"No S-D pairs found for l0={l0}")
                results[topo]['SPF'][l0] = 0.0
                results[topo]['SPL'][l0] = 0.0
                continue

            # Shuffle and sample efficiently
            np.random.shuffle(candidate_pairs)
            max_available = min(num_samples, len(candidate_pairs))
            selected_pairs = candidate_pairs[:max_available]

            # Convert to paths with progress tracking
            print(f"Found {len(selected_pairs)} candidate pairs, converting to paths...")
            valid_pairs = []
            for s, d in tqdm(selected_pairs, desc="Path conversion"):
                try:
                    path = nx.shortest_path(G, s, d)
                    valid_pairs.append((s, d, path))
                except nx.NetworkXNoPath:
                    continue

            if not valid_pairs:
                print(f"No valid paths remained for l0={l0}")
                results[topo]['SPF'][l0] = 0.0
                results[topo]['SPL'][l0] = 0.0
                continue

            print(f"Processing {len(valid_pairs)} valid pairs")

            # Time-constrained processing
            start_time = time.time()
            timeout = 600  # 10 minutes per strategy

            try:
                # SPF strategy with timeout
                with ThreadPoolExecutor() as executor:
                    spf_future = executor.submit(simulate_strategy, G, valid_pairs, delta, 'SPF')
                    results[topo]['SPF'][l0] = spf_future.result(timeout=timeout)
            except TimeoutError:
                print(f"SPF strategy timed out for l0={l0}")
                results[topo]['SPF'][l0] = 0.0

            try:
                # SPL strategy with timeout
                with ThreadPoolExecutor() as executor:
                    spl_future = executor.submit(simulate_strategy, G, valid_pairs, delta, 'SPL')
                    results[topo]['SPL'][l0] = spl_future.result(timeout=timeout)
            except TimeoutError:
                print(f"SPL strategy timed out for l0={l0}")
                results[topo]['SPL'][l0] = 0.0

            print(f"Completed in {time.time()-start_time:.2f}s")

    plot_results(results, l0_values)

def plot_results(results, l0_values):
    """Visualize the simulation results"""
    plt.figure(figsize=(12, 6))

    for topo in results.keys():
        spf_values = [results[topo]['SPF'].get(l0, 0) for l0 in l0_values]
        spl_values = [results[topo]['SPL'].get(l0, 0) for l0 in l0_values]

        plt.plot(l0_values, spf_values, 'o-', label=f'{topo} SPF')
        plt.plot(l0_values, spl_values, 'x--', label=f'{topo} SPL')

    plt.xlabel('Shortest Path Length (l0)')
    plt.ylabel('Average Concurrence')
    plt.title('Localized Multipath Entanglement Purification')
    plt.legend()
    plt.grid(True)
    plt.show()

# Run the simulation with reduced parameters
if __name__ == "__main__":
    main_simulation(
        topologies=['RN', 'BAN', 'SLN', 'TLN', 'HLN'],  # Test with simpler topologies first
        l0_values=range(1, 7),            # Test shorter path lengths
        num_samples=10000                    # Reduced sample size
    )