In [1]:
import torch
from torch_scatter import scatter

In [2]:
def sparse_dense_mul(s, d):
  i = s._indices()
  v = s._values()
  dv = d[i[0,:]]  # get values from relevant entries of dense matrix

  return torch.sparse_coo_tensor(i, v * dv, s.size(), device=s.device)

def index_scatter(sp, index, n_nodes):
  i = sp._indices()
  v = sp._values()

  new_0th = index[i[0]]
  new_1th = i[1]

  sp = torch.sparse_coo_tensor(torch.vstack((new_0th, new_1th)), v, size=(n_nodes, n_nodes), device = index.device)

  sp = sp.coalesce()
  return sp

In [3]:
n_nodes = 100
device = "cpu"
n_adj = 300


dense_identity = torch.eye(100, device = device)
sparse_identity = dense_identity.to_sparse()

#Adj Matrix is sparse but only the index is required as the values are only 0 or 1
adj = torch.randint(low = 0, high = n_nodes, size = (2, n_adj))
rev = torch.flip(adj, dims = (0, ))
adj = torch.hstack((adj, rev))

torch.allclose(dense_identity, sparse_identity.to_dense())

True

In [4]:
sparse_onehot_j = sparse_identity.index_select(0, adj[0])
dense_onehot_j = dense_identity.index_select(0, adj[0])

torch.allclose(dense_onehot_j, sparse_onehot_j.to_dense())

True

In [5]:
alpha = torch.arange(600)

In [6]:
sparse_muled_onehot_j = sparse_dense_mul(sparse_onehot_j, alpha)
dense_muled_onehot_j = dense_onehot_j * alpha.unsqueeze(-1)

torch.allclose(dense_muled_onehot_j, sparse_muled_onehot_j.to_dense())

True

In [7]:
sparse_fin = index_scatter(sparse_muled_onehot_j, adj[1], n_nodes)
dense_fin = scatter(dense_muled_onehot_j, adj[1], 0, dim_size= n_nodes, reduce="add")

torch.allclose(dense_fin, sparse_fin.to_dense())

True

In [8]:
sparse_yep = sparse_fin + sparse_identity

# Timetest

In [65]:
def timetest(n_nodes, n_adj, device):
    
    coords = torch.arange(0, n_nodes, device = device).unsqueeze(0).repeat([2, 1])
    sparse_identity = torch.sparse_coo_tensor(coords, torch.ones(n_nodes, device = device), (n_nodes, n_nodes), device = device)

    adj = torch.randint(low = 0, high = n_nodes, size = (2, n_adj), device= device)
    rev = torch.flip(adj, dims = (0, ))
    adj = torch.hstack((adj, rev))

    sparse_onehot_j = sparse_identity.index_select(0, adj[0])

    alpha = torch.arange(2 * n_adj, device = device)

    sparse_muled_onehot_j = sparse_dense_mul(sparse_onehot_j, alpha)

    sparse_fin = index_scatter(sparse_muled_onehot_j, adj[1], n_nodes)

    sparse_fin + sparse_identity

In [9]:
def timetest_dense(n_nodes, n_adj, device):
    for i in range(256):
        eye = torch.eye(n_nodes, device = device)
        onehot = torch.eye(n_nodes, device = device)
        for e in range(10):


            adj = torch.randint(low = 0, high = n_nodes, size = (2, n_adj), device = device)
            rev = torch.flip(adj, dims = (0, ))
            adj = torch.hstack((adj, rev))

            dense_onehot_j = onehot.index_select(0, adj[0])

            alpha = torch.arange(2 * n_adj, device = device)

            dense_muled_onehot_j = dense_onehot_j * alpha.unsqueeze(-1)

            dense_fin = scatter(dense_muled_onehot_j, adj[1], 0, dim_size= n_nodes, reduce="add")

            onehot = dense_fin + eye

In [75]:
%timeit timetest(100, 300, "cuda")
%timeit timetest(100, 300, "cpu")
%timeit timetest_dense(100, 300, "cuda")
%timeit timetest_dense(100, 300, "cpu")

359 ms ± 6.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
18.7 ms ± 902 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
120 µs ± 886 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
201 µs ± 7.73 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [10]:
#%timeit timetest(10000, 10000, "cuda")
#%timeit timetest(10000, 10000, "cpu")
%timeit timetest_dense(200, 400, "cuda")
#%timeit timetest_dense(10000, 10000, "cpu")

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking arugment for argument index in method wrapper_index_select)