# Speeding up Edge Contraction Algorithm

In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import sys

import numpy as np
import torch

sys.path.append("..")


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

## Roadmap

- [X] Create toy graph
- [X] Get timings of original algorithm
- [X] Implement CuGraph connected components
- [X] Get CC timings
- [ ] Implement CC into the PyGeometric function
- [ ] Explore a vectorized version of original idea - only one edge contracted per node
- [ ] Explore sorting vs. random choice of edge

### Toy Graph

In [3]:
num_nodes = 10000
num_edges = 100000
x = torch.rand((num_nodes, 3), device=device).float()
e = torch.randint(0, len(x), (2, num_edges), device=device).long()
edge_score = torch.cat([
    torch.rand(int(num_edges*0.9), device=device).float()*0.4,
    torch.rand(int(num_edges*0.1), device=device).float()
])

### Original Algorithm

In [4]:
from torch_scatter import scatter_add
from torch_sparse import coalesce

In [5]:
def __merge_edges_original__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx].item()
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx].item()
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)

    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch

In [6]:
def __merge_edges__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx]
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx]
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)

    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch


In [7]:
%%time
new_x, new_edge_index, new_batch = __merge_edges__(x, e, torch.zeros(x.shape[0], device=device).long(), edge_score)

CPU times: user 11.3 s, sys: 269 ms, total: 11.5 s
Wall time: 11.6 s


### CuGraph Connected Components

In [8]:
import cugraph
import cudf
import pandas as pd
import cupy as cp
from torch.utils.dlpack import from_dlpack, to_dlpack

ModuleNotFoundError: No module named 'cugraph'

In [None]:
passing_edges = e[:, edge_score > 0.5]

In [None]:
%%time
passing_edges = cudf.from_dlpack(to_dlpack(passing_edges.T))

In [None]:
%%time
G = cugraph.Graph()
G.from_cudf_edgelist(passing_edges, source=0, destination=1, edge_attr=None)

In [None]:
%%time
labels = cugraph.components.connectivity.weakly_connected_components(G)

This all seems to work fine, so let's build it into a new method

#### TODO

In [None]:
def __merge_edges__(x, edge_index, batch, edge_score):
        
    nodes_remaining = set(range(x.size(0)))

    cluster = torch.empty_like(batch, device=x.device).long()
    edge_argsort = torch.argsort(edge_score, descending=True)

    # Iterate through all edges, selecting it if it is not incident to
    # another already chosen edge.
    i = 0
    new_edge_indices = []
   # edge_index_cpu = edge_index.cpu()
    for edge_idx in edge_argsort.tolist():
        source = edge_index[0, edge_idx]
        if source not in nodes_remaining:
            continue

        target = edge_index[1, edge_idx]
        if target not in nodes_remaining:
            continue

        new_edge_indices.append(edge_idx)

        cluster[source] = i
        nodes_remaining.remove(source)

        if source != target:
            cluster[target] = i
            nodes_remaining.remove(target)

        i += 1

    # The remaining nodes are simply kept.
    for node_idx in nodes_remaining:
        cluster[node_idx] = i
        i += 1
#     cluster = cluster.to(x.device)

    # We compute the new features as an addition of the old ones.
    new_x = scatter_add(x, cluster, dim=0, dim_size=i)
    new_edge_score = edge_score[new_edge_indices]
    if len(nodes_remaining) > 0:
        remaining_score = x.new_ones(
            (new_x.size(0) - len(new_edge_indices), ))
        new_edge_score = torch.cat([new_edge_score, remaining_score])
    new_x = new_x * new_edge_score.view(-1, 1)

    N = new_x.size(0)
    new_edge_index, _ = coalesce(cluster[edge_index], None, N, N)

    new_batch = x.new_empty(new_x.size(0), dtype=torch.long, device=device)
#     batch = batch.to(x.device)
    new_batch = new_batch.scatter_(0, cluster, batch)

#     unpool_info = self.unpool_description(edge_index=edge_index,
#                                           cluster=cluster, batch=batch,
#                                           new_edge_score=new_edge_score)

#     return new_x, new_edge_index, new_batch, unpool_info
    return new_x, new_edge_index, new_batch


## Stable Roommates Approach

In [4]:
from torch_scatter import scatter_max

In [5]:
%%time
max_score_0, max_indices_0 = scatter_max(edge_score, e[0], dim=0, dim_size=x.shape[0])
max_score_1, max_indices_1 = scatter_max(edge_score, e[1], dim=0, dim_size=x.shape[0])

CPU times: user 63.2 ms, sys: 7.57 ms, total: 70.7 ms
Wall time: 71.1 ms


In [6]:
stacked_score, stacked_indices = torch.stack([max_score_0, max_score_1]), torch.stack([max_indices_0, max_indices_1]).T
top_score = torch.argmax(stacked_score, dim=0)

In [7]:
max_indices = torch.zeros(len(top_score), dtype=torch.long, device=device)

