-
-
Notifications
You must be signed in to change notification settings - Fork 958
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
218 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,7 @@ | |
spatial_gradient, | ||
box_blur, | ||
median_blur, | ||
filter2D, | ||
) | ||
from kornia.losses import ( | ||
ssim, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from typing import Tuple, List | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def compute_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int, int, int]: | ||
"""Computes padding tuple.""" | ||
# 4 ints: (padding_left, padding_right,padding_top,padding_bottom) | ||
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad | ||
assert len(kernel_size) == 2, kernel_size | ||
computed = [(k - 1) // 2 for k in kernel_size] | ||
return computed[1], computed[1], computed[0], computed[0] | ||
|
||
|
||
def filter2D(input: torch.Tensor, kernel: torch.Tensor, | ||
border_type: str = 'reflect') -> torch.Tensor: | ||
r"""Function that convolves a tensor with a kernel. | ||
The function applies a given kernel to a tensor. The kernel is applied | ||
indepentdently at each depth channel of the tensor. Before applying the | ||
kernel, the function applies padding according to the specified mode so | ||
that the output reaims in the same shape. | ||
Args: | ||
input (torch.Tensor): the input tensor with shape of | ||
:math:`(B, C, H, W)`. | ||
kernel (torch.Tensor): the kernel to be convolved with the input | ||
tensor. The kernel shape must be :math:`(B, kH, kW)`. | ||
borde_type (str): the padding mode to be applied before convolving. | ||
The expected modes are: ``'constant'``, ``'reflect'``, | ||
``'replicate'`` or ``'circular'``. Default: ``'reflect'``. | ||
Return: | ||
torch.Tensor: the convolved tensor of same size and numbers of channels | ||
as the input. | ||
""" | ||
if not isinstance(input, torch.Tensor): | ||
raise TypeError("Input type is not a torch.Tensor. Got {}" | ||
.format(type(input))) | ||
|
||
if not isinstance(kernel, torch.Tensor): | ||
raise TypeError("Input kernel type is not a torch.Tensor. Got {}" | ||
.format(type(kernel))) | ||
|
||
if not isinstance(border_type, str): | ||
raise TypeError("Input border_type is not string. Got {}" | ||
.format(type(kernel))) | ||
|
||
if not len(input.shape) == 4: | ||
raise ValueError("Invalid input shape, we expect BxCxHxW. Got: {}" | ||
.format(input.shape)) | ||
|
||
if not len(kernel.shape) == 3: | ||
raise ValueError("Invalid kernel shape, we expect BxHxW. Got: {}" | ||
.format(kernel.shape)) | ||
|
||
borders_list: List[str] = ['constant', 'reflect', 'replicate', 'circular'] | ||
if border_type not in borders_list: | ||
raise ValueError("Invalid border_type, we expect the following: {0}." | ||
"Got: {1}".format(borders_list, border_type)) | ||
|
||
# prepare kernel | ||
b, c, h, w = input.shape | ||
tmp_kernel: torch.Tensor = kernel.to(input.device).to(input.dtype) | ||
tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1) | ||
|
||
# pad the input tensor | ||
height, width = tmp_kernel.shape[-2:] | ||
padding_shape: Tuple[int, int, int, int] = compute_padding((height, width)) | ||
input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) | ||
|
||
# convolve the tensor with the kernel | ||
return F.conv2d(input_pad, tmp_kernel, padding=0, stride=1, groups=c) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import Tuple | ||
|
||
import pytest | ||
|
||
import kornia | ||
import kornia.testing as utils # test utils | ||
|
||
import torch | ||
from torch.testing import assert_allclose | ||
from torch.autograd import gradcheck | ||
|
||
|
||
class TestBoxBlur: | ||
def test_shape(self): | ||
inp = torch.zeros(1, 3, 4, 4) | ||
blur = kornia.filters.BoxBlur((3, 3)) | ||
assert blur(inp).shape == (1, 3, 4, 4) | ||
|
||
def test_shape_batch(self): | ||
inp = torch.zeros(2, 6, 4, 4) | ||
blur = kornia.filters.BoxBlur((3, 3)) | ||
assert blur(inp).shape == (2, 6, 4, 4) | ||
|
||
def test_kernel_3x3(self): | ||
inp = torch.tensor([[[ | ||
[1., 1., 1., 1., 1.], | ||
[1., 1., 1., 1., 1.], | ||
[1., 1., 1., 1., 1.], | ||
[2., 2., 2., 2., 2.], | ||
[2., 2., 2., 2., 2.] | ||
]]]) | ||
|
||
kernel_size = (3, 3) | ||
actual = kornia.filters.box_blur(inp, kernel_size) | ||
assert_allclose(actual[0, 0, 1, 1:4], torch.tensor(1.)) | ||
|
||
def test_kernel_5x5(self): | ||
inp = torch.tensor([[[ | ||
[1., 1., 1., 1., 1.], | ||
[1., 1., 1., 1., 1.], | ||
[1., 1., 1., 1., 1.], | ||
[2., 2., 2., 2., 2.], | ||
[2., 2., 2., 2., 2.] | ||
]]]) | ||
|
||
kernel_size = (5, 5) | ||
actual = kornia.filters.box_blur(inp, kernel_size) | ||
assert_allclose(actual[0, 0, 1, 2], torch.tensor(1.)) | ||
|
||
def test_kernel_5x5_batch(self): | ||
batch_size = 3 | ||
inp = torch.tensor([[[ | ||
[1., 1., 1., 1., 1.], | ||
[1., 1., 1., 1., 1.], | ||
[1., 1., 1., 1., 1.], | ||
[2., 2., 2., 2., 2.], | ||
[2., 2., 2., 2., 2.] | ||
]]]).repeat(batch_size, 1, 1, 1) | ||
|
||
kernel_size = (5, 5) | ||
actual = kornia.filters.box_blur(inp, kernel_size) | ||
assert_allclose(actual[0, 0, 1, 2], torch.tensor(1.)) | ||
|
||
def test_gradcheck(self): | ||
batch_size, channels, height, width = 1, 2, 5, 4 | ||
img = torch.rand(batch_size, channels, height, width) | ||
img = utils.tensor_to_gradcheck_var(img) # to var | ||
assert gradcheck(kornia.filters.box_blur, (img, (3, 3),), | ||
raise_exception=True) | ||
|
||
def test_jit(self): | ||
@torch.jit.script | ||
def op_script(input: torch.Tensor, | ||
kernel_size: Tuple[int, int]) -> torch.Tensor: | ||
return kornia.filters.box_blur(input, kernel_size) | ||
kernel_size = (3, 3) | ||
img = torch.rand(2, 3, 4, 5) | ||
actual = op_script(img, kernel_size) | ||
expected = kornia.filters.box_blur(img, kernel_size) | ||
assert_allclose(actual, expected) |