Skip to content

Commit

Permalink
Add box filter Module
Browse files Browse the repository at this point in the history
  • Loading branch information
ericup committed Apr 20, 2022
1 parent e99f05c commit 56d51c2
Showing 1 changed file with 39 additions and 1 deletion.
40 changes: 39 additions & 1 deletion celldetection/models/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Union
from torch.nn.common_types import _size_2_t

__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d', 'SobelFilter2d']
__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d', 'SobelFilter2d', 'BoxFilter2d']


class Filter2d(nn.Conv2d):
Expand Down Expand Up @@ -231,3 +231,41 @@ def get_kernel2d(transpose=False):
if transpose:
sobel = sobel.T
return sobel


class BoxFilter2d(Filter2d):
def __init__(
self,
in_channels: int,
kernel_size,
stride: _size_2_t = 1,
padding: Union[str, _size_2_t] = 0,
dilation: _size_2_t = 1,
padding_mode: str = 'zeros',
device=None,
dtype=None,
odd_padding=True,
trainable=False,
normalize=True
) -> None:
"""Box Filter 2d.
Args:
in_channels: Number of input channels.
stride: Stride.
padding: Padding.
dilation: Spacing between kernel elements.
padding_mode: One of ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``.
device: Device.
dtype: Data type.
odd_padding: Whether to apply one-sided padding to account for even kernel sizes.
trainable: Whether the kernel should be trainable.
normalize: Whether to normalize the kernel to retain magnitude.
"""
super().__init__(in_channels=in_channels, kernel=self.get_kernel2d(kernel_size, normalize),
stride=stride, padding=padding, dilation=dilation, odd_padding=odd_padding,
trainable=trainable, padding_mode=padding_mode, device=device, dtype=dtype)

@staticmethod
def get_kernel2d(kernel_size, normalize=True):
return torch.ones((kernel_size, kernel_size)) / (kernel_size ** 2 if normalize else 1)

0 comments on commit 56d51c2

Please sign in to comment.