## Graph Partitioning

I made a graph partitioning algorithm to help create subgraphs that
(1) has the least connecting edges
(2) has a similar internal node label and graph structure

In [45]:
import torch
import torch_geometric
from torch_geometric.data import Data
import random
import networkx as nx
import matplotlib.pyplot as plt

def construct_adj(data):
    """ Construct an adjacency map that has the structure {src: [dst, weight]} *which has both directions* ; and 
        construct an adjacency list that has the structure [src, dst, weight]
    """
    adjacencyMap = {}
    adjacencyList = []
    oneHopeNodes = {}
    for i in range(data.edge_index.size(1)):
        src, dst = data.edge_index[:, i]
        adjacencyList.append([src.item(), dst.item(), data.edge_weight[i].item()])

        if src.item() not in adjacencyMap:
            adjacencyMap[src.item()] = [[dst.item(), data.edge_weight[i].item()]]
            oneHopeNodes[src.item()] = [dst.item()]
        else:
            adjacencyMap[src.item()].append([dst.item(), data.edge_weight[i].item()])
            oneHopeNodes[src.item()].append(dst.item())
        
        if dst.item() not in adjacencyMap:
            adjacencyMap[dst.item()] = [[src.item(), data.edge_weight[i].item()]]
            oneHopeNodes[dst.item()] = [src.item()]
        else:
            adjacencyMap[dst.item()].append([src.item(), data.edge_weight[i].item()])
            oneHopeNodes[dst.item()].append(src.item())

    return adjacencyMap, adjacencyList, oneHopeNodes

def initialize_root_nodes(data, K, oneHopNodes):
    """ Select K root nodes ensuring they are not 1-hop neighbors and have different labels (ignore for now). """
    rootNodes = []
    potential_rootNodes = []
    unavailableNodes = set()
    all_nodes = list(range(data.num_nodes))
    random.shuffle(all_nodes)

    for i in range(len(all_nodes)):
        selectedNode = all_nodes[i]
        if selectedNode in unavailableNodes:
            continue

        potential_rootNodes.append(selectedNode)

        # Exclude node and its 1-hop neighbours for next selection
        unavailableNodes.update(oneHopNodes[selectedNode])
        unavailableNodes.add(selectedNode) 

    print(potential_rootNodes)

    if len(potential_rootNodes) < K:  # if there are less than K disjoint nodes
        # TODO
        # Option 1: Randomly allocate (numSubGraphs - potentialRootNodes.Count) nodes which belong to unavailableNodes
        # Option 2: Allocate nodes from classes which have not been introduced into the set of potentialRootNodes
        return potential_rootNodes
    else:
        rootNodes = potential_rootNodes[:K]
    
    # for node in all_nodes:
    #     if len(root_nodes) >= K:
    #         break
        
    #     if all(not torch.any(data.edge_index == node) for root in root_nodes):
    #         root_nodes.append(node)
    
    return rootNodes

def graph_partitioning(data, K):
    """ Partition the graph into K subgraphs while minimizing inter-subgraph edges and considering node labels. """
    adjacencyMap, adjacencyList, oneHopNodes = construct_adj(data)
    print(adjacencyMap)
    # root_nodes = initialize_root_nodes(data, K, oneHopNodes)
    root_nodes = [0,7]
    partitions = {root: {root} for root in root_nodes}
    unassigned_nodes = set(range(data.num_nodes)) - set(root_nodes)
    node_labels = data.y if hasattr(data, 'y') else torch.randint(0, K, (data.num_nodes,))  # Random labels if missing
    
    while unassigned_nodes:
        for root in partitions:
            neighbors = set()
            for node in partitions[root]:
                # neighbors.update(data.edge_index[1][data.edge_index[0] == node].tolist())
                neighbors.update(oneHopNodes[node])
            
            neighbors.intersection_update(unassigned_nodes)
            if not neighbors:
                continue

            best_node = max(neighbors, key=lambda v: (sum(item[1] for item in adjacencyMap[v]), sum(node_labels[v] == node_labels[n] for n in partitions[root]))) # wrong
            
            # best_node = max(neighbors, key=lambda v: (torch.sum(data.edge_weight[data.edge_index[1] == v]),
            #                                            sum(node_labels[v] == node_labels[n] for n in partitions[root])))
            partitions[root].add(best_node)
            unassigned_nodes.remove(best_node)
    
    return partitions

In [6]:
def generate_example_graph():
    """ Generate a small example graph. """
    num_edges = 20
    edge_index = torch.zeros((2, num_edges), dtype=torch.long)
    for i in range(20):
        edge_index[:, i] = torch.randperm(10)[:2]  # Pick 2 unique numbers from 0-9
    edge_weight = torch.randint(0, 10, (num_edges,)) # Random edge weights
    num_nodes = 10
    node_labels = torch.randint(0, 3, (num_nodes,))  # Assign random labels
    return Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes, y=node_labels)

def visualize_graph(data):
    """ Convert PyG Data to NetworkX and visualize the graph. """
    G = nx.DiGraph()
    edge_index = data.edge_index.numpy()
    edge_weight = data.edge_weight.numpy()
    
    for i in range(edge_index.shape[1]):
        G.add_edge(edge_index[0, i], edge_index[1, i], weight=edge_weight[i])
    
    pos = nx.spring_layout(G)
    edge_labels = {(u, v): f'{d["weight"]}' for u, v, d in G.edges(data=True)}
    
    plt.figure(figsize=(8, 6))
    nx.draw(G, pos, with_labels=True, node_color='lightblue', edge_color='gray', node_size=700, font_size=10)
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)
    plt.title("Graph Visualization")
    plt.show()

In [47]:
# graph_data = generate_example_graph()
# visualize_graph(graph_data)
edge_index = torch.tensor([[0,0,0,0,1,1,2,2,2,2,3,4,4,5,5,7,8,9],[1,3,4,6,6,9,3,5,7,8,10,6,7,9,10,10,9,10]])
# edge_weight = torch.tensor([8,9,8,2,7,7,2,8,1,9,8,6,3,9,8,8,8,5])
edge_weight = torch.tensor([8,4,8,2,7,7,1,8,1,9,8,6,3,9,8,8,8,5])
num_nodes = 11
node_labels = torch.tensor([1,3,3,4,1,4,3,2,4,4,3])

graph_data = Data(edge_index=edge_index, edge_weight=edge_weight, num_nodes=num_nodes, y=node_labels)

In [48]:
partitions = graph_partitioning(graph_data, K=3)
print("Partitioned subgraphs:", partitions)

{0: [[1, 8], [3, 4], [4, 8], [6, 2]], 1: [[0, 8], [6, 7], [9, 7]], 3: [[0, 4], [2, 1], [10, 8]], 4: [[0, 8], [6, 6], [7, 3]], 6: [[0, 2], [1, 7], [4, 6]], 9: [[1, 7], [5, 9], [8, 8], [10, 5]], 2: [[3, 1], [5, 8], [7, 1], [8, 9]], 5: [[2, 8], [9, 9], [10, 8]], 7: [[2, 1], [4, 3], [10, 8]], 8: [[2, 9], [9, 8]], 10: [[3, 8], [5, 8], [7, 8], [9, 5]]}
Partitioned subgraphs: {0: {0, 1, 4, 6, 8, 9}, 7: {2, 3, 5, 7, 10}}
