In [None]:
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
from torch_geometric.utils import scatter
import torch.nn as nn
import torch

dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

softmax = nn.Softmax(dim=-1)

for step, data in enumerate(loader):
    edge_index, batch = data.edge_index, data.batch

    x = torch.rand(batch.size(0))
    # print(x)

    num_nodes = scatter(batch.new_ones(x.size(0)), batch, reduce='sum')
    print(num_nodes)
    batch_size, max_num_nodes = num_nodes.size(0), int(num_nodes.max())

    cum_num_nodes = torch.cat(
        [num_nodes.new_zeros(1),
         num_nodes.cumsum(dim=0)[:-1]], dim=0)

    index = torch.arange(batch.size(0), dtype=torch.long)
    # print(index)
    # index - cum_num_nodes[batch] 得到的是每个点在各自图中的编号
    index = (index - cum_num_nodes[batch]) + (batch * max_num_nodes)
    # m = batch * max_num_nodes
    # print(index)

    dense_x = x.new_full((batch_size * max_num_nodes, ), -60000.0)
    # print(dense_x.size())
    dense_x[index] = x
    # print(x[0:max_num_nodes*3])
    # print(dense_x[0:max_num_nodes*3])
    dense_x = dense_x.view(batch_size, max_num_nodes)
    # print(dense_x[:3])
    dense_x = softmax(dense_x)
    # print(dense_x[-10:])
    probs, perm = dense_x.sort(dim=-1, descending=True)
    # print(perm)
    cum_probs = torch.cumsum(probs, dim=-1)

    nucleus = (cum_probs < 0.5)+0
    k = torch.count_nonzero(nucleus, dim=1).reshape(-1)
    k = torch.clamp(k, min=1)
    perm = perm + cum_num_nodes.view(-1, 1)
    # print(perm)
    perm = perm.view(-1)

    index = torch.cat([
        torch.arange(k[i], device=x.device) + i * max_num_nodes
        for i in range(batch_size)
    ], dim=0)
    print(index)

    perm = perm[index]

    break

In [None]:
import torch
import torch.nn as nn

softmax = nn.Softmax(dim=-1)
temperature = 1.0

x = torch.rand(20, 32)
x = x/float(temperature)
probs, indices = x.sort(dim=-1, descending=True)
# print(x)
# print(probs)
sm = softmax(probs)
# print(sm)
grad = sm[:, 1:] - sm[:, :-1]
grad = grad[:, 1:] - grad[:, :-1]
# print(grad)
only_pos = torch.abs(grad)
sum = torch.sum(only_pos, dim=1).view(-1, 1)
sec_weights = only_pos/sum
# print(sec_weights)
cum_weights = (torch.cumsum(sec_weights, dim=1) > 0.9)+0
# print(cum_weights)
tail_ids = torch.argmax(cum_weights, dim=1)+1
k = torch.clamp(tail_ids, min=1)

In [None]:
import torch
import torch.nn.functional as F
from torch import Tensor

x = torch.rand(32, 16)
sorted_scores, sorted_indices = torch.sort(
    x, descending=False, dim=1
)
# print(x)
# print(sorted_scores)
cdf = torch.cumsum(sorted_scores, dim=1)
# print(cdf)
# print(cdf - cdf.min(dim=1)[0].unsqueeze(dim=1))

normalized_cdf = (
    cdf - cdf.min(dim=1)[0].unsqueeze(dim=1)
) / ((cdf.max(dim=1)[0] - cdf.min(dim=1)[0]) / 1.0).unsqueeze(dim=1)
# print(normalized_cdf.size())

B = normalized_cdf.shape[0]
# print(B)

n_tokens = int(normalized_cdf.shape[1]/2)
# print(n_tokens)

ys = (
    torch.linspace(
        start=0,
        end=1.0,
        steps=n_tokens,
        device=normalized_cdf.device,
    )
    .unsqueeze(0)
    .repeat(B, 1)
)
# print(ys)