In [8]:
max_indices[max_score_0 > max_score_1] = max_indices_0[max_score_0 > max_score_1]
max_indices[max_score_1 > max_score_0] = max_indices_1[max_score_1 > max_score_0]

In [9]:
max_indices

tensor([919493, 990488, 718536,  ..., 208696,  13784, 941206], device='cuda:0')

Get timing for comparison

In [5]:
%%time

max_score_0, max_indices_0 = scatter_max(edge_score, e[0], dim=0, dim_size=x.shape[0])
max_score_1, max_indices_1 = scatter_max(edge_score, e[1], dim=0, dim_size=x.shape[0])

stacked_score, stacked_indices = torch.stack([max_score_0, max_score_1]), torch.stack([max_indices_0, max_indices_1]).T
top_score = torch.argmax(stacked_score, dim=0)

max_indices = torch.zeros(len(top_score), dtype=torch.long, device=device)

max_indices[max_score_0 > max_score_1] = max_indices_0[max_score_0 > max_score_1]
max_indices[max_score_1 > max_score_0] = max_indices_1[max_score_1 > max_score_0]

CPU times: user 53.3 ms, sys: 28.7 ms, total: 82 ms
Wall time: 81.6 ms


In [6]:
nodes = torch.arange(1,x.shape[0]+1,device = "cuda:0")
max_indices_copy = max_indices[:]
max_indices_pairs = torch.index_select(max_indices,0,max_indices_copy-1)
print(nodes.shape,max_indices_pairs.shape)

torch.Size([10000]) torch.Size([10000])


In [7]:
max_indices_matches = torch.eq(max_indices_pairs,nodes)

RuntimeError: CUDA error: device-side assert triggered

In [22]:
nodes_remaining = torch.ones(x.shape[0])
edges_shifted = 1 + e
nodes = torch.arange(1,x.shape[0]+1)

In [26]:
print(nodes.max())
print(edges_shifted.max())

tensor(100000)
tensor(100000, device='cuda:0')


In [34]:
%%time
_, counts_0 = torch.unique(e[0],return_counts=True)
_, counts_1 = torch.unique(e[1],return_counts=True)
max_neighbors = max(counts_0.max(),counts_1.max())

CPU times: user 3 ms, sys: 0 ns, total: 3 ms
Wall time: 2.36 ms


In [45]:
num_nodes = 10
num_edges = 100
x_small = torch.rand((num_nodes, 3), device=device).float()
e_small = torch.randint(0, len(x_small), (2, num_edges), device=device).long()
edge_score_small = torch.cat([
    torch.rand(int(num_edges*0.9), device=device).float()*0.4,
    torch.rand(int(num_edges*0.1), device=device).float()
])
_, counts_0 = torch.unique(e_small[0],return_counts=True)
_, counts_1 = torch.unique(e_small[1],return_counts=True)
max_neighbors = max(counts_0.max(),counts_1.max())

In [46]:
preferences = torch.zeros((x_small.shape[0],max_neighbors,2))

In [48]:
preferences.shape

torch.Size([10, 15, 2])

In [49]:
preferences[:,:,0] = torch.gather(e_small[1],) 

tensor([4, 3, 4, 5, 0, 2, 7, 6, 0, 9, 1, 4, 0, 7, 6, 9, 1, 6, 3, 3, 5, 8, 7, 6,
        8, 7, 7, 6, 8, 2, 2, 3, 4, 6, 0, 4, 4, 1, 2, 6, 8, 0, 3, 3, 1, 4, 4, 8,
        8, 2, 2, 1, 8, 7, 9, 7, 9, 4, 0, 8, 0, 6, 8, 5, 2, 1, 2, 4, 3, 2, 7, 2,
        3, 9, 5, 3, 7, 4, 1, 8, 8, 7, 1, 9, 8, 1, 7, 4, 7, 1, 3, 5, 6, 2, 0, 8,
        8, 2, 5, 9], device='cuda:0')

In [50]:
e_small

tensor([[4, 3, 4, 5, 0, 2, 7, 6, 0, 9, 1, 4, 0, 7, 6, 9, 1, 6, 3, 3, 5, 8, 7, 6,
         8, 7, 7, 6, 8, 2, 2, 3, 4, 6, 0, 4, 4, 1, 2, 6, 8, 0, 3, 3, 1, 4, 4, 8,
         8, 2, 2, 1, 8, 7, 9, 7, 9, 4, 0, 8, 0, 6, 8, 5, 2, 1, 2, 4, 3, 2, 7, 2,
         3, 9, 5, 3, 7, 4, 1, 8, 8, 7, 1, 9, 8, 1, 7, 4, 7, 1, 3, 5, 6, 2, 0, 8,
         8, 2, 5, 9],
        [1, 7, 9, 8, 1, 8, 6, 3, 8, 4, 9, 7, 1, 8, 1, 0, 7, 6, 6, 3, 3, 8, 2, 2,
         8, 6, 5, 9, 6, 9, 0, 2, 9, 4, 4, 6, 4, 6, 7, 9, 2, 1, 2, 1, 8, 3, 6, 4,
         3, 7, 7, 7, 2, 7, 7, 6, 8, 7, 0, 6, 2, 5, 3, 4, 5, 3, 2, 4, 0, 7, 1, 3,
         9, 6, 3, 4, 7, 2, 1, 2, 7, 0, 0, 1, 5, 8, 9, 4, 7, 3, 5, 1, 3, 7, 9, 9,
         5, 9, 4, 6]], device='cuda:0')

