# Speeding up Edge Contraction Algorithm

In [24]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys

import numpy as np
import torch
from torch_scatter import scatter_add
from torch_sparse import coalesce
sys.path.append("..")


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Roadmap

- [X] Create toy graph
- [X] Get timings of original algorithm
- [X] Implement CuGraph connected components
- [X] Get CC timings
- [ ] Implement CC into the PyGeometric function
- [ ] Explore a vectorized version of original idea - only one edge contracted per node
- [ ] Explore sorting vs. random choice of edge

### Toy Graph

In [3]:
num_nodes = 10000
num_edges = 100000
x = torch.rand((num_nodes, 3), device=device).float()
e = torch.randint(0, len(x), (2, num_edges), device=device).long()
edge_score = torch.cat([
    torch.rand(int(num_edges*0.9), device=device).float()*0.4,
    torch.rand(int(num_edges*0.1), device=device).float()
])

### Original Algorithm

In [4]:
from torch_scatter import scatter_add
from torch_sparse import coalesce

In [5]:
def __merge_edges_original__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx].item()
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx].item()
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)

    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch

In [47]:
def __merge_edges__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx]
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx]
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)
    print(len(nodes_remaining)/x.size(0))
    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch


In [48]:
%%time
new_x, new_edge_index, new_batch = __merge_edges__(x, e, torch.zeros(x.shape[0], device=device).long(), edge_score)

1.0
CPU times: user 1.15 s, sys: 49.2 ms, total: 1.2 s
Wall time: 1.2 s


### CuGraph Connected Components

In [8]:
import cugraph
import cudf
import pandas as pd
import cupy as cp
from torch.utils.dlpack import from_dlpack, to_dlpack

ModuleNotFoundError: No module named 'cugraph'

In [None]:
passing_edges = e[:, edge_score > 0.5]

In [None]:
%%time
passing_edges = cudf.from_dlpack(to_dlpack(passing_edges.T))

In [None]:
%%time
G = cugraph.Graph()
G.from_cudf_edgelist(passing_edges, source=0, destination=1, edge_attr=None)

In [None]:
%%time
labels = cugraph.components.connectivity.weakly_connected_components(G)

This all seems to work fine, so let's build it into a new method

#### TODO

In [None]:
def __merge_edges__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx]
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx]
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)

    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch


## Stable Roommates Approach

In [4]:
from torch_scatter import scatter_max

In [5]:
%%time
max_score_0, max_indices_0 = scatter_max(edge_score, e[0], dim=0, dim_size=x.shape[0])
max_score_1, max_indices_1 = scatter_max(edge_score, e[1], dim=0, dim_size=x.shape[0])

CPU times: user 63.2 ms, sys: 7.57 ms, total: 70.7 ms
Wall time: 71.1 ms


In [6]:
stacked_score, stacked_indices = torch.stack([max_score_0, max_score_1]), torch.stack([max_indices_0, max_indices_1]).T
top_score = torch.argmax(stacked_score, dim=0)

In [7]:
max_indices = torch.zeros(len(top_score), dtype=torch.long, device=device)

In [8]:
max_indices[max_score_0 > max_score_1] = max_indices_0[max_score_0 > max_score_1]
max_indices[max_score_1 > max_score_0] = max_indices_1[max_score_1 > max_score_0]

In [9]:
max_indices

tensor([919493, 990488, 718536,  ..., 208696,  13784, 941206], device='cuda:0')

Get timing for comparison

In [5]:
%%time

max_score_0, max_indices_0 = scatter_max(edge_score, e[0], dim=0, dim_size=x.shape[0])
max_score_1, max_indices_1 = scatter_max(edge_score, e[1], dim=0, dim_size=x.shape[0])

stacked_score, stacked_indices = torch.stack([max_score_0, max_score_1]), torch.stack([max_indices_0, max_indices_1]).T
top_score = torch.argmax(stacked_score, dim=0)

max_indices = torch.zeros(len(top_score), dtype=torch.long, device=device)

max_indices[max_score_0 > max_score_1] = max_indices_0[max_score_0 > max_score_1]
max_indices[max_score_1 > max_score_0] = max_indices_1[max_score_1 > max_score_0]

CPU times: user 53.3 ms, sys: 28.7 ms, total: 82 ms
Wall time: 81.6 ms


In [6]:
nodes = torch.arange(1,x.shape[0]+1,device = "cuda:0")
max_indices_copy = max_indices[:]
max_indices_pairs = torch.index_select(max_indices,0,max_indices_copy-1)
print(nodes.shape,max_indices_pairs.shape)

torch.Size([10000]) torch.Size([10000])


In [7]:
max_indices_matches = torch.eq(max_indices_pairs,nodes)

RuntimeError: CUDA error: device-side assert triggered

