In [None]:
import torch
import torch_geometric
from torch_geometric.data import Data
import networkx as nx
import matplotlib.pyplot as plt
import pymetis
import numpy as np
from torch_sparse import SparseTensor
from poisson_disc import Bridson_sampling  # Use Bridson_sampling function
from sklearn.neighbors import NearestNeighbors
from sklearn.cluster import KMeans  # 添加 KMeans 相关导入

import metis

import networkx as nx

def sort_subgraphs(pos, node_mask):
    """Sort the subgraphs based on the node positions.

    Parameters
    ----------
    pos : torch.Tensor [num_nodes, 2]
        The node positions.
    node_mask : torch.Tensor [n_patches, num_nodes]
        The node mask for each subgraph.

    Returns
    -------
    node_mask : torch.Tensor [n_patches, num_nodes]
        The node mask for each subgraph.
    """
    # calculate the center of each subgraph
    mask = torch.sum(node_mask, dim=1) > 0
    compressed_mask = node_mask[mask]
    compressed_len = compressed_mask.size(0)
    center = torch.stack([pos[compressed_mask[i]].mean(0)
                          for i in range(compressed_len)])
    radius = torch.norm(center, dim=1)
    # sort the subgraphs based on the radius
    _, indices = torch.sort(radius)
    sorted_node_mask = node_mask.new_full(node_mask.size(), False)
    sorted_node_mask[:compressed_len,:] = compressed_mask[indices]
    return sorted_node_mask


# Define the k_hop_subgraph function
def k_hop_subgraph(edge_index, num_nodes, num_hops, is_directed=False):
    """
    Returns the k-hop subgraph mask for all nodes in the graph.
    """
    if is_directed:
        row, col = edge_index
        birow, bicol = torch.cat([row, col]), torch.cat([col, row])
        edge_index = torch.stack([birow, bicol])
    else:
        row, col = edge_index

    sparse_adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))
    
    hop_masks = [torch.eye(num_nodes, dtype=torch.bool)]
    hop_indicator = torch.full((num_nodes, num_nodes), -1, dtype=torch.long)
    hop_indicator[hop_masks[0]] = 0
    
    for i in range(num_hops):
        next_mask = (sparse_adj.matmul(hop_masks[i].float()) > 0).bool()
        hop_masks.append(next_mask)
        hop_indicator[(hop_indicator == -1) & next_mask] = i + 1
    
    hop_indicator = hop_indicator.T  # N x N
    node_mask = (hop_indicator >= 0)  # N x N dense mask matrix
    return node_mask

# 修改后的基于 KMeans 的子图分割函数
def kmeans_subgraph(g, n_patches, drop_rate=0.0, num_hops=1, is_directed=False):
    """
    Use KMeans clustering to partition the graph into subgraphs.

    参数:
        g (torch_geometric.data.Data):
        n_patches (int): expected number of subgraphs.
        drop_rate (float): data augmentation parameter to randomly remove edges.
        num_hops (int): k-hop neighbors to include in each subgraph.
        is_directed (bool): whether the graph is directed.

    返回:
        node_mask (torch.BoolTensor): every subgraph's node mask, shape [n_patches, num_nodes].
        edge_mask (torch.BoolTensor): every subgraph's edge mask, shape [n_patches, num_edges].
    """
    num_nodes = g.num_nodes

    if num_nodes < n_patches:
        # 如果节点数少于簇数，随机分配簇
        membership = torch.randperm(n_patches)[:num_nodes]
    else:
        # 提取节点的位置信息作为聚类特征
        pos = g.x.numpy()  # 假设 g.x 包含节点的位置信息

        # 计算点云的边界
        x_min, y_min = pos.min(axis=0)
        x_max, y_max = pos.max(axis=0)

        # 生成均匀分布的初始中心点
        # 假设 n_patches 是一个完全平方数，以便在网格中均匀分布
        grid_size = int(np.ceil(np.sqrt(n_patches)))
        x_lin = np.linspace(x_min, x_max, grid_size)
        y_lin = np.linspace(y_min, y_max, grid_size)
        xv, yv = np.meshgrid(x_lin, y_lin)
        initial_centers = np.vstack([xv.ravel(), yv.ravel()]).T[:n_patches]

        # 使用均匀初始化的中心点进行 KMeans 聚类
        kmeans = KMeans(n_clusters=n_patches, init=initial_centers, n_init=1, random_state=42)
        membership = kmeans.fit_predict(pos)
        membership = torch.tensor(membership, dtype=torch.long)
    
    # 创建每个簇的掩码
    node_mask = torch.stack([membership == i for i in range(n_patches)])  # 形状: [n_patches, num_nodes]

    if num_hops > 0:
        # 获取子图的节点索引
        subgraphs_batch, subgraphs_node_mapper = node_mask.nonzero(as_tuple=True)
        # 计算 k-hop 子图掩码
        k_hop_node_mask = k_hop_subgraph(
            g.edge_index, g.num_nodes, num_hops, is_directed)
        # 更新 node_mask 以包含 k-hop 邻居
        node_mask.index_add_(0, subgraphs_batch, k_hop_node_mask[subgraphs_node_mapper])

        # 将 node_mask 转换为布尔类型
        node_mask = node_mask.bool()

    # 创建每个子图的边掩码
    edge_mask = node_mask[:, g.edge_index[0]] & node_mask[:, g.edge_index[1]]
    return node_mask, edge_mask


