In [1]:
import torch
import numpy as np
import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def featAdd(addIdx, addfeat, memfeat, cudafeat):
    # Convert addFeat, but you need to pay attention to the separate conversion of mem and cuda two feat vectors
    # Total can be considered as feat[addIdx] = addfeat
    # Split memfeat[addIdx_mem] = addfeat_mem; cudafeat[addIdx_cuda] = addfeat_cuda
    # Here addIdx has been converted to the actual index location via map
    start = time.time()
    addIdx = addIdx.cuda()

    addIdx_mem_indice = torch.nonzero(addIdx < 0).reshape(-1)
    addIdx_cuda_indice = torch.nonzero(addIdx > 0).reshape(-1)

    addIdx_mem = addIdx[addIdx_mem_indice] * (-1)
    addIdx_cuda = addIdx[addIdx_cuda_indice]
    
    # a bit slow
    memfeat[addIdx_mem] = addfeat[addIdx_mem_indice]
    cudafeat[addIdx_cuda] = addfeat[addIdx_cuda_indice].cuda()

    print(f"featAdd {time.time() - start:.4f}s")

#test:
addIdx = torch.tensor([-1,-3,2,4,1,3,-2])
memfeat = torch.tensor([-1,5,6,7,8,9,10])
cudafeat = torch.tensor([-1,11,12,13,14,15,16]).cuda()
addfeat = torch.tensor([20,21,22,23,24,25,26])
featAdd(addIdx, addfeat, memfeat,cudafeat)
print("memfeat:",memfeat)
print("cudafeat:",cudafeat)
print("addfeat:",addfeat)

featAdd 0.0390s
memfeat: tensor([-1, 20, 26, 21,  8,  9, 10])
cudafeat: tensor([-1, 24, 22, 25, 23, 15, 16], device='cuda:0')
addfeat: tensor([20, 21, 22, 23, 24, 25, 26])


In [None]:
def init_cac(lossMap, feat, memfeat, cudafeat, map):
    #Migrate feat initialization to memfeat and cudafeat
    mask = torch.ones(lossMap.shape[0], dtype = torch.bool, device='cuda')
    mask[lossMap == -1] = False

    cutfeat = feat[~mask]
    memfeat[1 : cutfeat.shape[0] + 1] = cutfeat
    map[~mask] = (-1) * torch.arange(1, cutfeat.shape[0] + 1, device = 'cuda', dtype=torch.int64)

    savefeat = feat[mask]
    cudafeat[1 :savefeat.shape[0] + 1] = savefeat
    map[mask] = torch.arange(1, savefeat.shape[0] + 1, device = 'cuda', dtype=torch.int64)

In [10]:
mask = torch.zeros(10, dtype = torch.bool)
mask[5:8] = True
table = torch.arange(10,dtype = torch.int32)

addFeat = torch.Tensor([-2,-3,-4]).to(torch.int32)
print(table)
print(addFeat)
table[mask]=addFeat
print(table)

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)
tensor([-2, -3, -4], dtype=torch.int32)
tensor([ 0,  1,  2,  3,  4, -2, -3, -4,  8,  9], dtype=torch.int32)


In [12]:
#cuda and cpu convert
def loss_feat_cac(lossMap, memfeat, cudafeat, map):
    # Calling this function requires guarantees
    #1.memfeat and cudafeat contain the feat of all the current subgraphs (can have redundant feat but must have all)
    #2.lossMap stores node index, and map functions as node index -> memfeat/cudafeat index
    # Function: cudafeat all feat of all current graph nodes that are not lost, and update map
    # That is, swap the locations of non-Savenode nodes in cudafeat and saveNode nodes in memfeat and maintain map
    # Ensure: Number of non-Savenode nodes in cudafeat > Number of saveNode nodes in memfeat
    start = time.time()
    saveNode = torch.nonzero(lossMap != -1).reshape(-1)

    # get saveNode
    saveIdxMap_mem = saveNode[torch.nonzero(map[saveNode] < 0).reshape(-1)]
    saveIdx_mem = map[saveIdxMap_mem] * (-1)

    # Gets the non-save index in cuda
    mask = torch.ones(map.shape[0], dtype=torch.bool, device='cuda')
    mask[saveNode] = False
    mask[torch.nonzero(map < 0).reshape(-1)] = False
    nsaveMap_cuda = (torch.nonzero(mask).reshape(-1))[:saveIdx_mem.shape[0]]
    nsaveIdx_cuda = map[nsaveMap_cuda]

    # Swap nsave_cuda[:len(save_mem)] with save_mem and maintain map
    cuda_tmp = cudafeat[nsaveIdx_cuda]
    cudafeat[nsaveIdx_cuda] = memfeat[saveIdx_mem].cuda()
    memfeat[saveIdx_mem] = cuda_tmp.cpu()

    # Maintaining map
    map_cuda_tmp = map[nsaveMap_cuda]
    map[nsaveMap_cuda] = map[saveIdxMap_mem]
    map[saveIdxMap_mem] = map_cuda_tmp
    print(f"loss_feat_cac {time.time() - start:.4f}s")

#test:
map = torch.tensor([-1,-2,1,2,3,-3,-4,-5,4,5,6])
lossMap = torch.tensor([1,1,1,1,-1,-1,1,1])
memfeat = torch.tensor([-1,11,12,13,14,15])
cudafeat = torch.tensor([-1,21,22,23,24,25,26,27]).cuda()
loss_feat_cac(lossMap,memfeat,cudafeat,map)
print("memfeat :",memfeat)
print("cudafeat:",cudafeat)

loss_feat_cac 0.0008s
memfeat : tensor([-1, 23, 24, 13, 25, 26])
cudafeat: tensor([-1, 21, 22, 11, 12, 14, 15, 27], device='cuda:0')
