Skip to content

Commit

Permalink
Fix: (SSIM) propagate device if gaussian_kernel is False, add test (#…
Browse files Browse the repository at this point in the history
…1149)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
(cherry picked from commit 45ac6b0)
  • Loading branch information
krshrimali authored and Borda committed Jul 22, 2022
1 parent af8b327 commit 7a5e9b7
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -46,6 +46,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125))


- Fix SSIM propagate device if `gaussian_kernel` is False, add test ([#1149](https://github.com/Lightning-AI/metrics/pull/1149))



## [0.9.2] - 2022-06-29

Expand Down
17 changes: 17 additions & 0 deletions tests/image/test_ssim.py
Expand Up @@ -125,6 +125,23 @@ def test_ssim(self, preds, target, sigma, ddp, dist_sync_on_step):
dist_sync_on_step=dist_sync_on_step,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_ssim_without_gaussian_kernel(self, preds, target, sigma, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
StructuralSimilarityIndexMeasure,
partial(_sk_ssim, data_range=1.0, sigma=sigma, kernel_size=None),
metric_args={
"gaussian_kernel": False,
"data_range": 1.0,
"sigma": sigma,
},
dist_sync_on_step=dist_sync_on_step,
)

def test_ssim_functional(self, preds, target, sigma):
self.run_functional_metric_test(
preds,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/functional/image/ssim.py
Expand Up @@ -150,7 +150,9 @@ def _ssim_compute(
kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device)

if not gaussian_kernel:
kernel = torch.ones((1, 1, *kernel_size)) / torch.prod(Tensor(kernel_size))
kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod(
torch.tensor(kernel_size, dtype=dtype, device=device)
)

input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)

Expand Down

0 comments on commit 7a5e9b7

Please sign in to comment.