# Define the recursive bisection function
def recursive_bisection(adjacency, num_partitions, current_partition, membership, start_id):
    """
    Perform recursive bisection to partition the graph into the specified number of partitions.

    Parameters:
        adjacency (list): Adjacency list of the graph.
        num_partitions (int): Desired number of partitions.
        current_partition (int): Current partition ID to assign nodes.
        membership (list): List to store the partition assignment for each node.
        start_id (int): Starting node index for the current subgraph.

    Returns:
        None: Updates the membership list in-place.
    """
    if num_partitions == 1:
        for node in range(len(adjacency)):
            membership[start_id + node] = current_partition
        return

    # Use pymetis to partition the current graph into two parts
    num_cuts, part = pymetis.part_graph(2, adjacency=adjacency)

    # Split the adjacency list into two subgraphs based on the partition
    subgraph1, subgraph2 = [], []
    node_map1, node_map2 = {}, {}
    index1, index2 = 0, 0

    for node, neighbors in enumerate(adjacency):
        if part[node] == 0:
            node_map1[node] = index1
            index1 += 1
            subgraph1.append([])
        else:
            node_map2[node] = index2
            index2 += 1
            subgraph2.append([])

    for node, neighbors in enumerate(adjacency):
        if part[node] == 0:
            for neighbor in neighbors:
                if neighbor in node_map1:
                    subgraph1[node_map1[node]].append(node_map1[neighbor])
        else:
            for neighbor in neighbors:
                if neighbor in node_map2:
                    subgraph2[node_map2[node]].append(node_map2[neighbor])

    # Recursively partition the two subgraphs
    recursive_bisection(subgraph1, num_partitions // 2, current_partition, membership, start_id)
    recursive_bisection(subgraph2, num_partitions - num_partitions // 2, current_partition + num_partitions // 2, membership, start_id + len(subgraph1))


# Define the metis_subgraph function
def recursive_metis_subgraph(g, n_patches, drop_rate=0.0, num_hops=1, is_directed=False):
    """
    Partition the graph into subgraphs using recursive bisection with pymetis.

    Parameters:
        g (torch_geometric.data.Data): Input graph.
        n_patches (int): Number of desired subgraphs.
        drop_rate (float): Proportion of edges to randomly drop for data augmentation.
        num_hops (int): Number of hops to include neighbors.
        is_directed (bool): Whether the graph is directed.

    Returns:
        node_mask (torch.Tensor): Boolean mask for each subgraph's nodes.
        edge_mask (torch.Tensor): Boolean mask for each subgraph's edges.
    """
    num_nodes = g.num_nodes

    if num_nodes < n_patches:
        # If the number of nodes is less than the number of patches, assign randomly
        membership = torch.randperm(n_patches)[:num_nodes]
    else:
        # Create adjacency list
        adjlist = g.edge_index.t().numpy()
        arr = np.random.rand(len(adjlist))
        selected = arr > drop_rate
        G = nx.Graph()
        G.add_nodes_from(np.arange(g.num_nodes))
        G.add_edges_from(adjlist[selected].tolist())

        # Convert NetworkX graph to adjacency list
        adjacency = [list(G.adj[i]) for i in range(num_nodes)]

        # Initialize membership list
        membership = [-1] * num_nodes

        # Perform recursive bisection
        recursive_bisection(adjacency, n_patches, current_partition=0, membership=membership, start_id=0)

    # Ensure the length of membership matches the number of nodes
    assert len(membership) == g.num_nodes
    membership = torch.tensor(membership, dtype=torch.long)

    # Create boolean masks for each subgraph
    node_mask = torch.stack([membership == i for i in range(n_patches)])  # Shape: [n_patches, num_nodes]

    if num_hops > 0:
        # Get node indices for each subgraph
        subgraphs_batch, subgraphs_node_mapper = node_mask.nonzero(as_tuple=True)
        # Compute k-hop subgraph mask
        k_hop_node_mask = k_hop_subgraph(
            g.edge_index, g.num_nodes, num_hops, is_directed
        )
        # Update node_mask to include k-hop neighbors
        node_mask.index_add_(0, subgraphs_batch, k_hop_node_mask[subgraphs_node_mapper])

        # Convert node_mask to boolean type
        node_mask = node_mask.bool()

    # Create edge masks for each subgraph
    edge_mask = node_mask[:, g.edge_index[0]] & node_mask[:, g.edge_index[1]]
    return node_mask, edge_mask


# Define the metis_subgraph function
def pymetis_subgraph(g, n_patches, drop_rate=0.0, num_hops=1, is_directed=False):
    if is_directed:
        if g.num_nodes < n_patches:
            membership = torch.arange(g.num_nodes)
        else:
            # Convert directed graph to undirected graph
            G = torch_geometric.utils.to_networkx(g, to_undirected="lower")
            # Use pymetis for graph partitioning
            adjacency = [list(G.adj[i]) for i in range(g.num_nodes)]
            num_cuts, membership = pymetis.part_graph(n_patches, adjacency=adjacency)
    else:
        if g.num_nodes < n_patches:
            # If the number of nodes is less than the number of patches, assign randomly
            membership = torch.randperm(n_patches)[:g.num_nodes]
        else:
            # Data augmentation: randomly remove some edges
            adjlist = g.edge_index.t().numpy()
            arr = np.random.rand(len(adjlist))
            selected = arr > drop_rate
            G = nx.Graph()
            G.add_nodes_from(np.arange(g.num_nodes))
            G.add_edges_from(adjlist[selected].tolist())
            # Use pymetis for graph partitioning
            adjacency = [list(G.adj[i]) for i in range(g.num_nodes)]
            num_cuts, membership = pymetis.part_graph(n_patches, adjacency=adjacency)

    # Ensure that the length of membership is not less than the number of nodes
    assert len(membership) >= g.num_nodes
    membership = torch.tensor(np.array(membership[:g.num_nodes]))  # Shape: [num_nodes]
    max_patch_id = torch.max(membership) + 1
    membership = membership + (n_patches - max_patch_id)

    # Create boolean masks for each subgraph
    node_mask = torch.stack([membership == i for i in range(n_patches)])  # Shape: [n_patches, num_nodes]
    if (getattr(g, 'pos', None) is not None):
        node_mask = sort_subgraphs(g.pos, node_mask)
    if num_hops > 0:
        # Get node indices for each subgraph
        subgraphs_batch, subgraphs_node_mapper = node_mask.nonzero(as_tuple=True)
        # Compute k-hop subgraph mask
        k_hop_node_mask = k_hop_subgraph(
            g.edge_index, g.num_nodes, num_hops, is_directed)
        # Update node_mask to include k-hop neighbors
        node_mask.index_add_(0, subgraphs_batch, k_hop_node_mask[subgraphs_node_mapper])

        # Convert node_mask to boolean type
        node_mask = node_mask.bool()

    # Create edge masks for each subgraph
    edge_mask = node_mask[:, g.edge_index[0]] & node_mask[:, g.edge_index[1]]
    return node_mask, edge_mask

def metis_subgraph(g, n_patches, drop_rate=0.0, num_hops=1, is_directed=False):
    """Partition the graph into subgraphs using METIS or random partitioning.

    Parameters
    ----------
    g : pytorch_geometric.data.Data
        The input graph.
    n_patches : int
        The number of subgraphs.
    drop_rate : float, optional
        The drop rate for edges to augment the graph, by default 0.0
    num_hops : int, optional
        The overlap of subgraphs, by default 1
    is_directed : bool, optional
        Whether the graph is directed, by default False

    Returns
    -------
    node_mask : torch.Tensor [n_patches, num_nodes]
        The node mask for each subgraph.
    edge_mask : torch.Tensor [n_patches, num_edges]
        The edge mask for each subgraph
    """
    if is_directed:
        if g.num_nodes < n_patches:
            membership = torch.arange(g.num_nodes)
        else:
            G = torch_geometric.utils.to_networkx(g, to_undirected="lower")
            cuts, membership = metis.part_graph(G, n_patches, recursive=True)
    else:
        if g.num_nodes < n_patches:
            membership = torch.randperm(n_patches)
        else:
            # data augmentation
            adjlist = g.edge_index.t()
            arr = torch.rand(len(adjlist))
            selected = arr > drop_rate
            G = nx.Graph()
            G.add_nodes_from(np.arange(g.num_nodes))
            G.add_edges_from(adjlist[selected].tolist())
            # metis partition
            cuts, membership = metis.part_graph(G, n_patches, recursive=True)

    # membership is a list of subgraph partition ids for each node
    assert len(membership) >= g.num_nodes
    membership = torch.tensor(np.array(membership[:g.num_nodes])) # Shape: [num_nodes]
    max_patch_id = torch.max(membership)+1
    membership = membership+(n_patches-max_patch_id)

    # node_mask is a list of boolean masks for each subgraph
    node_mask = torch.stack([membership == i for i in range(n_patches)]) # Shape: [n_patches, num_nodes]
    if (getattr(g, 'pos', None) is not None):
        node_mask = sort_subgraphs(g.pos, node_mask)
    if num_hops > 0:
        # subgraphs_batch is the batch id for each subgraph
        # subgraphs_node_mapper is the node id in the original graph for each subgraph node
        # e.g. sb[0] and snm[0] means the node id snm[0] in the original graph is in the subgraph sb[0]
        subgraphs_batch, subgraphs_node_mapper = node_mask.nonzero().T

        # k_hop_node_mask is the k-hop mask for each node in the original graph
        # especially, diagonal elements are set to True
        k_hop_node_mask = k_hop_subgraph(
            g.edge_index, g.num_nodes, num_hops, is_directed)
        node_mask.index_add_(0, subgraphs_batch,
                             k_hop_node_mask[subgraphs_node_mapper])

    # restrict that the edge's two nodes must be in the same subgraph,
    # then the edge is in the subgraph
    edge_mask = node_mask[:, g.edge_index[0]] & node_mask[:, g.edge_index[1]] # Shape: [n_patches, num_edges]
    return node_mask, edge_mask

# Generate Poisson disk sampled points
def generate_poisson_disk_points(width, height, radius, k=30, seed=42):
    """
    Generate Poisson disk sampled points in 2D space.

    Parameters:
        width (float): Width of the generation area.
        height (float): Height of the generation area.
        radius (float): Minimum distance between points.
        k (int): Maximum number of attempts to generate a new point for each active point.
        seed (int): Random seed.

    Returns:
        points (numpy.ndarray): Generated points, shape (num_points, 2).
    """
    np.random.seed(seed)  # Set random seed for reproducibility
    dims = np.array([width, height])
    points = Bridson_sampling(dims=dims, radius=radius, k=k)
    return points

# Compute nearest neighbors and construct adjacency
def build_knn_graph(points, k=5):
    """
    Use k-NN algorithm to find the k nearest neighbors for each point and construct adjacency.

    Parameters:
        points (numpy.ndarray): Coordinates of points, shape (num_points, 2).
        k (int): Number of nearest neighbors to connect for each point.

    Returns:
        edge_index (torch.LongTensor): Indices of edges, shape (2, num_edges).
    """
    num_points = points.shape[0]
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm='ball_tree').fit(points)
    distances, indices = nbrs.kneighbors(points)
    
    # Exclude each point itself (first neighbor)
    source = np.repeat(np.arange(num_points), k)
    target = indices[:, 1:k+1].reshape(-1)
    
    # To create an undirected graph, add reverse edges
    edge_index = np.vstack((source, target))
    edge_index = np.hstack((edge_index, np.vstack((target, source))))
    
    # Remove duplicate edges
    edge_index = torch.tensor(edge_index, dtype=torch.long)
    edge_index = torch.unique(edge_index, dim=1)
    
    return edge_index

# Create PyTorch Geometric Data object
def create_pyg_data(points, edge_index):
    """
    Create a PyTorch Geometric Data object.

    Parameters:
        points (numpy.ndarray): Coordinates of points, shape (num_points, 2).
        edge_index (torch.LongTensor): Indices of edges, shape (2, num_edges).

    Returns:
        data (torch_geometric.data.Data): PyTorch Geometric Data object.
    """
    x = torch.tensor(points, dtype=torch.float)  # Node features
    data = Data(x=x, edge_index=edge_index)
    return data

# Visualize the generated graph
def visualize_graph(data, title="K-NN Graph"):
    """
    Visualize the graph structure.

    Parameters:
        data (torch_geometric.data.Data): PyTorch Geometric Data object.
        title (str): Title of the graph.
    """
    G = nx.Graph()
    edge_index = data.edge_index.numpy()
    edges = edge_index.T.tolist()
    G.add_edges_from(edges)
    
    # Get node coordinates
    pos = {i: data.x[i].numpy() for i in range(data.num_nodes)}
    
    plt.figure(figsize=(8, 6))
    nx.draw_networkx_nodes(G, pos, node_size=50, node_color='skyblue')
    nx.draw_networkx_edges(G, pos, alpha=0.5)
    plt.title(title)
    plt.axis('off')
    plt.show()

import os
def visualize_and_save_subgraphs(data, node_mask, edge_mask, pos, save_dir=None, prefix='Subgraph'):
    """
    Visualize and save the partitioned subgraphs (patches).

    Parameters:
        data (torch_geometric.data.Data): PyTorch Geometric Data object containing node and edge information.
        node_mask (torch.Tensor): Boolean mask of shape [n_patches, num_nodes], indicating which nodes are in each subgraph.
        edge_mask (torch.Tensor): Boolean mask of shape [n_patches, num_edges], indicating which edges are in each subgraph.
        pos (dict): Node positions, with node indices as keys and coordinate arrays as values.
        save_dir (str, optional): Directory to save images. If None, images are not saved.
        prefix (str, optional): Prefix for image filenames. Default is 'Subgraph'.
    """
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)  # Create the save directory if it doesn't exist
    
    n_patches = node_mask.size(0)  # Number of patches (subgraphs)
    
    for i in range(n_patches):
        # Get nodes in the current subgraph
        nodes_in_subgraph = node_mask[i].nonzero(as_tuple=True)[0].tolist()
        if len(nodes_in_subgraph) == 0:
            print(f"Subgraph {i + 1} is empty. Skipping...")
            continue  # Skip empty subgraphs
        
        # Get edges in the current subgraph
        # edge_mask[i] is a boolean tensor that selects edges belonging to subgraph i
        edges_in_subgraph = data.edge_index[:, edge_mask[i]].t().tolist()
        
        # Create a NetworkX graph for the subgraph
        subG = nx.Graph()
        subG.add_nodes_from(nodes_in_subgraph)
        subG.add_edges_from(edges_in_subgraph)
        
        # Plot the subgraph
        plt.figure(figsize=(6, 6))
        nx.draw_networkx_nodes(subG, pos, node_color='lightblue', node_size=100)
        nx.draw_networkx_edges(subG, pos, edge_color='gray')
        # If you need to display node labels, uncomment the following line
        # nx.draw_networkx_labels(subG, pos, font_size=8, font_weight='bold')
        plt.title(f"{prefix} {i + 1}")
        plt.axis('off')  # Turn off the axes
        
        if save_dir is not None:
            filename = f"{prefix}_{i + 1}.png"
            filepath = os.path.join(save_dir, filename)
            plt.savefig(filepath, bbox_inches='tight')  # Save the figure
            print(f"Saved {filepath}")
            plt.close()  # Close the figure to free memory
        else:
            plt.show()  # Display the figure if not saving

