In [5]:
import torch
from torch_geometric.utils.random import barabasi_albert_graph, erdos_renyi_graph
import torch_geometric

In [6]:
import numpy as np
import time
from collections import Counter

In [7]:
class Graph():
    mapping = dict()
    
    def __init__(self):
        self.nodes = []
        
    def add_node(self, u):
        self.nodes.append(u)
        Graph.mapping[u.id] = len(self.nodes) - 1 
    
    def get_nodes(self):
        list_nodes_id = []
        for i in Graph.mapping.values():
            if i != None:
                list_nodes_id.append(i)
        return list_nodes_id
    
    def get_mapping(self):
        mapping = dict()
        super_nodes = self.get_nodes()
        for i in range(len(super_nodes)):
            mapping[i] = graph.nodes[super_nodes[i]].merged_nodes
        return mapping
            
    def add_edge(self, u_id, v_id):
        u = self.nodes[Graph.mapping[u_id]]
        v = self.nodes[Graph.mapping[v_id]]
        u.add_node(v)
        v.add_node(u)
    
    def remove_edge(self, u_id, v_id):
        u = self.nodes[Graph.mapping[u_id]]
        v = self.nodes[Graph.mapping[v_id]]
        u.remove_node(v)
        v.remove_node(u)
    
    def join_nodes(self, u_id, v_id):
        u = self.nodes[Graph.mapping[u_id]]
        v = self.nodes[Graph.mapping[v_id]]
              
        u.join(v)    
        self.nodes[Graph.mapping[v_id]] = None
        Graph.mapping[v_id] = None  
            
    def get_heavy_edge(self, u_id):
        u = self.nodes[Graph.mapping[u_id]]
        heavy_node_id = None
        heavy_edge = float('inf')
        for v in u.neigh:
            if  v.degree < heavy_edge:
                heavy_edge = v.degree
                heavy_node_id = v.id
        return heavy_node_id
          

In [8]:
class Node():
    def __init__(self, i):
        self.neigh = [] 
        self.degree = 0
        self.id = i
        self.merged_nodes = [i]
    
    def join(self, v):
        self.merged_nodes.append(v.id)
        for node in v.neigh:
            node.remove_node(v)
            node.add_node(self)
            self.add_node(node)
    
    def add_node(self, u):
        if not(u in self.neigh) and (u.id != self.id):
            self.neigh.append(u)
            self.degree += 1
             
    def add_nodes(self, nodes):
        for node in nodes:
            self.add_node(node)
        
    def remove_node(self, u):
        try:
            self.neigh.remove(u)
            self.degree -= 1
        except:
            pass        

In [9]:
n_nodes = 3000
edge_index = barabasi_albert_graph(n_nodes,4)
graph = Graph()

In [10]:
for i in range(n_nodes):
    graph.add_node(Node(i))

In [11]:
for i in range(edge_index.shape[1]):
    graph.add_edge(edge_index[0,i].item(), edge_index[1,i].item())

In [12]:
start = time.time()
n_supernodes = int(0.1 * n_nodes)
while len(graph.get_nodes()) > n_supernodes:
    u_id = np.random.choice(graph.get_nodes())
    v_id = graph.get_heavy_edge(u_id)
    graph.join_nodes(u_id, v_id)
print(time.time() - start)

1.3993408679962158


In [13]:
for n in graph.get_nodes():
    print(graph.nodes[n].merged_nodes)

[0, 1253, 1489]
[1, 2009, 2928, 1450, 1862, 973]
[3, 320, 2239, 203, 1455]
[4, 384, 1482, 2816, 2043]
[5, 143, 1644, 2070, 1792]
[6, 1535, 1716]
[7]
[8, 1905]
[9, 2, 1486]
[10]
[11, 1529, 2137, 558, 2079, 1149, 1674, 119]
[12]
[13, 786]
[14, 385, 2361, 2052]
[15, 1427, 579]
[16, 1244, 2638, 345]
[17, 1835, 307]
[18, 1989, 670]
[19, 2567, 900, 2584, 51]
[20, 167]
[21, 1525]
[22, 344, 928]
[23]
[24, 2200, 2982, 1348, 658]
[25, 1612, 209]
[26, 677, 2725]
[28]
[29, 1404, 871, 427]
[30, 2563, 2745, 2007, 2944]
[31, 2240, 2981, 184, 2173, 50]
[32, 1809, 394, 710, 762]
[33, 702]
[34, 2427]
[38, 984, 823]
[39, 2612, 2700, 124]
[40, 599, 467, 1309, 68]
[41, 1176, 2931, 2029, 215]
[42, 2510, 1118]
[43, 2033, 2447, 2164, 743, 141, 597]
[44, 2663, 1672, 1636]
[46, 2266, 1807, 2479]
[48, 1425, 2084, 174]
[49, 1449]
[52, 1135, 1285, 173, 482, 2641]
[53, 1597, 2523, 2188]
[54, 1182, 1691, 1592, 296]
[56, 956, 1385, 159, 2606, 506]
[57, 1207, 649, 848]
[60, 2314, 1888, 1812, 2177, 214, 1646]
[61, 1304

In [15]:
mapping = graph.get_mapping()

In [16]:
mapping

{0: [0, 1253, 1489],
 1: [1, 2009, 2928, 1450, 1862, 973],
 2: [3, 320, 2239, 203, 1455],
 3: [4, 384, 1482, 2816, 2043],
 4: [5, 143, 1644, 2070, 1792],
 5: [6, 1535, 1716],
 6: [7],
 7: [8, 1905],
 8: [9, 2, 1486],
 9: [10],
 10: [11, 1529, 2137, 558, 2079, 1149, 1674, 119],
 11: [12],
 12: [13, 786],
 13: [14, 385, 2361, 2052],
 14: [15, 1427, 579],
 15: [16, 1244, 2638, 345],
 16: [17, 1835, 307],
 17: [18, 1989, 670],
 18: [19, 2567, 900, 2584, 51],
 19: [20, 167],
 20: [21, 1525],
 21: [22, 344, 928],
 22: [23],
 23: [24, 2200, 2982, 1348, 658],
 24: [25, 1612, 209],
 25: [26, 677, 2725],
 26: [28],
 27: [29, 1404, 871, 427],
 28: [30, 2563, 2745, 2007, 2944],
 29: [31, 2240, 2981, 184, 2173, 50],
 30: [32, 1809, 394, 710, 762],
 31: [33, 702],
 32: [34, 2427],
 33: [38, 984, 823],
 34: [39, 2612, 2700, 124],
 35: [40, 599, 467, 1309, 68],
 36: [41, 1176, 2931, 2029, 215],
 37: [42, 2510, 1118],
 38: [43, 2033, 2447, 2164, 743, 141, 597],
 39: [44, 2663, 1672, 1636],
 40: [46, 22