In [2]:
from typing import Callable, List, NamedTuple, Optional, Tuple, Union

from torch_geometric.datasets import Flickr
import torch

dataset = Flickr("/mnt/nfs-ssd/raw-datasets/pyg-format/Flickr")
data = dataset[0]
data

Data(x=[89250, 500], edge_index=[2, 899756], y=[89250], train_mask=[89250], val_mask=[89250], test_mask=[89250])

In [3]:
from torch_geometric.utils import *

l_edge = data.edge_index
# l_edge, _ = add_self_loops(data.edge_index)
contains_self_loops(l_edge)

False

In [None]:
from torch import nn
import torch.nn.functional as F
from torch_geometric.data import Data

w_ego_root = nn.Parameter(torch.Tensor(data.num_features))
w_ego_u = nn.Parameter(torch.Tensor(data.num_features))
w_layer_v = nn.Parameter(torch.Tensor(data.num_features, 1))
w_layer_u = nn.Parameter(torch.Tensor(data.num_features, 1))
w_threshold = nn.Parameter(torch.Tensor(data.num_features, 1))

nn.init.constant_(w_ego_root, 1)
nn.init.constant_(w_ego_u, 1)
nn.init.constant_(w_layer_v, 1)
nn.init.constant_(w_layer_u, 1)
nn.init.constant_(w_threshold, 1e-5)

cos = nn.CosineSimilarity(dim=-1, eps=1e-6)

def ego_kernel(h_root, h_u):
    h_root = h_root * w_ego_root
    h_u = h_u * w_ego_u

    return cos(h_root, h_u)


def layer_kernel(h_v, h_u, adj):
    h_v = h_v @ w_layer_v
    h_u = h_u @ w_layer_u
    h_msg = adj @ h_v

    return F.normalize(F.relu(h_msg + h_u), dim = 0).view(-1)


def node_importance(self_adj_t):
    adj_t = self_adj_t.fill_value(1., dtype=torch.float)
    adj_t = adj_t.set_diag()
    deg = adj_t.sum(dim=1).to(torch.float)
    deg_inv_sqrt = deg.pow(-1)
    deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
    adj_t = deg_inv_sqrt.view(-1, 1) * adj_t * deg_inv_sqrt.view(1, -1)
    return adj_t.sum(dim=1).pow(0.5)


def edge_global_to_batch(n_idx, e_idx, undirected = False):
    """
     将全局edgeId转换成batch内id
    """
    batch_idx = []
    for edge in e_idx.t():
        batch_idx.append([
            (n_idx == edge[0]).nonzero(as_tuple=True)[0],
            (n_idx == edge[1]).nonzero(as_tuple=True)[0]
        ])

    batch_idx = torch.tensor(batch_idx).t()
    batch_idx, _ = add_remaining_self_loops(batch_idx)
    if undirected:
        batch_idx = to_undirected(batch_idx)

    val, idx = torch.sort(batch_idx[0])
    return torch.stack([val, batch_idx[1][idx]])


def get_node_p(n_idx, p_dict, method = 'sum'):
    """
        methods: [sum, mean, max]
    """
    p_arr = []
    for nid in n_idx:
        p_node = p_dict[nid.item()]
        assert p_node != [], nid
        if method == 'sum':
            p_node = sum(p_node)
        elif method == 'mean':
            p_node = sum(p_node) / len(p_node)
        elif method == 'max':
            p_node = max(p_node)
        else:
            raise NotImplementedError(method+ "not implemented!")
        p_arr.append(p_node.view(1,1))

    return torch.cat(p_arr).view(-1)


class EgoData(Data):
    def __init__(self, x, edge_index, y, p=None):
        y = y[0] if y.numel() > 1 else y
        super().__init__(x, edge_index, y=y)
        # self.ego_index = ego_index
        self.p = p

    # def __inc__(self, key, value, *args, **kwargs):
    #     if key == 'ego_ptr':
    #         return self.num_nodes
    #     else:
    #         return super().__inc__(key, value, *args, **kwargs)


