In [1]:
import numpy as np
import torch
import dgl
import time


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def coo2csr_sort(row,col):
    sort_row,indice = torch.sort(row,dim=0)
    indice = col[indice]
    inptr = torch.cat([torch.Tensor([0]).to(torch.int32),torch.cumsum(torch.bincount(sort_row), dim=0)])
    return inptr,indice

def coo2csr_dgl(srcs,dsts):
    g = dgl.graph((dsts,srcs)).formats('csr')       # Sequential switching is equivalent to converting CSC and compressing dst
    indptr, indices, _ = g.adj_sparse(fmt='csr')
    return indptr,indices



In [3]:
def cooTocsr(srcList,dstList,sliceNUM=1,device=torch.device('cpu')):
    dstList = dstList.cuda()
    inptr = torch.cat([torch.Tensor([0]).to(torch.int32).to(dstList.device),torch.cumsum(torch.bincount(dstList), dim=0)]).to(torch.int32)
    indice = torch.zeros_like(srcList).to(torch.int32).cuda()
    addr = inptr.clone()[:-1].cuda()
    if sliceNUM == 1:
        srcList = srcList.cuda()
        dgl.cooTocsr(inptr,indice,addr,dstList,srcList) # TODO : Compress the fourth argument of the function, so dst and src are switched sequentially
        inptr = inptr.cpu() 
        indice = indice.cpu()
        addr = None
        srcList = srcList.cpu()
        dstList = dstList.cpu()
        return inptr,indice
    else:
        dstList = dstList.cpu()
        src_batches = torch.chunk(srcList, sliceNUM, dim=0)
        dst_batches = torch.chunk(dstList, sliceNUM, dim=0)
        batch = [src_batches, dst_batches]
        for _,(src_batch,dst_batch) in enumerate(zip(*batch)):
            src_batch = src_batch.cuda()
            dst_batch = dst_batch.cuda()
            dgl.cooTocsr(inptr,indice,addr,dst_batch,src_batch) # compact dst save src
        addr,dst_batch,src_batch= None,None,None
        inptr = inptr.cpu() 
        indice = indice.cpu()
        return inptr,indice

In [9]:
dst = torch.tensor([0,0,0,1,1,3,3]).to(torch.int32)#.cuda()
src = torch.tensor([1,2,3,4,5,6,7]).to(torch.int32)#.cuda()
indptr, indices = coo2csr_sort(src,dst)
print("ptr: ",indptr)
print("indices: ",indices)
print('-'*10)
indptr, indices = coo2csr_sort(dst,src)
print("ptr: ",indptr)
print("indices: ",indices)

ptr:  tensor([0, 0, 1, 2, 3, 4, 5, 6, 7])
indices:  tensor([0, 0, 0, 1, 1, 3, 3], dtype=torch.int32)
----------
ptr:  tensor([0, 3, 5, 5, 7])
indices:  tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.int32)


In [10]:
indptr, indices = coo2csr_dgl(src,dst)
print("ptr: ",indptr)
print("indices: ",indices)
print('-'*10)
indptr, indices = coo2csr_dgl(dst,src)
print("ptr: ",indptr)
print("indices: ",indices)

ptr:  tensor([0, 3, 5, 5, 7, 7, 7, 7, 7], dtype=torch.int32)
indices:  tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.int32)
----------
ptr:  tensor([0, 0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.int32)
indices:  tensor([0, 0, 0, 1, 1, 3, 3], dtype=torch.int32)


In [12]:
indptr, indices = cooTocsr(src,dst,sliceNUM=1,device=torch.device('cpu'))
print("ptr: ",indptr)
print("indices: ",indices)
indptr, indices = cooTocsr(dst,src,sliceNUM=1,device=torch.device('cpu'))
print('-'*10)
print("ptr: ",indptr)
print("indices: ",indices)

ptr:  tensor([0, 3, 5, 5, 7], dtype=torch.int32)
indices:  tensor([1, 2, 3, 4, 5, 6, 7], dtype=torch.int32)
----------
ptr:  tensor([0, 0, 1, 2, 3, 4, 5, 6, 7], dtype=torch.int32)
indices:  tensor([0, 0, 0, 1, 1, 3, 3], dtype=torch.int32)
