In [1]:
import numpy as np
import torch
import dgl
import time


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# G_MEM: 16G
MAXEDGE = 1000000000    # 
MAXSHUFFLE = 30000000   # 


In [3]:
def convert_to_tensor(data, dtype=torch.int32):
    if isinstance(data, np.ndarray):
        return torch.from_numpy(data).to(dtype)
    else:
        return data.to(dtype)

def remapEdgeId(uniTable,srcList,dstList,device=torch.device('cpu'),remap=None):
    if remap == None:
        # 构建ramap表
        index = torch.arange(len(uniTable),dtype=torch.int32,device=device)
        remap = torch.zeros(torch.max(uniTable)+1,dtype=torch.int32,device=device)
        remap[uniTable.to(torch.int64)] = index
    uniTable = uniTable.cpu()
    if srcList != None:
        srcList = srcList.to(device)
        srcList = remap[srcList.to(torch.int64)]
        srcList = srcList.cpu()
    if dstList != None:
        dstList = dstList.to(device)
        dstList = remap[dstList.to(torch.int64)]
        dstList = dstList.cpu()
    return srcList,dstList,remap


def nodeShuffle(raw_node,raw_graph):
    srcs = raw_graph[::2]
    dsts = raw_graph[1::2]
    raw_node = convert_to_tensor(raw_node, dtype=torch.int32).cuda()
    srcs_tensor = convert_to_tensor(srcs, dtype=torch.int32)
    dsts_tensor = convert_to_tensor(dsts, dtype=torch.int32)
    uniTable = torch.ones(len(raw_node),dtype=torch.int32,device="cuda")
    batch_size = len(srcs) // (MAXEDGE//2) + 1
    src_batches = list(torch.chunk(srcs_tensor, batch_size, dim=0))
    dst_batches = list(torch.chunk(dsts_tensor, batch_size, dim=0))
    batch = [src_batches, dst_batches]
    start = time.time()
    src_emp = raw_node[:1].clone()
    dst_emp = raw_node[:1].clone()
    srcShuffled,dstShuffled,uniTable = dgl.mapByNodeSet(raw_node,uniTable,src_emp,dst_emp)
    remap = None
    for index,(src_batch,dst_batch) in enumerate(zip(*batch)):
        srcShuffled,dstShuffled,remap = remapEdgeId(uniTable,src_batch,dst_batch,remap=remap,device=torch.device('cuda:0'))
        src_batches[index] = srcShuffled
        dst_batches[index] = dstShuffled 
    srcs_tensor = torch.cat(src_batches)
    dsts_tensor = torch.cat(dst_batches)
    uniTable = uniTable.cpu()
    print(f"node shuffle time :{time.time()-start:.4f}")
    return srcs_tensor,dsts_tensor,uniTable

In [4]:
src = torch.tensor([1,2,3,4,5,6,7]).to(torch.int32) + 10
dst = torch.tensor([0,0,0,1,1,3,3]).to(torch.int32) + 10
nodeId = torch.arange(20,dtype=torch.int32)[8:]
graph = torch.stack((src,dst),dim=1).reshape(-1)
graph = torch.cat([graph,graph,graph,graph])
print("G: ",graph)
print("src: ",graph[::2])
print("dst: ",graph[1::2])
print("nodeId: ",nodeId)

G:  tensor([11, 10, 12, 10, 13, 10, 14, 11, 15, 11, 16, 13, 17, 13, 11, 10, 12, 10,
        13, 10, 14, 11, 15, 11, 16, 13, 17, 13, 11, 10, 12, 10, 13, 10, 14, 11,
        15, 11, 16, 13, 17, 13, 11, 10, 12, 10, 13, 10, 14, 11, 15, 11, 16, 13,
        17, 13], dtype=torch.int32)
src:  tensor([11, 12, 13, 14, 15, 16, 17, 11, 12, 13, 14, 15, 16, 17, 11, 12, 13, 14,
        15, 16, 17, 11, 12, 13, 14, 15, 16, 17], dtype=torch.int32)