In [None]:
# 参数设置
width = 10.0        # 生成区域的宽度
height = 10.0       # 生成区域的高度
radius =0.5   # 最小点间距
k_knn = 5           # 每个点的最近邻数量
seed = 42           # 随机种子

# 1. 生成泊松盘采样的点
points = generate_poisson_disk_points(width, height, radius, k=20, seed=seed)
pos_dict = {i: pos for i, pos in enumerate(points)}
print(f"Generate Number: {len(points)}")

In [None]:
# 2. 构建 k-NN 图
edge_index = build_knn_graph(points, k=k_knn)
print(f"Number of Edges: {edge_index.shape[1]}")


In [None]:
# 3. 创建 PyTorch Geometric 的 Data 对象
data = create_pyg_data(points, edge_index)
data.pos = torch.tensor(points, dtype=torch.float)

In [None]:
# 4. 可视化生成的图
visualize_graph(data, title=f"{k_knn}-NN Graph with {len(points)} Poisson Disk Sampled Points")

In [None]:
# 5. 执行图分割
n_patches = 128      # 要划分的子图数量

metis_node_mask, metis_edge_mask = metis_subgraph(
    data, 
    n_patches=n_patches, 
    drop_rate=0.0, 
    num_hops=1, 
    is_directed=False
)

