Skip to content

Commit

Permalink
Add Sobel filter Module
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Apr 19, 2022
1 parent e2fdd62 commit e99f05c
Showing 1 changed file with 51 additions and 2 deletions.
53 changes: 51 additions & 2 deletions celldetection/models/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Union
from torch.nn.common_types import _size_2_t

__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d']
__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d', 'SobelFilter2d']


class Filter2d(nn.Conv2d):
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
... [1, 0, -1],
... ], dtype=torch.float32)
... sobel_layer = Filter2d(in_channels=3, kernel=sobel, padding=1, trainable=False)
... sobel_layer
... sobel_layer, sobel_layer.weight
(Filter2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=3, bias=False),
tensor([[ 1., 0., -1.],
[ 2., 0., -2.],
Expand Down Expand Up @@ -182,3 +182,52 @@ def get_kernel2d(transpose=False):
if transpose:
kernel = kernel.T
return kernel


class SobelFilter2d(Filter2d):
def __init__(
self,
in_channels: int,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None,
odd_padding=True,
trainable=False,
transpose=False
) -> None:
"""Sobel Filter 2d.
Applies the 3x3 Sobel image gradient operator.
References:
- https://en.wikipedia.org/wiki/Sobel_operator
Args:
in_channels: Number of input channels.
stride: Stride.
padding: Padding.
dilation: Spacing between kernel elements.
padding_mode: One of ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``.
device: Device.
dtype: Data type.
odd_padding: Whether to apply one-sided padding to account for even kernel sizes.
trainable: Whether the kernel should be trainable.
transpose: ``False`` for :math:`h_x` kernel, ``True`` for :math:`h_y` kernel.
"""
super().__init__(in_channels=in_channels, kernel=self.get_kernel2d(transpose),
stride=stride, padding=padding, dilation=dilation, odd_padding=odd_padding,
trainable=trainable, padding_mode=padding_mode, device=device, dtype=dtype)

@staticmethod
def get_kernel2d(transpose=False):
sobel = torch.as_tensor([
[1, 0, -1],
[2, 0, -2],
[1, 0, -1.],
])
if transpose:
sobel = sobel.T
return sobel

0 comments on commit e99f05c

Please sign in to comment.