|
2 | 2 | import torch
|
3 | 3 | import numpy as np
|
4 | 4 | from skimage.metrics import peak_signal_noise_ratio as ski_psnr
|
| 5 | +from skimage.metrics import structural_similarity as ski_ssim |
5 | 6 |
|
6 | 7 | from pytorch_lightning.metrics.functional import (
|
7 | 8 | mae,
|
8 | 9 | mse,
|
9 | 10 | psnr,
|
10 | 11 | rmse,
|
11 |
| - rmsle |
| 12 | + rmsle, |
| 13 | + ssim |
12 | 14 | )
|
13 | 15 |
|
14 | 16 |
|
@@ -93,3 +95,50 @@ def test_psnr_against_sklearn(sklearn_metric, torch_metric):
|
93 | 95 | sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
|
94 | 96 | pl_score = torch_metric(pred, target, data_range=n_cls_target)
|
95 | 97 | assert torch.allclose(sk_score, pl_score)
|
| 98 | + |
| 99 | + |
| 100 | +@pytest.mark.parametrize(['size', 'channel', 'plus', 'multichannel'], [ |
| 101 | + pytest.param(16, 1, 0.125, False), |
| 102 | + pytest.param(32, 1, 0.25, False), |
| 103 | + pytest.param(48, 3, 0.5, True), |
| 104 | + pytest.param(64, 4, 0.75, True), |
| 105 | + pytest.param(128, 5, 1, True) |
| 106 | +]) |
| 107 | +def test_ssim(size, channel, plus, multichannel): |
| 108 | + device = "cuda" if torch.cuda.is_available() else "cpu" |
| 109 | + pred = torch.rand(1, channel, size, size, device=device) |
| 110 | + target = pred + plus |
| 111 | + ssim_idx = ssim(pred, target) |
| 112 | + np_pred = np.random.rand(size, size, channel) |
| 113 | + if multichannel is False: |
| 114 | + np_pred = np_pred[:, :, 0] |
| 115 | + np_target = np.add(np_pred, plus) |
| 116 | + sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True) |
| 117 | + assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2) |
| 118 | + |
| 119 | + ssim_idx = ssim(pred, pred) |
| 120 | + assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device)) |
| 121 | + |
| 122 | + |
| 123 | +@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [ |
| 124 | + pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape |
| 125 | + pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape) |
| 126 | + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma) |
| 127 | + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma) |
| 128 | + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma) |
| 129 | + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input |
| 130 | + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input |
| 131 | + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input |
| 132 | + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input |
| 133 | + pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input |
| 134 | +]) |
| 135 | +def test_ssim_invalid_inputs(pred, target, kernel, sigma): |
| 136 | + pred_t = torch.rand(pred) |
| 137 | + target_t = torch.rand(target, dtype=torch.float64) |
| 138 | + with pytest.raises(TypeError): |
| 139 | + ssim(pred_t, target_t) |
| 140 | + |
| 141 | + pred = torch.rand(pred) |
| 142 | + target = torch.rand(target) |
| 143 | + with pytest.raises(ValueError): |
| 144 | + ssim(pred, target, kernel, sigma) |
0 commit comments