In [None]:
import warnings
warnings.filterwarnings("ignore", message="networkx backend defined more than once")
import sys
import os
import gc
import csv
import copy
import json
import math
import shutil
import random
import pickle
import itertools
import matplotlib
matplotlib.use('Agg')  # 非交互式后端
import numpy as np
import networkx as nx
import tqdm as tqdm
import matplotlib.pyplot as plt
from pathlib import Path
from protocols import MPC_protocol, MPG_protocol, SP_protocol
from graph import network, set_p_edge

from joblib import Parallel, delayed
from networkx.algorithms.community import greedy_modularity_communities
from networkx.drawing.layout import *

notebook_path = os.path.abspath("")

from config import DATA_PATHS



In [None]:
mkr = ['x','+','d','o','1','2']+['x']*100
dashs = ['-.','--',':','-']+['-']*100
cols = ['gray','g','b','orange','r','k','purple']+['k']*100
linewidth = 2.2
mks = 5.5
fontsize = 14
sys.path.append("..")
root_path = DATA_PATHS["input_graphs"]
LOOP_STATE_PATH = "loop_state.pkl"
TEMP_DATA_PATH = "temp_data.pkl"

Find the ER for the MPC, MPG, and SP protocols

In [None]:
import json
import networkx as nx

def load_data(filepath):
    """
    Load a network graph from a JSON file in node-link format.

    The JSON file is expected to contain:
      - "nodes": a list of nodes, each with fields "id", "latitude", "longitude", "location", and "country"
      - "links": a list of edges, each with fields "source", "target", and "length"

    This function:
      - Builds a NetworkX graph with node and edge attributes
      - Stores node positions using (longitude, latitude) format
      - Collects the ID of the first node (assumed to be the fixed or user node)
      - Prints the degree of the first node for verification

    Args:
        filepath (str or Path): Path to the input JSON file.

    Returns:
        tuple:
            - G (networkx.Graph): The constructed graph.
            - user (list): A list containing the first node's ID.
            - pos (dict): Mapping from node ID to (longitude, latitude) positions.
    """
    pos = {}
    user = []

    # Step 1: Read JSON file
    with open(filepath, "r") as f:
        data = json.load(f)

    # Step 2: Initialize graph
    G = nx.Graph()

    # Step 3: Add nodes
    for node in data["nodes"]:
        node_id = node["id"]
        x, y = node["latitude"], node["longitude"]
        G.add_node(node_id, location=node["location"], country=node["country"])  # Add node to graph
        pos[node_id] = (y, x)  # Store node position as (longitude, latitude)

    # Step 4: Add edges
    for edge in data["links"]:
        source = int(edge["source"])
        target = int(edge["target"])
        G.add_edge(source, target, length=edge["length"])  # Add edge to graph

    degree_dict = dict(G.degree())
    degree_items = list(degree_dict.items())
    first_node, first_degree = degree_items[0]
    print(f"First node ID: {first_node}, Degree: {first_degree}")

    user.append(data["nodes"][0]["id"])

    return G, user, pos