In [22]:
nodes_remaining = torch.ones(x.shape[0])
edges_shifted = 1 + e
nodes = torch.arange(1,x.shape[0]+1)

In [26]:
print(nodes.max())
print(edges_shifted.max())

tensor(100000)
tensor(100000, device='cuda:0')


In [34]:
%%time
_, counts_0 = torch.unique(e[0],return_counts=True)
_, counts_1 = torch.unique(e[1],return_counts=True)
max_neighbors = max(counts_0.max(),counts_1.max())

CPU times: user 3 ms, sys: 0 ns, total: 3 ms
Wall time: 2.36 ms


In [45]:
num_nodes = 10
num_edges = 100
x_small = torch.rand((num_nodes, 3), device=device).float()
e_small = torch.randint(0, len(x_small), (2, num_edges), device=device).long()
edge_score_small = torch.cat([
    torch.rand(int(num_edges*0.9), device=device).float()*0.4,
    torch.rand(int(num_edges*0.1), device=device).float()
])
_, counts_0 = torch.unique(e_small[0],return_counts=True)
_, counts_1 = torch.unique(e_small[1],return_counts=True)
max_neighbors = max(counts_0.max(),counts_1.max())

In [46]:
preferences = torch.zeros((x_small.shape[0],max_neighbors,2))

In [48]:
preferences.shape

torch.Size([10, 15, 2])

In [49]:
preferences[:,:,0] = torch.gather(e_small[1],) 

tensor([4, 3, 4, 5, 0, 2, 7, 6, 0, 9, 1, 4, 0, 7, 6, 9, 1, 6, 3, 3, 5, 8, 7, 6,
        8, 7, 7, 6, 8, 2, 2, 3, 4, 6, 0, 4, 4, 1, 2, 6, 8, 0, 3, 3, 1, 4, 4, 8,
        8, 2, 2, 1, 8, 7, 9, 7, 9, 4, 0, 8, 0, 6, 8, 5, 2, 1, 2, 4, 3, 2, 7, 2,
        3, 9, 5, 3, 7, 4, 1, 8, 8, 7, 1, 9, 8, 1, 7, 4, 7, 1, 3, 5, 6, 2, 0, 8,
        8, 2, 5, 9], device='cuda:0')

In [50]:
e_small

tensor([[4, 3, 4, 5, 0, 2, 7, 6, 0, 9, 1, 4, 0, 7, 6, 9, 1, 6, 3, 3, 5, 8, 7, 6,
         8, 7, 7, 6, 8, 2, 2, 3, 4, 6, 0, 4, 4, 1, 2, 6, 8, 0, 3, 3, 1, 4, 4, 8,
         8, 2, 2, 1, 8, 7, 9, 7, 9, 4, 0, 8, 0, 6, 8, 5, 2, 1, 2, 4, 3, 2, 7, 2,
         3, 9, 5, 3, 7, 4, 1, 8, 8, 7, 1, 9, 8, 1, 7, 4, 7, 1, 3, 5, 6, 2, 0, 8,
         8, 2, 5, 9],
        [1, 7, 9, 8, 1, 8, 6, 3, 8, 4, 9, 7, 1, 8, 1, 0, 7, 6, 6, 3, 3, 8, 2, 2,
         8, 6, 5, 9, 6, 9, 0, 2, 9, 4, 4, 6, 4, 6, 7, 9, 2, 1, 2, 1, 8, 3, 6, 4,
         3, 7, 7, 7, 2, 7, 7, 6, 8, 7, 0, 6, 2, 5, 3, 4, 5, 3, 2, 4, 0, 7, 1, 3,
         9, 6, 3, 4, 7, 2, 1, 2, 7, 0, 0, 1, 5, 8, 9, 4, 7, 3, 5, 1, 3, 7, 9, 9,
         5, 9, 4, 6]], device='cuda:0')

In [15]:
e_small = torch.tensor(
    [[1,4,3,2,5,2,4],
     [4,3,2,5,3,1,2]])
scores_small = torch.tensor([1.0,0.0,0.4,0.8,])
m = torch.tensor([4,4,5,1,4],device='cpu')
r = ~( m[m-1] == torch.arange(1,6))
print(r.type(torch.ByteTensor))

tensor([0, 1, 1, 0, 1], dtype=torch.uint8)


In [11]:
n = torch.ones_like(m)
n = n.type(torch.ByteTensor)

tensor([1, 1, 1, 1, 1], dtype=torch.uint8)


In [53]:
from torch_scatter import scatter_max
import networkx as nx
device = "cuda:0"

num_nodes = 10000
num_edges = 100000
G = nx.gnm_random_graph(num_nodes, num_edges, seed=None, directed=False)
e = torch.tensor(np.array(G.edges)).T
e = e.to(device)
x = torch.rand((num_nodes,1),device=device).float()
edge_score = torch.cat([
    torch.rand(int(num_edges*0.9), device=device).float(),
    torch.rand(int(num_edges*0.1), device=device).float()
])