def unique(x, dim=-1):
    unique, inverse = torch.unique(x, return_inverse=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=inverse.device)
    inverse, perm = inverse.flip([dim]), perm.flip([dim])
    return unique, inverse.new_empty(unique.size(dim)).scatter_(dim, inverse, perm)

In [101]:
%%time

from torch_sparse import SparseTensor
from collections import defaultdict
from torch_scatter import scatter

# torch.autograd.set_detect_anomaly(True)


row, col = l_edge.cpu()
self_adj_t = SparseTensor(
    row=row, col=col, value=torch.arange(col.size(0)),
    sparse_sizes=(data.num_nodes, data.num_nodes)).t()

x = data.x
n_imp = node_importance(self_adj_t)


# 入参
self_budget = 200
a = 0.5
p_gather = 'sum'
# margin_expand = False  # 每一轮只用边缘节点探索

for batch_i in range(1):
    batch_node = torch.tensor([13687], dtype=torch.int64)

    e_id = []  # 存储所有边index
    n_id = [batch_node] # 存储所有点index
    # n_p = defaultdict(list)  # 储存所有点权重
    n_p = [torch.tensor([1.])]

    ego_store = defaultdict(float)

    # def get_ego_score(u):
    #     s_arr = []
    #     for u_id in list(u):
    #         score = ego_store[u_id]
    #         score = ego_kernel(x[batch_node], x[u_id]) if score == 0.0 else score
    #         s_arr.append(score)
    #     return torch.cat(s_arr)

    # 初始化
    # n_p[batch_node.item()] = [torch.tensor(1)]
    r_max = (x[batch_node] @ w_threshold).view(-1)
    p_norm = float('-inf')
    v = batch_node
    budget = self_budget
    hop = 0
    while budget > 0:
        """
        只采节点不采边
        """
        adj_t, u = self_adj_t.sample_adj(v, -1, replace=False)
        u_size = u.size(-1)
        row, col, layer_e = adj_t.coo()
        adj = SparseTensor(row=col, col=row, sparse_sizes=adj_t.sparse_sizes()[::-1])

        """计算p_u"""
        # ego_score = get_ego_score(u)
        ego_score = ego_kernel(x[batch_node], x[u])
        layer_score = layer_kernel(x[v], x[u], adj)
        p_u = (a * ego_score + (1 - a) * layer_score) \
              * n_imp[u] \
              * (budget / self_budget)
        # p_norm = max(torch.max(p_u).detach(), p_norm)
        if hop == 0: p_norm = p_u[0].clone().detach()
        p_u /= p_norm

        """计算mask"""
        p_clip = torch.clamp(p_u, min = 0, max=1)
        num_sample = torch.clamp(p_clip.sum(), 1, u_size).int()
        num_sample = min(budget, num_sample)
        mask = torch.zeros((u_size,), dtype=torch.bool)
        mask[torch.multinomial(p_clip, num_sample)] = 1
        budget -= num_sample
        # mask = torch.bernoulli(torch.clamp(p_u, min=0, max=1)).to(torch.bool)
        # layer_cost = sum(mask).item()
        # if layer_cost > budget:
        #     _, p_id = torch.sort(p_u[mask], dim=-1, descending=True)
        #     mask = torch.zeros_like(mask)
        #     mask[p_id[:budget]] = 1
        # budget -= layer_cost

        p_u -= r_max  # 为了让r_max可导
        mask = mask & (p_u > 0)
        if sum(mask).item() < 1: break

        _, col, layer_e = adj_t.coo()
        edges = [layer_e[col[:] == i] for i in torch.arange(u_size)[mask]]
        e_id.append(torch.cat(edges))

        sampled_u = u[mask]
        # margin_u = torch.tensor(list(set(sampled_u.numpy()).difference(set(n_id.numpy()))),dtype=n_id.dtype)
        # n_id = torch.cat((n_id, margin_u))
        n_id.append(sampled_u)
        n_p.append(p_u[mask])

        # for idx, p_val in zip(sampled_u.tolist(), p_u[mask]):
        #     n_p[idx].append(p_val)

        # v = margin_u if margin_expand else sampled_u
        v = sampled_u
        hop += 1

    e_id = torch.cat(e_id).unique()
    n_id, n_mask = torch.cat(n_id).unique(return_inverse=True)
    # n_id, n_mask = unique(torch.cat(n_id))
    p = scatter(torch.cat(n_p), n_mask, dim=-1, reduce=p_gather)

    # n_id[[0, n_mask[0]]] = n_id[[n_mask[0], 0]]
    batch_edge = edge_global_to_batch(n_id, l_edge[:, e_id])
    # p = n_p[n_mask]
    # p = get_node_p(n_id, n_p, 'sum')


    ego_data = EgoData(x[n_id], batch_edge, data.y[n_id[n_mask[0]]], p)
    ego_data.hop = hop
    ego_data.ego_ptr = n_mask[0]

    # print("hop_num:", hop, ego_data)

