Skip to content

Commit

Permalink
fix missing ceil_mode in export
Browse files Browse the repository at this point in the history
Signed-off-by: chenhe <chenhe@megarobo.tech>
  • Loading branch information
chAwater committed Jul 18, 2023
1 parent 630de18 commit 655a059
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions escnn/nn/modules/pooling/pointwise_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def export(self):
self.eval()

if self.d == 2:
return torch.nn.MaxPool2d(self.kernel_size, self.stride, self.padding, self.dilation).eval()
return torch.nn.MaxPool2d(self.kernel_size, self.stride, self.padding, self.dilation, ceil_mode=self.ceil_mode).eval()
elif self.d == 3:
return torch.nn.MaxPool3d(self.kernel_size, self.stride, self.padding, self.dilation).eval()
return torch.nn.MaxPool3d(self.kernel_size, self.stride, self.padding, self.dilation, ceil_mode=self.ceil_mode).eval()
else:
raise NotImplementedError

Expand Down Expand Up @@ -303,13 +303,13 @@ def __init__(self,
Channel-wise max-pooling: each channel is treated independently.
This module works exactly as :class:`torch.nn.MaxPool2D`, wrapping it in the
:class:`~escnn.nn.EquivariantModule` interface.
Notice that not all representations support this kind of pooling. In general, only representations which support
pointwise non-linearities do.
.. warning ::
Even if the input tensor has a `coords` attribute, the output of this module will not have one.
Args:
in_type (FieldType): the input field type
kernel_size: the size of the window to take a max over
Expand Down Expand Up @@ -457,5 +457,3 @@ def __init__(self,
# for backward compatibility
PointwiseMaxPool = PointwiseMaxPool2D
PointwiseMaxPoolAntialiased = PointwiseMaxPoolAntialiased2D


0 comments on commit 655a059

Please sign in to comment.