In [2]:
import numpy as np
import torch
import dgl
import time


  from .autonotebook import tqdm as notebook_tqdm


In [14]:
# G_MEM: 16G
MAXEDGE = 1000000000    # 
MAXSHUFFLE = 30000000   # 


In [27]:
def convert_to_tensor(data, dtype=torch.int32):
    if isinstance(data, np.ndarray):
        return torch.from_numpy(data).to(dtype)
    else:
        return data.to(dtype)

def remapEdgeId(uniTable,srcList,dstList,device=torch.device('cpu'),remap=None):
    if remap == None:
        # setup ramap table
        index = torch.arange(len(uniTable),dtype=torch.int32,device=device)
        remap = torch.zeros(torch.max(uniTable)+1,dtype=torch.int32,device=device)
        remap[uniTable.to(torch.int64)] = index
    uniTable = uniTable.cpu()
    if srcList != None:
        srcList = srcList.to(device)
        srcList = remap[srcList.to(torch.int64)]
        srcList = srcList.cpu()
    if dstList != None:
        dstList = dstList.to(device)
        dstList = remap[dstList.to(torch.int64)]
        dstList = dstList.cpu()
    return srcList,dstList,remap


def nodeShuffle(raw_node,raw_graph):
    srcs, dsts = raw_graph[::2], raw_graph[1::2]
    raw_node = convert_to_tensor(raw_node, dtype=torch.int32).cuda()
    srcs_tensor = convert_to_tensor(srcs, dtype=torch.int32)
    dsts_tensor = convert_to_tensor(dsts, dtype=torch.int32)
    uniTable = torch.ones(len(raw_node),dtype=torch.int32,device="cuda")
    batch_size = len(srcs) // (MAXEDGE//2) + 1
    src_batches = list(torch.chunk(srcs_tensor, batch_size, dim=0))
    dst_batches = list(torch.chunk(dsts_tensor, batch_size, dim=0))
    batch = [src_batches, dst_batches]
    src_emp,dst_emp = raw_node[:1].clone(), raw_node[:1].clone()    # padding , no use
    srcShuffled,dstShuffled,uniTable = dgl.mapByNodeSet(raw_node,uniTable,src_emp,dst_emp,rhsNeed=False,include_rhs_in_lhs=False)
    remap = None
    for index,(src_batch,dst_batch) in enumerate(zip(*batch)):
        srcShuffled,dstShuffled,remap = remapEdgeId(uniTable,src_batch,dst_batch,remap=remap,device=torch.device('cuda:0'))
        src_batches[index] = srcShuffled
        dst_batches[index] = dstShuffled 
    srcs_tensor = torch.cat(src_batches)
    dsts_tensor = torch.cat(dst_batches)
    uniTable = uniTable.cpu()
    return srcs_tensor,dsts_tensor,uniTable

In [3]:
src = torch.tensor([1,2,3,4,5,6,7]).to(torch.int32) + 10
dst = torch.tensor([0,0,0,1,1,3,3]).to(torch.int32) + 10
nodeId = torch.arange(20,dtype=torch.int32)[8:]
graph = torch.stack((src,dst),dim=1).reshape(-1)
graph = torch.cat([graph,graph,graph,graph])
print("G: ",graph)
print("src: ",graph[::2])
print("dst: ",graph[1::2])
print("nodeId: ",nodeId)

G:  tensor([11, 10, 12, 10, 13, 10, 14, 11, 15, 11, 16, 13, 17, 13, 11, 10, 12, 10,
        13, 10, 14, 11, 15, 11, 16, 13, 17, 13, 11, 10, 12, 10, 13, 10, 14, 11,
        15, 11, 16, 13, 17, 13, 11, 10, 12, 10, 13, 10, 14, 11, 15, 11, 16, 13,
        17, 13], dtype=torch.int32)
src:  tensor([11, 12, 13, 14, 15, 16, 17, 11, 12, 13, 14, 15, 16, 17, 11, 12, 13, 14,
        15, 16, 17, 11, 12, 13, 14, 15, 16, 17], dtype=torch.int32)
dst:  tensor([10, 10, 10, 11, 11, 13, 13, 10, 10, 10, 11, 11, 13, 13, 10, 10, 10, 11,
        11, 13, 13, 10, 10, 10, 11, 11, 13, 13], dtype=torch.int32)
nodeId:  tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=torch.int32)


In [8]:
graph = torch.as_tensor(np.fromfile("/home/bear/workspace/single-gnn/data/partition/FR/part0/raw_G.bin",dtype=np.int32))
node = torch.as_tensor(np.fromfile("/home/bear/workspace/single-gnn/data/partition/FR/part0/raw_nodes.bin",dtype=np.int32))

In [21]:
node = torch.cat([node,node,node,node]).to(torch.int32).cuda()

In [28]:
node.shape

torch.Size([114585868])

In [None]:
srcs_tensor,dsts_tensor,uniTable = nodeShuffle(node,graph)

In [24]:
torch.max(dsts_tensor)

tensor(28646466, dtype=torch.int32)

In [25]:
uniTable

tensor([       0,        1,        2,  ..., 65598687, 65599292, 65599766],
       dtype=torch.int32)

In [26]:
uniTable.shape

torch.Size([28646467])