In [None]:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import scatter
import torch.nn as nn
import torch

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

softmax = nn.Softmax(dim=-1)

for step, data in enumerate(loader):
    edge_index, batch = data.edge_index, data.batch

    x = torch.rand(batch.size(0))
    # print(x)

    num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')
    print(num_nodes)
    batch_size, max_num_nodes = num_nodes.size(0), int(num_nodes.max())

    cum_num_nodes = torch.cat(
        [num_nodes.new_zeros(1),
         num_nodes.cumsum(dim=0)[:-1]], dim=0)

    index = torch.arange(batch.size(0), dtype=torch.long)
    # print(index)
    # index - cum_num_nodes[batch] 得到的是每个点在各自图中的编号
    index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
    # m = batch * max_num_nodes
    # print(index)

    dense_x = x.new_full((batch_size * max_num_nodes, ), -60000.0)
    # print(dense_x.size())
    dense_x[index] = x
    # print(x[0:max_num_nodes*3])
    # print(dense_x[0:max_num_nodes*3])
    dense_x = dense_x.view(batch_size, max_num_nodes)
    # print(dense_x[:3])
    dense_x = softmax(dense_x)
    # print(dense_x[-10:])
    probs, perm = dense_x.sort(dim=-1, descending=True)
    # print(perm)
    cum_probs = torch.cumsum(probs, dim=-1)

    nucleus = (cum_probs < 0.5)+0
    k = torch.count_nonzero(nucleus, dim=1).reshape(-1)
    k = torch.clamp(k, min=1)
    perm = perm + cum_num_nodes.view(-1, 1)
    # print(perm)
    perm = perm.view(-1)

    index = torch.cat([
        torch.arange(k[i], device=x.device) + i * max_num_nodes
        for i in range(batch_size)
    ], dim=0)

    perm = perm[index]

    break

In [62]:
import torch
import torch.nn as nn

softmax = nn.Softmax(dim=-1)

x = torch.rand(20, 32)
probs, indices = x.sort(dim=-1, descending=True)
# print(x)
# print(probs)
sm = softmax(probs)
# print(sm)
grad = sm[:, 1:] - sm[:, :-1]
grad = grad[:, 1:] - grad[:, :-1]
# print(grad)
only_pos = torch.abs(grad)
sum = torch.sum(only_pos, dim=1).view(-1, 1)
sec_weights = only_pos/sum
print(sec_weights)
cum_weights = (torch.cumsum(sec_weights, dim=1) > 0.9)+0
print(cum_weights)
tail_ids = torch.argmax(cum_weights, dim=1)+1
tail_ids = tail_ids
print(tail_ids)
logits = torch.arange(x.size(0))
logits_inds = torch.stack((logits, tail_ids), dim=1)
# print(logits_inds)

tensor([[2.5583e-02, 1.6077e-02, 8.1055e-02, 1.5515e-02, 2.3373e-02, 1.5247e-02,
         6.0502e-02, 5.6048e-02, 5.8988e-02, 1.1227e-04, 5.0377e-03, 1.2401e-01,
         1.2557e-01, 3.4591e-03, 3.1316e-02, 1.5845e-02, 6.2644e-02, 6.0905e-02,
         8.3703e-03, 8.4836e-03, 5.3586e-02, 5.7458e-02, 1.3956e-02, 2.5187e-02,
         1.2593e-02, 5.8418e-03, 1.7118e-02, 4.0480e-03, 7.6291e-03, 4.4511e-03],
        [6.3088e-03, 4.2135e-02, 5.9452e-02, 3.9655e-03, 1.6937e-01, 1.2830e-01,
         3.3971e-02, 8.6658e-03, 8.9451e-03, 9.1150e-02, 1.2069e-01, 1.6342e-03,
         2.4037e-03, 2.2121e-03, 3.5896e-03, 5.5878e-02, 1.1495e-02, 2.1983e-02,
         1.2883e-02, 6.7439e-03, 1.6919e-02, 3.2339e-02, 2.7544e-02, 2.5374e-02,
         4.9401e-02, 1.2011e-02, 2.6817e-02, 5.6069e-03, 1.2048e-02, 1.5602e-04],
        [7.7877e-02, 1.0929e-02, 1.3849e-02, 7.3203e-02, 8.8657e-03, 1.1420e-02,
         2.6735e-02, 9.0363e-02, 3.7434e-02, 2.6756e-02, 8.2849e-03, 2.6076e-02,
         3.8743e-02, 8.100