In [1]:
from typing import Optional

from torch import Tensor
from torch.nn import LSTM

from torch_geometric.nn.aggr import Aggregation


class LSTMAggregation(Aggregation):
    r"""Performs LSTM-style aggregation in which the elements to aggregate are
    interpreted as a sequence, as described in the `"Inductive Representation
    Learning on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper.

    .. warning::
        :class:`LSTMAggregation` is not a permutation-invariant operator.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        **kwargs (optional): Additional arguments of :class:`torch.nn.LSTM`.
    """
    def __init__(self, in_channels: int, out_channels: int, **kwargs):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs)
        self.reset_parameters()

    def reset_parameters(self):
        self.lstm.reset_parameters()


    def forward(self, x: Tensor, index: Optional[Tensor] = None,
                ptr: Optional[Tensor] = None, dim_size: Optional[int] = None,
                dim: int = -2) -> Tensor:
        x, _ = self.to_dense_batch(x, index, ptr, dim_size, dim)
        return self.lstm(x)[0][:, -1]


    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels})')


In [3]:
b = torch.tensor([3,9,6])
a = torch.rand(3,4)
print(a, b)
sorted, idx = torch.sort(b)
print(idx)
a = a[idx]
print(a)

tensor([[1.4817e-02, 1.3328e-04, 3.1582e-01, 2.1440e-01],
        [9.7630e-02, 3.3602e-01, 6.9767e-01, 2.0946e-01],
        [3.8457e-01, 7.2020e-01, 1.1124e-01, 9.1535e-01]]) tensor([3, 9, 6])
tensor([0, 2, 1])
tensor([[1.4817e-02, 1.3328e-04, 3.1582e-01, 2.1440e-01],
        [3.8457e-01, 7.2020e-01, 1.1124e-01, 9.1535e-01],
        [9.7630e-02, 3.3602e-01, 6.9767e-01, 2.0946e-01]])