In [15]:
e_small = torch.tensor(
    [[1,4,3,2,5,2,4],
     [4,3,2,5,3,1,2]])
scores_small = torch.tensor([1.0,0.0,0.4,0.8,])
m = torch.tensor([4,4,5,1,4],device='cpu')
r = ~( m[m-1] == torch.arange(1,6))
print(r.type(torch.ByteTensor))

tensor([0, 1, 1, 0, 1], dtype=torch.uint8)


In [11]:
n = torch.ones_like(m)
n = n.type(torch.ByteTensor)

tensor([1, 1, 1, 1, 1], dtype=torch.uint8)


In [79]:
from torch_scatter import scatter_max
device = "cuda:0"


num_nodes = 10000
num_edges = 200000
x = torch.rand((num_nodes, 3), device=device).float()
e = torch.randint(0, len(x), (2, num_edges), device=device).long()

edge_score = torch.cat([
    torch.rand(int(num_edges*0.9), device=device).float(),
    torch.rand(int(num_edges*0.1), device=device).float()
])
edge_score = edge_score[e[0] != e[1]]
e = e[:,e[0] != e[1]]

#x = torch.rand((6,1),device=device).float()
#e = torch.tensor([
#    [0,1,2,3,4],
#    [1,2,3,4,5]],device=device)
#edge_score = torch.tensor([1.0,0.9,0.8,0.9,1.0],device=device)

In [80]:
%%time
#used for comparing against max_indices
nodes = torch.arange(x.shape[0])
nodes = nodes.to(device)

nodes_remaining = torch.ones_like(nodes,dtype = torch.bool)
ratio = 1.0
edges_contracted = torch.empty(2,0,device=device)
i = 0

while i < 10 and ratio > 0.05:    
    max_score_0, max_indices_0 = scatter_max(edge_score, e[0], dim=0, dim_size=x.shape[0])
    max_score_1, max_indices_1 = scatter_max(edge_score, e[1], dim=0, dim_size=x.shape[0])

    stacked_score, stacked_indices = torch.stack([max_score_0, max_score_1]), torch.stack([max_indices_0, max_indices_1]).T
    top_score , _ = torch.max(stacked_score, dim=0)
    top_score = top_score.to(device)
    
    max_indices = torch.zeros(len(top_score), dtype=torch.long, device=device)
    max_indices[max_score_0 > max_score_1] = e[1][max_indices_0[max_score_0 > max_score_1]]
    max_indices[max_score_1 > max_score_0] = e[0][max_indices_1[max_score_1 > max_score_0]]
    
    #gets the max neighbor of the max neighbor of each node
    #if this equals the node itself, the nodes are both each other's max neighbors
    max_indices_pairs = max_indices[max_indices]
    max_indices_matches = (max_indices_pairs == nodes)
    
    #don't check any nodes who have already been contrated
    indices_remaining = top_score > 0.0
    max_indices_matches *= indices_remaining
    
    #update how many nodes are remaining
    nodes_remaining = ~max_indices_matches * nodes_remaining
    
    #find all edges with >= 1 node in the nodes being contracted. This is used to zero out the edge scores
    edge_score_zero_mask = (e[..., None] == nodes[max_indices_matches]).any(-1).any(0)
    edge_score *= ~edge_score_zero_mask
    
    #find all edges with both nodes being contracted. This is used to find which edges are being contracted
    edge_contract_mask_0 = ((e[...,None] == nodes[max_indices_matches])[0]).any(-1)
    edge_contract_mask_1 = ((e[...,None] == nodes[max_indices_matches])[1]).any(-1)
    edge_contract_mask = edge_contract_mask_0 * edge_contract_mask_1
    
    #keep track of which edges are contracted
    new_edges_contracted = e[:,edge_contract_mask]
    edges_contracted = torch.cat([edges_contracted,new_edges_contracted],dim=-1)
    print(edges_contracted.shape[1])
    
    ratio = (torch.sum(nodes_remaining)/nodes_remaining.shape[0]).item()
    i += 1
    print(ratio)

52130
0.4966000020503998
63460
0.26319998502731323
66344
0.14739999175071716
67125
0.09039999544620514
67423
0.05639999732375145
67571
0.03539999946951866
CPU times: user 86.4 ms, sys: 19.6 ms, total: 106 ms
Wall time: 104 ms
