diff --git a/CHANGELOG.md b/CHANGELOG.md index 259ae64a231d..41545494a34c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support `in_channels` with `tuple` in `GENConv` for bipartite message passing ([#5627](https://github.com/pyg-team/pytorch_geometric/pull/5627), [#5641](https://github.com/pyg-team/pytorch_geometric/pull/5641)) - Handle cases of not having enough possible negative edges in `RandomLinkSplit` ([#5642](https://github.com/pyg-team/pytorch_geometric/pull/5642)) - Fix `RGCN+pyg-lib` for `LongTensor` input ([#5610](https://github.com/pyg-team/pytorch_geometric/pull/5610)) -- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701] (https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714)), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722)) +- Improved type hint support ([#5603](https://github.com/pyg-team/pytorch_geometric/pull/5603), [#5659](https://github.com/pyg-team/pytorch_geometric/pull/5659), [#5664](https://github.com/pyg-team/pytorch_geometric/pull/5664), [#5665](https://github.com/pyg-team/pytorch_geometric/pull/5665), [#5666](https://github.com/pyg-team/pytorch_geometric/pull/5666), [#5667](https://github.com/pyg-team/pytorch_geometric/pull/5667), [#5668](https://github.com/pyg-team/pytorch_geometric/pull/5668), [#5669](https://github.com/pyg-team/pytorch_geometric/pull/5669), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5673), [#5675](https://github.com/pyg-team/pytorch_geometric/pull/5675), [#5673](https://github.com/pyg-team/pytorch_geometric/pull/5676), [#5678](https://github.com/pyg-team/pytorch_geometric/pull/5678), [#5682](https://github.com/pyg-team/pytorch_geometric/pull/5682), [#5683](https://github.com/pyg-team/pytorch_geometric/pull/5683), [#5684](https://github.com/pyg-team/pytorch_geometric/pull/5684), [#5685](https://github.com/pyg-team/pytorch_geometric/pull/5685), [#5687](https://github.com/pyg-team/pytorch_geometric/pull/5687), [#5688](https://github.com/pyg-team/pytorch_geometric/pull/5688), [#5695](https://github.com/pyg-team/pytorch_geometric/pull/5695), [#5699](https://github.com/pyg-team/pytorch_geometric/pull/5699), [#5701] (https://github.com/pyg-team/pytorch_geometric/pull/5701), [#5702](https://github.com/pyg-team/pytorch_geometric/pull/5702), [#5703](https://github.com/pyg-team/pytorch_geometric/pull/5703), [#5706](https://github.com/pyg-team/pytorch_geometric/pull/5706), [#5707](https://github.com/pyg-team/pytorch_geometric/pull/5707), [#5710](https://github.com/pyg-team/pytorch_geometric/pull/5710), [#5714](https://github.com/pyg-team/pytorch_geometric/pull/5714)), [#5722](https://github.com/pyg-team/pytorch_geometric/pull/5722), [#5729](https://github.com/pyg-team/pytorch_geometric/pull/5729)) - Avoid modifying `mode_kwargs` in `MultiAggregation` ([#5601](https://github.com/pyg-team/pytorch_geometric/pull/5601)) - Changed `BatchNorm` to allow for batches of size one during training ([#5530](https://github.com/pyg-team/pytorch_geometric/pull/5530), [#5614](https://github.com/pyg-team/pytorch_geometric/pull/5614)) - Integrated better temporal sampling support by requiring that local neighborhoods are sorted according to time ([#5516](https://github.com/pyg-team/pytorch_geometric/issues/5516), [#5602](https://github.com/pyg-team/pytorch_geometric/issues/5602)) diff --git a/test/nn/norm/test_graph_size_norm.py b/test/nn/norm/test_graph_size_norm.py index e8fa32ad4ea0..8b70795ed4d7 100644 --- a/test/nn/norm/test_graph_size_norm.py +++ b/test/nn/norm/test_graph_size_norm.py @@ -1,10 +1,19 @@ import torch from torch_geometric.nn import GraphSizeNorm +from torch_geometric.testing import is_full_test def test_graph_size_norm(): + x = torch.randn(100, 16) batch = torch.repeat_interleave(torch.full((10, ), 10, dtype=torch.long)) + norm = GraphSizeNorm() - out = norm(torch.randn(100, 16), batch) + assert str(norm) == 'GraphSizeNorm()' + + out = norm(x, batch) assert out.size() == (100, 16) + + if is_full_test(): + jit = torch.jit.script(norm) + assert torch.allclose(jit(x, batch), out) diff --git a/torch_geometric/nn/norm/graph_size_norm.py b/torch_geometric/nn/norm/graph_size_norm.py index f3ba10bc928e..38e638491b66 100644 --- a/torch_geometric/nn/norm/graph_size_norm.py +++ b/torch_geometric/nn/norm/graph_size_norm.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from torch import Tensor +from torch_geometric.typing import OptTensor from torch_geometric.utils import degree @@ -16,10 +18,13 @@ class GraphSizeNorm(nn.Module): def __init__(self): super().__init__() - def forward(self, x, batch=None): + def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor: """""" if batch is None: batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) inv_sqrt_deg = degree(batch, dtype=x.dtype).pow(-0.5) return x * inv_sqrt_deg.index_select(0, batch).view(-1, 1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()'