dst:  tensor([10, 10, 10, 11, 11, 13, 13, 10, 10, 10, 11, 11, 13, 13, 10, 10, 10, 11,
        11, 13, 13, 10, 10, 10, 11, 11, 13, 13], dtype=torch.int32)
nodeId:  tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=torch.int32)


In [5]:
# G_MEM: 16G
MAXEDGE = 16    # 
MAXSHUFFLE = 8   # 


print("src: ",src)
print("dst: ",dst)
print("-"*20)
srcs_tensor,dsts_tensor,uni = nodeShuffle(nodeId,graph)
print("src: ",srcs_tensor)
print("dst: ",dsts_tensor)
print("uni: ",uni)

src:  tensor([11, 12, 13, 14, 15, 16, 17], dtype=torch.int32)
dst:  tensor([10, 10, 10, 11, 11, 13, 13], dtype=torch.int32)
--------------------
node shuffle time :0.1625
src:  tensor([3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5, 6, 7, 8, 9, 3, 4, 5,
        6, 7, 8, 9], dtype=torch.int32)
dst:  tensor([2, 2, 2, 3, 3, 5, 5, 2, 2, 2, 3, 3, 5, 5, 2, 2, 2, 3, 3, 5, 5, 2, 2, 2,
        3, 3, 5, 5], dtype=torch.int32)
uni:  tensor([ 8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=torch.int32)


In [None]:
def nodeShuffle(raw_node,raw_graph):
    srcs = raw_graph[::2]
    dsts = raw_graph[1::2]
    raw_node = convert_to_tensor(raw_node, dtype=torch.int32).cuda()
    srcs_tensor = convert_to_tensor(srcs, dtype=torch.int32)
    dsts_tensor = convert_to_tensor(dsts, dtype=torch.int32)
    uni = torch.ones(len(raw_node)).to(torch.int32).cuda()  
    if len(raw_node) <= MAXSHUFFLE:
        batch_size = len(srcs) // (MAXEDGE// 8) + 1
        print("batch size :",batch_size)
        src_batches = list(torch.chunk(srcs_tensor, batch_size, dim=0))
        dst_batches = list(torch.chunk(dsts_tensor, batch_size, dim=0))
        batch = [src_batches, dst_batches]
        start = time.time()
        for index,(src_batch,dst_batch) in enumerate(zip(*batch)):
            print("slice map by node")
            src_batch = src_batch.cuda()
            dst_batch = dst_batch.cuda()
            srcShuffled,dstShuffled,uni = dgl.mapByNodeSet(raw_node,uni,src_batch,dst_batch)
            srcShuffled = srcShuffled.cpu()
            dstShuffled = dstShuffled.cpu()   
            src_batches[index] = srcShuffled
            dst_batches[index] = dstShuffled 
        srcs_tensor = torch.cat(src_batches)
        dsts_tensor = torch.cat(dst_batches)
        uni = uni.cpu()
        print(f"using time :{time.time()-start:.4f}")
        return srcs_tensor,dsts_tensor,uni
    else:
        batch_size = len(srcs) // (MAXEDGE//2) + 1
        src_batches = list(torch.chunk(srcs_tensor, batch_size, dim=0))
        dst_batches = list(torch.chunk(dsts_tensor, batch_size, dim=0))
        batch = [src_batches, dst_batches]
        start = time.time()
        src_emp = raw_node[:1].clone()
        dst_emp = raw_node[:1].clone()
        srcShuffled,dstShuffled,uni = dgl.mapByNodeSet(raw_node,uni,src_emp,dst_emp)
        for index,(src_batch,dst_batch) in enumerate(zip(*batch)):
            print("slice remap")
            srcShuffled,dstShuffled = remapEdgeId(uni,src_batch,dst_batch,device=torch.device('cuda:0'))
            src_batches[index] = srcShuffled
            dst_batches[index] = dstShuffled 
        srcs_tensor = torch.cat(src_batches)
        dsts_tensor = torch.cat(dst_batches)
        uni = uni.cpu()
        print(f"node shuffle time :{time.time()-start:.4f}")
        return srcs_tensor,dsts_tensor,uni