From 56d51c2dec38e5e108cf020fc07d5aae4cdbef13 Mon Sep 17 00:00:00 2001 From: ericup Date: Wed, 20 Apr 2022 17:45:19 +0200 Subject: [PATCH] Add box filter Module --- celldetection/models/filters.py | 40 ++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/celldetection/models/filters.py b/celldetection/models/filters.py index 219bc38..c4d5d86 100644 --- a/celldetection/models/filters.py +++ b/celldetection/models/filters.py @@ -6,7 +6,7 @@ from typing import Union from torch.nn.common_types import _size_2_t -__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d', 'SobelFilter2d'] +__all__ = ['Filter2d', 'PascalFilter2d', 'ScharrFilter2d', 'SobelFilter2d', 'BoxFilter2d'] class Filter2d(nn.Conv2d): @@ -231,3 +231,41 @@ def get_kernel2d(transpose=False): if transpose: sobel = sobel.T return sobel + + +class BoxFilter2d(Filter2d): + def __init__( + self, + in_channels: int, + kernel_size, + stride: _size_2_t = 1, + padding: Union[str, _size_2_t] = 0, + dilation: _size_2_t = 1, + padding_mode: str = 'zeros', + device=None, + dtype=None, + odd_padding=True, + trainable=False, + normalize=True + ) -> None: + """Box Filter 2d. + + Args: + in_channels: Number of input channels. + stride: Stride. + padding: Padding. + dilation: Spacing between kernel elements. + padding_mode: One of ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``. + device: Device. + dtype: Data type. + odd_padding: Whether to apply one-sided padding to account for even kernel sizes. + trainable: Whether the kernel should be trainable. + normalize: Whether to normalize the kernel to retain magnitude. + """ + super().__init__(in_channels=in_channels, kernel=self.get_kernel2d(kernel_size, normalize), + stride=stride, padding=padding, dilation=dilation, odd_padding=odd_padding, + trainable=trainable, padding_mode=padding_mode, device=device, dtype=dtype) + + @staticmethod + def get_kernel2d(kernel_size, normalize=True): + return torch.ones((kernel_size, kernel_size)) / (kernel_size ** 2 if normalize else 1)