ys_start = (
    torch.min(normalized_cdf + (normalized_cdf == 0).float() * 1e8, dim=1)[0]
    .unsqueeze(-1)
    .expand_as(ys)
)
# print(ys_start)

steps = (
    torch.range(0, n_tokens - 1, device=normalized_cdf.device)
    .unsqueeze(0)
    .expand_as(ys_start)
)
# print(steps)
ys = ys_start + (((ys * (n_tokens - 1)) - ys_start * steps) / (n_tokens - 1))
ys = ys.unsqueeze(dim=2)
# print(ys.size())

normalized_cdf = normalized_cdf.unsqueeze(dim=1)
# print(normalized_cdf.size())

expanded_ys = torch.Tensor.expand(ys, (B, ys.shape[1], ys.shape[1]))
# print(expanded_ys.size())

N = sorted_scores.shape[1]
diff_tokens = ys.shape[1] - N
# print(diff_tokens)
# print(torch.abs(expanded_ys - F.pad(normalized_cdf, (diff_tokens, 0))))

tokens_to_pick_ind = torch.min(
    torch.abs(expanded_ys - F.pad(normalized_cdf, (diff_tokens, 0))),
    dim=2,
)[1]
print(tokens_to_pick_ind.size())
tokens_to_pick_ind = tokens_to_pick_ind - diff_tokens
print(tokens_to_pick_ind)


def get_unique_indices(indices: Tensor, max_value: int) -> Tensor:
    sorted_indices = torch.sort(indices, dim=1)[0]

    shift_left = F.pad(sorted_indices[:, 1:], (0, 1), value=1.0)
    unique_indices = torch.where(
        (shift_left - sorted_indices) == 0,
        max_value * torch.ones_like(indices),
        sorted_indices,
    )

    unique_indices = torch.sort(unique_indices, dim=1)[0]

    return unique_indices


unique_indices = get_unique_indices(
    indices=tokens_to_pick_ind, max_value=N - 1
)  # [B x n_tokens]


raw_indice = torch.gather(
    sorted_indices, 1, unique_indices
)

print(raw_indice)

In [None]:
import torch
import numpy as np
from scipy import interpolate
import matplotlib.pyplot as plt
from numpy.random import random

x = torch.arange(8)
print(x)
y = torch.rand(8)
print(y)
cdf_y = torch.cumsum(y, dim=-1)
cdf_y = cdf_y/cdf_y.max()
print(cdf_y)
out = interpolate.interp1d(cdf_y, x)
uniform_samples = random(int(8))

In [None]:
import torch
import torch.nn as nn

softmax = nn.Softmax(dim=-1)

N = 17
Max = 35
ratio = 0.5

a = torch.rand(N)
# print(a)
x = a.new_full((1 * Max, ), -60000.0)
# print(x)
index = torch.arange(17)
x[index] = a
# print(x)
sm = softmax(x)
# print(sm)
cdf = torch.cumsum(sm, dim=-1)
# print(cdf)

normalized_cdf = (
    cdf - cdf.min(dim=0)[0]
) / ((cdf.max(dim=0)[0] - cdf.min(dim=0)[0]) / 1.0)
print(normalized_cdf)

T = int(ratio*N)

ys = torch.linspace(
    start=0,
    end=1.0,
    steps=T+2,
)[1:T+1].view(-1, 1)
print(ys)

pre_cdf = normalized_cdf.repeat(T, 1)
# print(pre_cdf)
print((pre_cdf > ys)+0)
pre_selected = torch.argmax((pre_cdf > ys)+0, dim=1)
print(pre_selected)
xxx = torch.arange(10)
out = torch.cat([pre_selected, xxx])
indexx = torch.arange(5)
print(out)
pee = out[indexx]
print(pee)

In [212]:
a, perm = torch.arange(10).sort(dim=0, descending=True)
b = torch.tensor([2, 2])
print(b)
print(a)
print(a[b].unique())

tensor([2, 2])
tensor([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
tensor([7])