In [None]:
def multi_iterative_score_partition_with_drawing(
    G, 
    fixed_node, 
    alpha=1.0, 
    beta=1.0,
    max_rounds=10,
    shuffle_nodes=True,
    pos=None,
    output_path=None
):
    """
    Multi-round iterative score-based community partitioning with optional visualization:
      - Number of communities = degree(fixed_node) + 1
      - Initialization: fixed_node goes to community 0; each neighbor to a separate community
      - Assignment: score = alpha * distance + beta * community_size
      - Iteration: if moving reduces the score, move the node; repeat until stable or max_rounds
      - Optionally draws and saves the graph with colored communities

    Returns:
        communities (list of sets): Final partitioned communities.
        all_key_nodes_combos (list of lists): All combinations by picking one node from each non-zero community,
                                              with fixed_node prepended to each.
    """
    # ============ 1) Initialize community containers ============
    neighbors = list(G.neighbors(fixed_node))
    num_communities = len(neighbors)  # Number of communities excluding community 0
    communities = [set() for _ in range(num_communities + 1)]
    visited = set()

    # Community 0 contains the fixed_node
    communities[0].add(fixed_node)
    visited.add(fixed_node)

    # Each neighbor starts in its own community
    for i, nb in enumerate(neighbors, start=1):
        communities[i].add(nb)
        visited.add(nb)

    # ============ 2) Initial assignment of remaining nodes ============
    for node in G.nodes():
        if node not in visited:
            best_score = float('inf')
            best_index = None
            for i, nb in enumerate(neighbors, start=1):
                dist = nx.shortest_path_length(G, source=node, target=nb)
                size = len(communities[i])
                score = alpha * dist + beta * size
                if score < best_score:
                    best_score = score
                    best_index = i
            communities[best_index].add(node)
            visited.add(node)

    # ============ 3) Multi-round migration ============
    round_num = 0
    while round_num < max_rounds:
        round_num += 1
        moved_count = 0

        # All nodes except the fixed_node
        all_nodes = [n for n in G.nodes() if n != fixed_node]

        if shuffle_nodes:
            random.shuffle(all_nodes)

        for node in all_nodes:
            current_idx = None
            for i, comm in enumerate(communities):
                if node in comm:
                    current_idx = i
                    break

            # Skip the fixed_node
            if current_idx == 0:
                continue

            # Compute current score
            nb_current = neighbors[current_idx - 1]
            dist_current = nx.shortest_path_length(G, source=node, target=nb_current)
            size_current = len(communities[current_idx])
            current_score = alpha * dist_current + beta * size_current

            best_score = current_score
            best_index = current_idx

            # Try other communities
            for i, nb in enumerate(neighbors, start=1):
                if i == current_idx:
                    continue
                dist = nx.shortest_path_length(G, source=node, target=nb)
                size = len(communities[i])
                score = alpha * dist + beta * size
                if score < best_score:
                    best_score = score
                    best_index = i

            # Move if better community is found
            if best_index != current_idx:
                communities[current_idx].remove(node)
                communities[best_index].add(node)
                moved_count += 1

        # If no node moved, stop
        if moved_count == 0:
            break

    # ============ 4) Generate all valid key node combinations ============
    all_key_nodes_combos = []

    if all(len(communities[i]) > 0 for i in range(1, num_communities + 1)):
        all_products = product(*(communities[i] for i in range(1, num_communities + 1)))
        for combo in all_products:
            combo_list = [fixed_node] + list(combo)
            all_key_nodes_combos.append(combo_list)
    else:
        all_key_nodes_combos = []

    # ============ 5) Draw graph with community coloring ============
    if pos is None:
        pos = nx.spring_layout(G, seed=42)

    colors = ["red", "blue", "green", "orange", "purple", "cyan", "yellow", "pink"]
    
    plt.figure(figsize=(8, 6))

    for i, community in enumerate(communities):
        nx.draw_networkx_nodes(
            G, pos, nodelist=community,
            node_color=colors[i % len(colors)], 
            label=f"Community {i}",
            alpha=0.8,
            node_size=100
        )

    # Highlight the fixed node separately
    nx.draw_networkx_nodes(
        G, pos, nodelist=[fixed_node],
        node_color="red", node_shape="o",
        node_size=100, alpha=0.9, label="Fixed Node"
    )

    nx.draw_networkx_edges(G, pos, edge_color="gray", alpha=0.5)
    nx.draw_networkx_labels(G, pos, font_size=8, font_color="black")

    plt.legend(
        fontsize=6,
        borderaxespad=0.5,
        labelspacing=0.2,
        loc="upper left",
        bbox_to_anchor=(1.05, 1),
    )
    plt.title("Graph with Colored Communities (no farthest-node highlight)")
    plt.tight_layout(pad=2.0)

    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Graph saved to {output_path}")
    else:
        plt.show()
    plt.close()

    # Return final communities and key node combinations
    return communities, all_key_nodes_combos