In [54]:
%%time

#x = torch.rand((9,1),device=device).float()
#e = torch.tensor([
   # [0,0,0,0,1,1,2,3,3,4,5,6],
  #  [1,3,7,5,6,8,6,4,8,5,8,7]],device=device)
#edge_score = torch.tensor([0.3,0.7,0.8,0.9,0.3,1.0,1.0,0.5,0.9,1.0,0.3,0.9],device=device)

#used for comparing against max_indices
nodes = torch.arange(x.shape[0])
nodes = nodes.to(device)

nodes_remaining = torch.ones_like(nodes,dtype = torch.bool)
nodes_remaining = nodes_remaining.to(device)
edges_remaining = torch.ones_like(e[0],dtype=torch.bool)
edges_remaining = edges_remaining.to(device)
ratio = 1.0
i = 0

while i < 10 and ratio > 0.05:    
    #get max edge score for each node and edge index where it occurs
    max_score_0, max_indices_0 = scatter_max(edge_score, e[0], dim=0, dim_size=x.shape[0])
    max_score_1, max_indices_1 = scatter_max(edge_score, e[1], dim=0, dim_size=x.shape[0])

    #stack scores for each direction
    stacked_score, stacked_indices = torch.stack([max_score_0, max_score_1]), torch.stack([max_indices_0, max_indices_1]).T
    top_score , _ = torch.max(stacked_score, dim=0)
    top_score = top_score.to(device)
    
    #get max neighbor for each node
    max_indices = torch.zeros(len(top_score), dtype=torch.long, device=device)
    max_indices[max_score_0 > max_score_1] = e[1][max_indices_0[max_score_0 > max_score_1]]
    max_indices[max_score_1 > max_score_0] = e[0][max_indices_1[max_score_1 > max_score_0]]
    
    #find edges where each node is the other's max index
    edge_index_match_0 = max_indices[e[1]] == e[0]
    edge_index_match_1 = max_indices[e[0]] == e[1]
    node_0_valid = nodes_remaining[e[0]]
    node_1_valid = nodes_remaining[e[1]]
    edge_index_match = edge_index_match_0 & edge_index_match_1 & node_0_valid & node_1_valid
    
    #update the remaining edges based on which ones should be removed
    edges_remaining &= ~edge_index_match
    edges_contracted = e[:,edge_index_match]
    
    #update the remaining nodes based on which ones should be removed
    nodes_removed = torch.flatten(edges_contracted)
    nodes_remaining[nodes_removed] = 0.0
    
    #zero out the edge scores of every edge that has >= 1 node being removed
    edge_score_zero_mask = (e[..., None] == nodes[nodes_removed]).any(-1).any(0)
    edge_score *= ~edge_score_zero_mask
    ratio = (torch.sum(nodes_remaining)/nodes_remaining.shape[0]).item()
    i += 1
    print(ratio)
    
edges_contracted = e[:,~edges_remaining]
new_e = e[:,edges_remaining]
#print(new_e)
clustered_indices = torch.arange(edges_contracted.shape[1]).to(device)
remaining_indices = edges_contracted.shape[1] + torch.arange(torch.sum(nodes_remaining)).to(device)
new_node_index_map = torch.cat([
    torch.stack([edges_contracted[0],clustered_indices]),
    torch.stack([edges_contracted[1],clustered_indices]),
    torch.stack([nodes[nodes_remaining],remaining_indices])],dim=-1)
new_node_index_map = new_node_index_map[:,torch.argsort(new_node_index_map[0])]
#print(new_node_index_map)
cluster = new_node_index_map[1,:]
#print(cluster)
new_x = scatter_add(x, cluster, dim=0, dim_size=torch.unique(cluster).shape[0])
N = new_x.size(0)
new_edge_index, _ = coalesce(cluster[new_e], None, N, N)
#print(new_edge_index)

0.5055999755859375
0.2773999869823456
0.1605999916791916
0.094200000166893
0.05979999899864197
0.04839999973773956
CPU times: user 23.2 ms, sys: 8.02 ms, total: 31.2 ms
Wall time: 30 ms


In [36]:
%%time


tensor([[0, 0, 0, 1, 3, 3, 5, 6],
        [1, 3, 5, 6, 4, 8, 8, 7]], device='cuda:0')
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8],
        [0, 1, 2, 4, 3, 3, 2, 0, 1]], device='cuda:0')
tensor([0, 1, 2, 4, 3, 3, 2, 0, 1], device='cuda:0')
tensor([[0, 0, 0, 1, 2, 3, 4, 4],
        [1, 3, 4, 2, 0, 1, 1, 3]], device='cuda:0')
CPU times: user 5.76 ms, sys: 739 µs, total: 6.5 ms
Wall time: 5.52 ms


torch.Size([185921, 1])