kmeans_node_mask, kmeans_edge_mask = kmeans_subgraph(
    data, 
    n_patches=n_patches, 
    drop_rate=0.0, 
    num_hops=1, 
    is_directed=False
)

recursive_node_mask, recursive_edge_mask = recursive_metis_subgraph(
    data, 
    n_patches=n_patches, 
    drop_rate=0.0, 
    num_hops=1, 
    is_directed=False
)

py_node_mask, py_edge_mask = pymetis_subgraph(
    data, 
    n_patches=n_patches, 
    drop_rate=0.0, 
    num_hops=1, 
    is_directed=False
)

print(metis_node_mask.shape)
print(metis_edge_mask.shape)

In [None]:
def to_sparse(node_mask, edge_mask):
    subgraphs_nodes = node_mask.nonzero().T
    subgraphs_edges = edge_mask.nonzero().T
    return subgraphs_nodes, subgraphs_edges

metis_subgraphs_nodes, metis_subgraphs_edges = to_sparse(py_node_mask, py_edge_mask)

subgraphs_batch = metis_subgraphs_nodes[0]
mask = torch.zeros(n_patches).bool()
mask[subgraphs_batch] = True
print(mask)

In [None]:
subgraphs_batch, subgraphs_node_mapper = origin_node_mask.nonzero().T
print(subgraphs_batch.shape)
print(subgraphs_node_mapper.shape)

