diff --git a/celldetection/models/filters.py b/celldetection/models/filters.py index f6ca960..219bc38 100644 --- a/celldetection/models/filters.py +++ b/celldetection/models/filters.py @@ -6,7 +6,7 @@ from typing import Union from torch.nn.common_types import _size_2_t -__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d'] +__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d', 'SobelFilter2d'] class Filter2d(nn.Conv2d): @@ -34,7 +34,7 @@ def __init__( ... [1, 0, -1], ... ], dtype=torch.float32) ... sobel_layer = Filter2d(in_channels=3, kernel=sobel, padding=1, trainable=False) - ... sobel_layer + ... sobel_layer, sobel_layer.weight (Filter2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3, bias=False), tensor([[ 1., 0., -1.], [ 2., 0., -2.], @@ -182,3 +182,52 @@ def get_kernel2d(transpose=False): if transpose: kernel = kernel.T return kernel + + +class SobelFilter2d(Filter2d): + def __init__( + self, + in_channels: int, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + padding_mode: str = 'zeros', + device=None, + dtype=None, + odd_padding=True, + trainable=False, + transpose=False + ) -> None: + """Sobel Filter 2d. + + Applies the 3x3 Sobel image gradient operator. + + References: + - https://en.wikipedia.org/wiki/Sobel_operator + + Args: + in_channels: Number of input channels. + stride: Stride. + padding: Padding. + dilation: Spacing between kernel elements. + padding_mode: One of ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``. + device: Device. + dtype: Data type. + odd_padding: Whether to apply one-sided padding to account for even kernel sizes. + trainable: Whether the kernel should be trainable. + transpose: ``False`` for :math:`h_x` kernel, ``True`` for :math:`h_y` kernel. + """ + super().__init__(in_channels=in_channels, kernel=self.get_kernel2d(transpose), + stride=stride, padding=padding, dilation=dilation, odd_padding=odd_padding, + trainable=trainable, padding_mode=padding_mode, device=device, dtype=dtype) + + @staticmethod + def get_kernel2d(transpose=False): + sobel = torch.as_tensor([ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1.], + ]) + if transpose: + sobel = sobel.T + return sobel