In [None]:
import cupy as cp
import numpy as np


# -------------------------------
# Mock Config Parameters
# -------------------------------
cost_e = 10.0
c_const = 0.5

# -------------------------------
# Sample Data
# -------------------------------
edge_index = cp.array([
    [0, 1, 2, 3],
    [1, 2, 3, 4]
])  # shape: [2, num_edges]
num_nodes = 5
prizes = {
    "nodes": cp.array([1.0, 2.0, 3.0, 4.0, 5.0]),
    "edges": cp.array([3.0, 5.0, 12.0, 2.0])  # intentionally includes > cost
}


# -------------------------------
# Original Function
# -------------------------------
def compute_subgraph_costs_original(edge_index, num_nodes, prizes):
    updated_cost_e = min(
        cost_e,
        prizes["edges"].max().item() * (1 - c_const / 2),
    )

    edges, costs = [], []
    virtual = {"n_prizes": [], "edges": [], "costs": []}
    mapping = {"nodes": {}, "edges": {}}

    for i, (src, dst) in enumerate(edge_index.T):
        prize_e = prizes["edges"][i].item()
        if prize_e <= updated_cost_e:
            mapping["edges"][len(edges)] = i
            edges.append((int(src), int(dst)))
            costs.append(updated_cost_e - prize_e)
        else:
            virtual_node_id = num_nodes + len(virtual["n_prizes"])
            mapping["nodes"][virtual_node_id] = i
            virtual["edges"].extend([(int(src), virtual_node_id), (virtual_node_id, int(dst))])
            virtual["costs"].extend([0, 0])
            virtual["n_prizes"].append(prize_e - updated_cost_e)

    prizes_final = cp.concatenate([prizes["nodes"], cp.array(virtual["n_prizes"])])
    edges_dict = {
        "edges": cp.array(edges + virtual["edges"]),
        "num_prior_edges": len(edges),
    }
    costs_final = cp.array(costs + virtual["costs"])
    return edges_dict, prizes_final, costs_final, mapping


# -------------------------------
# Optimized Function
# -------------------------------
def compute_subgraph_costs_optimized(edge_index, num_nodes, prizes):
    updated_cost_e = min(
        cost_e,
        prizes["edges"].max().item() * (1 - c_const / 2),
    )

    src, dst = edge_index
    prize_e = prizes["edges"]

    real_mask = prize_e <= updated_cost_e
    virtual_mask = ~real_mask

    real_indices = cp.nonzero(real_mask)[0]
    virtual_indices = cp.nonzero(virtual_mask)[0]

    # Real edges and costs
    real_src = src[real_indices]
    real_dst = dst[real_indices]
    real_edges = cp.stack([real_src, real_dst], axis=1)
    real_costs = updated_cost_e - prize_e[real_indices]

    # Virtual nodes
    virtual_src = src[virtual_indices]
    virtual_dst = dst[virtual_indices]
    n_virtuals = len(virtual_indices)
    virtual_node_ids = cp.arange(num_nodes, num_nodes + n_virtuals)

    virtual_edges = cp.stack([
        cp.column_stack([virtual_src, virtual_node_ids]),
        cp.column_stack([virtual_node_ids, virtual_dst])
    ]).reshape(-1, 2)

    virtual_costs = cp.zeros(len(virtual_edges))
    virtual_n_prizes = prize_e[virtual_indices] - updated_cost_e

    # Combine everything
    all_edges = cp.concatenate([real_edges, virtual_edges], axis=0)
    all_costs = cp.concatenate([real_costs, virtual_costs])
    all_prizes = cp.concatenate([prizes["nodes"], virtual_n_prizes])

    mapping = {
        "edges": {int(i): int(real_indices[i]) for i in range(len(real_indices))},
        "nodes": {int(num_nodes + i): int(virtual_indices[i]) for i in range(n_virtuals)},
    }

    edges_dict = {
        "edges": all_edges,
        "num_prior_edges": len(real_edges)
    }

    return edges_dict, all_prizes, all_costs, mapping


# -------------------------------
# Run Comparison
# -------------------------------
orig_result = compute_subgraph_costs_original(edge_index, num_nodes, prizes)
opt_result = compute_subgraph_costs_optimized(edge_index, num_nodes, prizes)

# Compare outputs
print("✅ Edges identical:", cp.all(orig_result[0]["edges"] == opt_result[0]["edges"]))
print("✅ Num prior edges:", orig_result[0]["num_prior_edges"] == opt_result[0]["num_prior_edges"])
print("✅ Prizes equal:", cp.allclose(orig_result[1], opt_result[1]))
print("✅ Costs equal:", cp.allclose(orig_result[2], opt_result[2]))
print("✅ Edge mappings equal:", orig_result[3]["edges"] == opt_result[3]["edges"])
print("✅ Node mappings equal:", orig_result[3]["nodes"] == opt_result[3]["nodes"])