CPU times: user 7.92 s, sys: 129 ms, total: 8.05 s
Wall time: 168 ms


In [104]:
l_edge[:, 45090]

tensor([ 571, 4104])

In [1497]:
batch_data.p[0].backward()
w_layer_v.grad

tensor([[-4.6185e-08],
        [-7.3994e-08],
        [ 3.3367e-08],
        [-7.9778e-08],
        [ 3.3293e-08],
        [ 1.5478e-07],
        [ 7.8158e-10],
        [-5.6782e-08],
        [-3.7173e-08],
        [-4.8598e-08],
        [-6.1097e-09],
        [ 2.7541e-08],
        [-9.2305e-08],
        [-8.9921e-08],
        [ 8.5099e-08],
        [ 1.8860e-08],
        [ 6.3238e-08],
        [-6.5001e-08],
        [ 3.3741e-08],
        [-8.7589e-08],
        [ 2.3794e-08],
        [ 7.7014e-08],
        [-6.9723e-08],
        [ 1.0773e-08],
        [ 1.0246e-07],
        [-1.2371e-08],
        [-6.3099e-08],
        [-4.7066e-08],
        [-8.9762e-08],
        [-7.9280e-08],
        [ 6.3772e-08],
        [-6.7350e-08],
        [ 2.2183e-08],
        [ 3.4333e-09],
        [ 1.6413e-08],
        [ 3.4389e-08],
        [-9.0412e-08],
        [-4.0481e-08],
        [-4.2033e-08],
        [-7.5425e-08],
        [ 4.5143e-08],
        [-1.1203e-07],
        [ 7.0886e-08],
        [ 8

In [6]:
from torch_geometric.transforms import ToSparseTensor
from torch_geometric.data import Batch
from torch_geometric.loader.utils import filter_data

batch_data = Batch.from_data_list([ego_data, ego_data])

# batch_data = filter_data(data, n_id, batch_edge[0], batch_edge[1], e_id)
# batch_data.batch_size = batch_node.numel()
print(batch_data)
# to_sparse = ToSparseTensor()
# to_sparse(batch_data)
# batch_data

AttributeError: 'NoneType' object has no attribute 'numel'

In [1439]:
batch_data.ego_ptr = (batch_data.ego_ptr + batch_data.ptr[:-1])
xr = batch_data.x[batch_data.ego_ptr]
torch.equal(xr[0], xr[1])

True

In [1486]:
SparseTensor.from_edge_index(data.edge_index)

SparseTensor(row=tensor([    0,     0,     0,  ..., 89249, 89249, 89249]),
             col=tensor([    1,  1694,  2569,  ..., 28631, 52736, 78119]),
             size=(89250, 89250), nnz=899756, density=0.01%)

In [None]:
x = torch.rand(2,5)
index = torch.LongTensor([2,3])
print(x)

# 如果想在x的第一个维度上选择x[2]和x[0]
y = torch.index_select(x, dim=1, index=index)
y

In [1262]:
def unique(x, dim=-1):
    unique, inverse = torch.unique(x, return_inverse=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=inverse.device)
    inverse, perm = inverse.flip([dim]), perm.flip([dim])
    return unique, inverse.new_empty(unique.size(dim)).scatter_(dim, inverse, perm)

t = torch.tensor([3, 3 ,5, 5, 3, 2, 1])
print(unique(t))

(tensor([1, 2, 3, 5]), tensor([6, 5, 0, 2]))


In [1270]:
def _unique(x, dim=-1):
    unique, inverse = torch.unique(x, return_inverse=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=inverse.device)
    inverse, perm = inverse.flip([dim]), perm.flip([dim])
    return unique, inverse.new_empty(unique.size(dim)).scatter_(dim, inverse, perm)

t = torch.tensor([3, 3 ,5, 5, 3, 2, 1])

dim = -1
unique, inverse = torch.unique(t, return_inverse=True, dim=dim)
perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=inverse.device)
inverse, perm = inverse.flip([dim]), perm.flip([dim])
print(inverse, perm)
_unique(t)


tensor([0, 1, 2, 3, 3, 2, 2]) tensor([6, 5, 4, 3, 2, 1, 0])


(tensor([1, 2, 3, 5]), tensor([6, 5, 0, 2]))

In [None]:
from torch_scatter import scatter

t_, inv = torch.unique(t, return_inverse=True)
print(inv)
scatter(t, inv, dim=-1, reduce='sum')

In [1469]:
a = []
torch.cat(a)

NotImplementedError: There were no tensor arguments to this function (e.g., you passed an empty list of Tensors), but no fallback function is registered for schema aten::_cat.  This usually means that this function requires a non-empty list of Tensors, or that you (the operator writer) forgot to register a fallback function.  Available functions are [CPU, CUDA, QuantizedCPU, BackendSelect, Python, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradLazy, AutogradXPU, AutogradMLC, AutogradHPU, AutogradNestedTensor, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, AutocastCPU, Autocast, Batched, VmapMode, Functionalize].

CPU: registered at aten/src/ATen/RegisterCPU.cpp:21063 [kernel]
CUDA: registered at aten/src/ATen/RegisterCUDA.cpp:29726 [kernel]
QuantizedCPU: registered at aten/src/ATen/RegisterQuantizedCPU.cpp:1258 [kernel]
BackendSelect: fallthrough registered at ../aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at ../aten/src/ATen/core/PythonFallbackKernel.cpp:47 [backend fallback]
Named: registered at ../aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at ../aten/src/ATen/ConjugateFallback.cpp:18 [backend fallback]
Negative: registered at ../aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at ../aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at ../aten/src/ATen/core/VariableFallbackKernel.cpp:64 [backend fallback]
AutogradOther: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradCPU: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradCUDA: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradXLA: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradLazy: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradXPU: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradMLC: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradHPU: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradNestedTensor: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradPrivateUse1: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradPrivateUse2: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
AutogradPrivateUse3: registered at ../torch/csrc/autograd/generated/VariableType_3.cpp:11380 [autograd kernel]
Tracer: registered at ../torch/csrc/autograd/generated/TraceType_3.cpp:11220 [kernel]
AutocastCPU: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:461 [backend fallback]
Autocast: fallthrough registered at ../aten/src/ATen/autocast_mode.cpp:305 [backend fallback]
Batched: registered at ../aten/src/ATen/BatchingRegistrations.cpp:1059 [backend fallback]
VmapMode: fallthrough registered at ../aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
Functionalize: registered at ../aten/src/ATen/FunctionalizeFallbackKernel.cpp:52 [backend fallback]


In [1489]:
# batch_data.adj_t = SparseTensor(
#             row=batch_data.edge_index[1], col=batch_data.edge_index[0],
#             sparse_sizes=batch_data.edge_stores[0].size()[::-1],
#             is_sorted=True)
# delattr(batch_data, 'edge_index')
batch_data.adj_t.sparse_sizes()[::-1]


(368, 368)

In [1494]:
torch.tensor([[torch.tensor([0]), torch.tensor([3])], [torch.tensor([0]), torch.tensor([])]]).t()

ValueError: only one element tensors can be converted to Python scalars