In [None]:
def plot_er_vs_p(p_range, ER, funcs, cols, output_path, fontsize=12, figsize=(10, 6), dpi=600):
    """
    Plot the relationship between ER and link generation probability p, and save the figure to file.

    Args:
        p_range (array-like): Range of p values (link generation probabilities).
        ER (list of lists): ER data for each protocol function.
        funcs (list): List of protocol functions, used to generate legend labels.
        cols (list): List of colors corresponding to each function's plot.
        output_path (str): File path where the plot will be saved.
        fontsize (int, optional): Font size for axis labels and ticks. Default is 12.
        figsize (tuple, optional): Figure size in inches. Default is (10, 6).
        dpi (int, optional): Resolution (dots per inch) of the output image. Default is 600.
    """
    # Extract function names for legend labels
    nom_list = [str(f).split(' ')[1] for f in funcs]

    plt.figure(figsize=figsize, dpi=dpi)
    plt.grid(linewidth=0.5)

    # Plot each protocol's ER values
    for i in range(len(funcs)):
        y = plt.plot(
            p_range, ER[i],
            color=cols[i],
            marker="x",
            linestyle='None',
            markersize=3,
            # alpha = 0.5,
            # linewidth = linewidth,
            label=nom_list[i]
        )

    plt.yscale('log')
    plt.legend(fontsize=10)
    plt.tick_params(labelsize=fontsize)

    plt.xlabel('Link generation probability p', fontsize=fontsize)
    plt.ylabel('ER ($\mathregular{GHZ}_5/\ \\mathregular{T_{slot}}$)', fontsize=fontsize)

    ax = plt.gca()
    ax.set_xlim([0.2, 1])
    ax.set_ylim([0.0001, 1])

    # Save the figure
    plt.savefig(output_path, dpi=dpi)
    plt.close('all')
    print(f"Plot saved to {output_path}")


In [None]:
def process_single_p(G_original, combo, p, funcs, timesteps, reps):
    """
    Worker function for evaluating entanglement rate (ER) at a single value of link generation probability p.

    This function:
        - Creates a deep copy of the original graph to avoid shared state issues in parallel processing.
        - Sets edge-level probability `p` across the graph.
        - Computes the ER value for each protocol in `funcs` given the same topology and node combination.

    Args:
        G_original (networkx.Graph): The original graph topology (will be copied internally).
        combo (list): A list of key nodes used as inputs to the protocols.
        p (float): The link generation probability to apply.
        funcs (list): List of protocol functions to evaluate.
        timesteps (int): Number of timesteps used per protocol evaluation.
        reps (int): Number of repetitions to average over.

    Returns:
        list: A list of ER values (floats), one per protocol in the same order as `funcs`.
    """
    # Create a deep copy of the graph to avoid shared state issues
    G = copy.deepcopy(G_original)

    # Set edge-level link generation probability
    set_p_edge(G, p_op=p)

    # Compute ER for each protocol
    p_ers = []
    for func in funcs:
        er, _, _ = func(G, combo, timesteps=timesteps, reps=reps)
        p_ers.append(er)

    return p_ers


In [None]:
# ============== 2) Global Parameters ==============
funcs = [MPC_protocol, MPG_protocol, SP_protocol]
p_range = np.linspace(1, 0.2, 50)

timesteps = 100
reps = 200
alpha = 1.4
beta = 0.105
max_rounds = 10
shuffle_nodes = True

sr_results = []  # SR results for each file will be appended here

# Chunked exit: exit after processing a certain number of combinations
chunk_size = 300
state_file = "loop_state.pkl"

# ============== 3) Load or Initialize Progress ==============
try:
    with open(state_file, "rb") as f:
        progress = pickle.load(f)
    print("Progress restored:", progress)
except FileNotFoundError:
    progress = {
        "subfolder_idx": 0,        # Index of current subfolder
        "file_idx": 0,             # Index of current file in subfolder
        "combo_idx": 0,            # Index of current combination in file
        "global_combo_count": 0    # Total number of combinations processed
    }
    print("No progress file found, starting from scratch.")

# List of subfolders
subfolders = [sf for sf in root_path.iterdir() if sf.is_dir()]
subfolders.sort()

