In [1]:
from typing import Callable, List, NamedTuple, Optional, Tuple, Union

import torch
from torch import Tensor
from torch_sparse import SparseTensor

class EdgeIndex(NamedTuple):
    edge_index: Tensor
    e_id: Optional[Tensor]
    size: Tuple[int, int]

    def to(self, *args, **kwargs):
        edge_index = self.edge_index.to(*args, **kwargs)
        e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
        return EdgeIndex(edge_index, e_id, self.size)


class Adj(NamedTuple):
    adj_t: SparseTensor
    e_id: Optional[Tensor]
    size: Tuple[int, int]

    def to(self, *args, **kwargs):
        adj_t = self.adj_t.to(*args, **kwargs)
        e_id = self.e_id.to(*args, **kwargs) if self.e_id is not None else None
        return Adj(adj_t, e_id, self.size)

In [30]:
from torch_geometric.datasets import Flickr

dataset = Flickr("/mnt/nfs-ssd/raw-datasets/pyg-format/Flickr")
data = dataset[0]
data

In [13]:
node_idx = data.train_mask
node_idx = node_idx.nonzero(as_tuple=False).view(-1)
node_idx.size(0)

44625

In [27]:
from torch.utils.data import DataLoader

loader = DataLoader(node_idx, 2, True)

batch_nodes = next(iter(loader))
print(batch_nodes)

tensor([13687, 37207])


In [87]:
batch_size: int = len(batch_nodes)

row, col = data.edge_index.cpu()
self_adj_t = SparseTensor(
    row=row, col=col, value=torch.arange(col.size(0)),
    sparse_sizes=(data.num_nodes, data.num_nodes)).t()

adjs = []
n_id = batch_nodes
for size in [3, 3, 3]:
    adj_t, n_id = self_adj_t.sample_adj(n_id, size, replace=False)
    print(adj_t)
    print(n_id)
    e_id = adj_t.storage.value()
    size = adj_t.sparse_sizes()[::-1]

    row, col, _ = adj_t.coo()
    edge_index = torch.stack([col, row], dim=0)
    adjs.append(EdgeIndex(edge_index, e_id, size))

adjs = adjs[0] if len(adjs) == 1 else adjs[::-1]
out = (batch_size, n_id, adjs)

SparseTensor(row=tensor([0, 0, 0, 1, 1, 1]),
             col=tensor([2, 3, 4, 5, 6, 7]),
             val=tensor([328131, 109541, 489350, 473863, 166128, 102761]),
             size=(2, 8), nnz=6, density=37.50%)
tensor([13687, 37207, 14920,  2314, 30343, 28682,  4458,  2062])
SparseTensor(row=tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7]),
             col=tensor([ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,  1, 17, 18,
                           19, 20, 21, 22, 23, 24]),
             val=tensor([328131, 109541, 489350, 473863, 166128, 102761, 196699, 311928,  47906,
                           455109, 270704, 421367, 542562, 277492, 183038, 548857, 551357, 340557,
                           187529, 585395, 212426, 800920, 426154, 324422]),
             size=(8, 25), nnz=24, density=12.00%)
tensor([13687, 37207, 14920,  2314, 30343, 28682,  4458,  2062,  5856, 13577,
          611, 26765, 10478, 23268, 36477, 10965,  5208, 37498, 15926,

In [118]:
edge_3 = out[2][0].edge_index
edge_1 = out[2][1].edge_index
# torch.equal(edge_1, edge_3[:, edge_1.size(1)])
print(edge_3)
torch.equal(edge_1, edge_3[:, :edge_1.size(1)])

tensor([[ 2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,  1, 17, 18,
         19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 38, 39, 40, 41, 42, 43, 44, 45, 46,  8, 46, 47, 48, 49, 50,
         51, 52, 53, 54, 55, 56,  6, 57, 58, 59, 60, 61,  7, 62, 63, 64, 65, 66,
         67, 68, 69],
        [ 0,  0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,
          6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11, 11,
         12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17,
         18, 18, 18, 19, 19, 19, 20, 20, 20, 21, 21, 21, 22, 22, 22, 23, 23, 23,
         24, 24, 24]])


True

In [138]:
from torch_geometric.loader.utils import (
    filter_data,
    to_csc
)
batch_nodes = torch.tensor([13687])

colptr, row, perm = to_csc(data, torch.device('cpu'))

sample_fn = torch.ops.torch_sparse.neighbor_sample
node, row, col, edge = sample_fn(
    colptr,
    row,
    batch_nodes,
    [-1, -1],
    False,  # replace
    True,  # directed
)
batch_data = filter_data(data, node, row, col, edge, perm)
batch_data.batch_size = batch_nodes.numel()
batch_data

Data(x=[822, 500], edge_index=[2, 1737], y=[822], train_mask=[822], val_mask=[822], test_mask=[822], batch_size=1)

---

In [5]:
import torch
torch.multinomial(torch.tensor([]), 0)

RuntimeError: cannot sample n_sample <= 0 samples

In [6]:
torch.clamp(torch.tensor(2), 1, 0)

tensor(0)