Skip to content

Commit

Permalink
Add Scharr filter Module
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Apr 15, 2022
1 parent e3e0a89 commit e2fdd62
Showing 1 changed file with 51 additions and 1 deletion.
52 changes: 51 additions & 1 deletion celldetection/models/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Union
from torch.nn.common_types import _size_2_t

__all__ = ['Filter2d', 'PascalFilter2d']
__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d']


class Filter2d(nn.Conv2d):
Expand Down Expand Up @@ -132,3 +132,53 @@ def get_kernel1d(kernel_size, normalize=True):
def get_kernel2d(kernel_size, normalize=True):
k = PascalFilter2d.get_kernel1d(kernel_size, normalize)
return torch.as_tensor(np.outer(k, k))


class ScharrFilter2d(Filter2d):
def __init__(
self,
in_channels: int,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None,
odd_padding=True,
trainable=False,
transpose=False
) -> None:
"""Scharr Filter 2d.
Applies the Scharr gradient operator, a 3x3 kernel optimized for rotational symmetry.
References:
- https://archiv.ub.uni-heidelberg.de/volltextserver/962/
- https://en.wikipedia.org/wiki/Sobel_operator#Alternative_operators
Args:
in_channels: Number of input channels.
stride: Stride.
padding: Padding.
dilation: Spacing between kernel elements.
padding_mode: One of ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``.
device: Device.
dtype: Data type.
odd_padding: Whether to apply one-sided padding to account for even kernel sizes.
trainable: Whether the kernel should be trainable.
transpose: ``False`` for :math:`h_x` kernel, ``True`` for :math:`h_y` kernel.
"""
super().__init__(in_channels=in_channels, kernel=self.get_kernel2d(transpose),
stride=stride, padding=padding, dilation=dilation, odd_padding=odd_padding,
trainable=trainable, padding_mode=padding_mode, device=device, dtype=dtype)

@staticmethod
def get_kernel2d(transpose=False):
kernel = torch.as_tensor([
[47, 0, -47],
[162, 0, -162],
[47, 0, -47.],
])
if transpose:
kernel = kernel.T
return kernel

0 comments on commit e2fdd62

Please sign in to comment.