# ============== 4) Main Loop ==============
for s_idx in range(progress["subfolder_idx"], len(subfolders)):
    subfolder = subfolders[s_idx]
    if not subfolder.is_dir():
        continue

    print(f"Processing subfolder: {subfolder}")

    # Iterate over files in subfolder
    files = [f for f in subfolder.iterdir() if f.is_file()]
    files.sort()

    for f_idx in range(progress["file_idx"], len(files)):
        file = files[f_idx]
        if not file.is_file():
            continue

        print(f"  Processing file: {file} ...")

        # Initialize counters for this file
        failure_counts = {func.__name__: 0 for func in funcs}
        combination_counter = 0

        # Load graph
        G, users, pos = load_data(file)

        G = network(G)

        # Output paths
        class_folder = subfolder.name
        file_path = file.with_suffix(".png")
        file_name = file_path.name

        er_folder_path = Path.cwd().parent.joinpath("new_result", class_folder)
        er_folder_path.mkdir(exist_ok=True)
        er_topology_folder_path = er_folder_path.joinpath(file_name)
        er_topology_folder_path.mkdir(exist_ok=True)
        communities_output_path = Path.cwd().parent.joinpath("communitie", class_folder, file_name)

        # Run community detection and get combinations
        communities, users_node_combination = multi_iterative_score_partition_with_drawing(
            G, users[0], alpha, beta, max_rounds, shuffle_nodes, pos, communities_output_path
        )

        # Randomly sample 100 combinations
        sampled_combinations = np.random.choice(
            len(users_node_combination),
            size=min(100, len(users_node_combination)),
            replace=False
        )

        # Process 100 combinations in this file
        for sampled_idx, original_idx in enumerate(sampled_combinations):
            combo = users_node_combination[original_idx]
            combination_counter += 1
            progress["global_combo_count"] += 1

            # =========== Combination Metadata ===========
            combination_sr = {
                "Combination_ID": f"combo_{sampled_idx}",
                "Nodes": str(combo),  # Convert list of nodes to string, e.g., "[2,5,9]"
            }

            # Compute ER matrix
            ER = np.zeros((len(funcs), len(p_range)))
            results = Parallel(n_jobs=-1, verbose=10)(
                delayed(process_single_p)(G, combo, p, funcs, timesteps, reps)
                for p in p_range
            )

            # Fill ER matrix
            for i, p_ers in enumerate(results):
                ER[:, i] = p_ers

            plot_er_vs_p(p_range, ER, funcs, cols, er_topology_folder_path.joinpath(f'result_for_{str(combo)}'))

            # =========== Calculate Success Ratios ===========
            for func_idx, func in enumerate(funcs):
                protocol_er = ER[func_idx, :]
                zero_count = np.sum(protocol_er < 1e-10)
                success_ratio = 1 - (zero_count / len(p_range))
                combination_sr[func.__name__] = round(success_ratio, 3)
            del results, ER
            gc.collect()

            # Chunked exit: exit every `chunk_size` combinations
            if progress["global_combo_count"] % chunk_size == 0:
                print(f"Processed {progress['global_combo_count']} combinations, exiting for checkpoint.")
                progress["subfolder_idx"] = s_idx
                progress["file_idx"] = f_idx
                progress["combo_idx"] = sampled_idx + 1  # Resume from next combination
                with open(state_file, "wb") as pf:
                    pickle.dump(progress, pf)
                exit()

            # Update progress after each combination
            progress["combo_idx"] = sampled_idx + 1
            progress["subfolder_idx"] = s_idx
            progress["file_idx"] = f_idx
            with open(state_file, "wb") as pf:
                pickle.dump(progress, pf)

            # =========== Write to CSV ===========
            output_subfolder_csv_path = er_topology_folder_path.joinpath(f"{file.stem}_sr_details.csv")

            # Define headers (written only on first write)
            fieldnames = ["Combination_ID", "Nodes"] + [func.__name__ for func in funcs]

            write_header = not output_subfolder_csv_path.exists()

            with open(output_subfolder_csv_path, mode="a", newline="") as subfile:
                csv_writer = csv.DictWriter(subfile, fieldnames=fieldnames)
                if write_header:
                    csv_writer.writeheader()
                csv_writer.writerow(combination_sr)

        # =========== File-Level SR Summary ===========
        sr_for_protocols = {}
        if combination_counter > 0:
            for protocol_name, failures in failure_counts.items():
                sr_for_protocols[protocol_name] = (combination_counter - failures) / combination_counter
        else:
            for protocol_name in failure_counts:
                sr_for_protocols[protocol_name] = 0

        # Add metadata (subfolder and file name)
        sr_for_protocols["Subfolder"] = subfolder.name
        sr_for_protocols["File"] = file.name
        sr_results.append(sr_for_protocols)

        # File complete => reset combo_idx and advance file_idx
        progress["combo_idx"] = 0
        progress["file_idx"] = f_idx + 1
        with open(state_file, "wb") as pf:
            pickle.dump(progress, pf)