In [None]:
save_directory = "..\\data\\subgraphs\\test_metis"
visualize_and_save_subgraphs(data,metis_node_mask, metis_edge_mask, pos_dict, save_dir=save_directory, prefix='Metis')

In [None]:
# 设置保存目录（例如：'subgraphs'）
save_directory = '..\\data\\subgraphs\\test_recursive'
# 调用函数可视化并保存子图
visualize_and_save_subgraphs(
    data=data,
    node_mask=recursive_node_mask,
    edge_mask=recursive_edge_mask,
    pos=pos_dict,
    save_dir=save_directory,  # 如果不需要保存，可以设置为 None
    prefix='Recursive'  # 图像文件名的前缀
)


In [None]:
save_directory = '..\\data\\subgraphs\\small_kmeans'

visualize_and_save_subgraphs(
    data=data,
    node_mask=kmeans_node_mask,
    edge_mask=kmeans_edge_mask,
    pos=pos_dict,
    save_dir=save_directory,  # 如果不需要保存，可以设置为 None
    prefix='KMeans'  # 图像文件名的前缀
)

In [None]:
save_directory = '..\\data\\subgraphs\\test_pymetis'

visualize_and_save_subgraphs(
    data=data,
    node_mask=py_node_mask,
    edge_mask=py_edge_mask,
    pos=pos_dict,
    save_dir=save_directory,  # 如果不需要保存，可以设置为 None
    prefix='KMeans'  # 图像文件名的前缀
)