# Bipartite

## PyG

In [1]:
import torch
import torch.nn.functional as F

cuda = torch.device('cuda') 
r = 8
x = torch.rand((16,128,128,128)).to(cuda)
y = F.avg_pool2d(x, r, r)

x_shape = x.shape
B,C,W,H = x_shape
y_shape = y.shape
By,Cy,Wy,Hy = y_shape

def flat_nodes(x,shape):
  B,C,W,H = shape
  x = x.reshape((-1,C,H*W))#.contiguous()
  x = x.transpose(1,2)#.contiguous()
  x = x.reshape((B*H*W,C))#.contiguous()
  return x

def unflat_nodes(x,shape):
  B,C,W,H = shape

  x = x.reshape((B,H*W,C))
  x = x.transpose(1,2)
  x = x.reshape((-1,C,H,W))
  return x


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import time
from torch_geometric.nn import knn

start_time = time.time()

x_f = flat_nodes(x,x_shape)
#print(x.shape)
y_f = flat_nodes(y,y_shape)
#print(y.shape)

batches_x = torch.linspace(0,B,steps=(B*H*W),dtype=torch.int64).to(device=x.device)
batches_y = torch.linspace(0,By,steps=(By*Hy*Wy),dtype=torch.int64).to(device=x.device)
assign_index = knn(y_f, x_f, 9, batches_y, batches_x)

print("--- %s seconds ---" % (time.time() - start_time))

--- 0.29327917098999023 seconds ---


## PyG + Torch

In [5]:
from gcn_lib.torch_edge import DenseDilatedKnnGraph
import torch.nn.functional as F
import torch
import time


start_time = time.time()
dilated_knn_graph = DenseDilatedKnnGraph(9)

y_f = y.reshape(B, C, -1, 1).contiguous() 
x_f = x.reshape(B, C, -1, 1).contiguous() 
edge_index = dilated_knn_graph(x_f, y_f, None)


x_j = edge_index[0]
x_i = edge_index[1]

count_batches = torch.linspace(0,B,steps=(9*H*W*B),dtype=torch.int64).to(device=x.device)

xx_j = x_j.reshape(-1) + ( count_batches  * (Hy*Wy))
xx_i = x_i.reshape(-1) + ( count_batches  * (H*W))
new_edge_index = torch.cat([xx_i.unsqueeze(0),xx_j.unsqueeze(0)], dim = 0)

x_f = flat_nodes(x, x.shape)
y_f = flat_nodes(y, y.shape)

print("--- %s seconds ---" % (time.time() - start_time))



--- 0.02605724334716797 seconds ---


In [6]:
x_j = edge_index[0]
x_i = edge_index[1]

print(assign_index.t().shape)
print(x_i.shape)
print(x_j.shape)

16*128*128*9 == 2359296


torch.Size([2359296, 2])
torch.Size([16, 16384, 9])
torch.Size([16, 16384, 9])


True

In [7]:
start = (B-1)*H*W*9
end = start+9
print(assign_index.t()[start:end])
print()
print(new_edge_index.t()[start:end])

tensor([[245760,   3982],
        [245760,   3916],
        [245760,   4024],
        [245760,   3840],
        [245760,   3881],
        [245760,   3846],
        [245760,   3999],
        [245760,   3913],
        [245760,   4034]], device='cuda:0')

tensor([[245760,   3982],
        [245760,   3916],
        [245760,   4024],
        [245760,   3840],
        [245760,   3881],
        [245760,   3913],
        [245760,   3846],
        [245760,   3999],
        [245760,   4034]], device='cuda:0')


# Single

In [1]:
import torch
import torch.nn.functional as F

cuda = torch.device('cuda') 
x = torch.rand((16,128,128,128)).to(cuda)


def flat_nodes(x,shape):
  B,C,W,H = shape
  x = x.reshape((-1,C,H*W))#.contiguous()
  x = x.transpose(1,2)#.contiguous()
  x = x.reshape((B*H*W,C))#.contiguous()
  return x

def unflat_nodes(x,shape):
  B,C,W,H = shape

  x = x.reshape((B,H*W,C))
  x = x.transpose(1,2)
  x = x.reshape((-1,C,H,W))
  return x


## PyG

In [2]:
import time
from torch_geometric.nn import knn_graph

start_time = time.time()

x_f = flat_nodes(x,x_shape)

batches_x = torch.linspace(0,B,steps=(B*H*W),dtype=torch.int64).to(device=x.device)
assign_index = knn_graph(x_f, 9, batches_x)

print("--- %s seconds ---" % (time.time() - start_time))

--- 11.07405948638916 seconds ---


## PyG + Torch

In [6]:
from gcn_lib.torch_edge import DenseDilatedKnnGraph
dilated_knn_graph = DenseDilatedKnnGraph(9)
x_f = x.reshape(B, C, -1, 1).contiguous()
edge_index = dilated_knn_graph(x_f, None, None)

In [7]:
x_j = edge_index[0]
x_i = edge_index[1]

In [9]:
x_j[0,0]

tensor([   0,  694,  713, 1552, 3460, 1767, 1068,  272, 1572])

In [1]:
from gcn_lib.torch_edge import DenseDilatedKnnGraph
import torch.nn.functional as F
import torch
import time

cuda = torch.device('cuda') 
x = torch.rand((16,128,64,64)).to(cuda)


def flat_nodes(x,shape):
  B,C,W,H = shape
  x = x.reshape((-1,C,H*W))#.contiguous()
  x = x.transpose(1,2)#.contiguous()
  x = x.reshape((B*H*W,C))#.contiguous()
  return x

def unflat_nodes(x,shape):
  B,C,W,H = shape

  x = x.reshape((B,H*W,C))
  x = x.transpose(1,2)
  x = x.reshape((-1,C,H,W))
  return x


dilated_knn_graph = DenseDilatedKnnGraph(9)
start_time = time.time()

B,C,W,H = x.shape
x_f = x.reshape(B, C, -1, 1)#.contiguous() 
edge_index = dilated_knn_graph(x_f, None, None)


x_j = edge_index[0]
x_i = edge_index[1]

count_batches = torch.linspace(0,B,steps=(9*H*W*B),dtype=torch.int64)
count_batches = count_batches.to(cuda)

xx_j = x_j.reshape(-1) + ( count_batches  * (H*W))
xx_i = x_i.reshape(-1) + ( count_batches  * (H*W))
new_edge_index = torch.cat([xx_i.unsqueeze(0),xx_j.unsqueeze(0)], dim = 0)

x_f = flat_nodes(x, x.shape)

print("--- %s seconds ---" % (time.time() - start_time))

RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

## Results

In [4]:
x_j = edge_index[0]
x_i = edge_index[1]

#print(assign_index.t().shape)
print(x_i.shape)
print(x_j.shape)

16*128*128*9 == 2359296

torch.Size([16, 16384, 9])
torch.Size([16, 16384, 9])


True

In [4]:
16*128*128

262144

In [8]:
start = (B-1)*H*W*9
end = start+9
""" print(assign_index.t()[start:end])
print() """
print(new_edge_index.t()[start:end])
torch.max(new_edge_index)

tensor([[245760, 245760],
        [245760, 254816],
        [245760, 260964],
        [245760, 259659],
        [245760, 251757],
        [245760, 257526],
        [245760, 259117],
        [245760, 254731],
        [245760, 247671]])


tensor(262143)