Skip to content

Commit bda7cf1

Browse files
author
Jeff Yang
authored
metrics: add SSIM (Lightning-AI#2671)
* metrics: add SSIM * Update CHANGELOG.md fix codefactor issue fix doctest fix doctest fix test * added test for raise Error
1 parent d0b8e85 commit bda7cf1

File tree

8 files changed

+255
-3
lines changed

8 files changed

+255
-3
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))
1213
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))
1314

1415
- Added support for `Trainer(num_sanity_val_steps=-1)` to check all validation data before training ([#2246](https://github.com/PyTorchLightning/pytorch-lightning/pull/2246))

docs/source/metrics.rst

+12
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,12 @@ RMSLE
234234
.. autoclass:: pytorch_lightning.metrics.regression.RMSLE
235235
:noindex:
236236

237+
SSIM
238+
^^^^
239+
240+
.. autoclass:: pytorch_lightning.metrics.regression.SSIM
241+
:noindex:
242+
237243
----------------
238244

239245
Functional Metrics
@@ -403,6 +409,12 @@ psnr (F)
403409
.. autofunction:: pytorch_lightning.metrics.functional.psnr
404410
:noindex:
405411

412+
ssim (F)
413+
^^^^^^^^
414+
415+
.. autofunction:: pytorch_lightning.metrics.functional.ssim
416+
:noindex:
417+
406418
stat_scores_multiple_classes (F)
407419
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
408420

pytorch_lightning/metrics/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
PSNR,
77
RMSE,
88
RMSLE,
9+
SSIM
910
)
1011
from pytorch_lightning.metrics.classification import (
1112
Accuracy,
@@ -54,6 +55,7 @@
5455
"PSNR",
5556
"RMSE",
5657
"RMSLE",
58+
"SSIM"
5759
]
5860
__sequence_metrics = ["BLEUScore"]
5961
__all__ = __regression_metrics + __classification_metrics + ["SklearnMetric"] + __sequence_metrics

pytorch_lightning/metrics/functional/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -26,5 +26,6 @@
2626
psnr,
2727
rmse,
2828
rmsle,
29+
ssim
2930
)
3031
from pytorch_lightning.metrics.functional.nlp import bleu_score

pytorch_lightning/metrics/functional/regression.py

+115
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Sequence
2+
13
import torch
24
from torch.nn import functional as F
35

@@ -182,3 +184,116 @@ def psnr(
182184
psnr_base_e = 2 * torch.log(data_range) - torch.log(mse_score)
183185
psnr = psnr_base_e * (10 / torch.log(torch.tensor(base)))
184186
return psnr
187+
188+
189+
def _gaussian_kernel(channel, kernel_size, sigma, device):
190+
def gaussian(kernel_size, sigma, device):
191+
gauss = torch.arange(
192+
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device
193+
)
194+
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
195+
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
196+
197+
gaussian_kernel_x = gaussian(kernel_size[0], sigma[0], device)
198+
gaussian_kernel_y = gaussian(kernel_size[1], sigma[1], device)
199+
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
200+
201+
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])
202+
203+
204+
def ssim(
205+
pred: torch.Tensor,
206+
target: torch.Tensor,
207+
kernel_size: Sequence[int] = (11, 11),
208+
sigma: Sequence[float] = (1.5, 1.5),
209+
reduction: str = "elementwise_mean",
210+
data_range: float = None,
211+
k1: float = 0.01,
212+
k2: float = 0.03
213+
) -> torch.Tensor:
214+
"""
215+
Computes Structual Similarity Index Measure
216+
217+
Args:
218+
pred: Estimated image
219+
target: Ground truth image
220+
kernel_size: Size of the gaussian kernel. Default: (11, 11)
221+
sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
222+
reduction: A method for reducing ssim over all elements in the ``pred`` tensor. Default: ``elementwise_mean``
223+
224+
Available reduction methods:
225+
- elementwise_mean: takes the mean
226+
- none: pass away
227+
- sum: add elements
228+
229+
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
230+
k1: Parameter of SSIM. Default: 0.01
231+
k2: Parameter of SSIM. Default: 0.03
232+
233+
Returns:
234+
A Tensor with SSIM
235+
236+
Example:
237+
238+
>>> pred = torch.rand([16, 1, 16, 16])
239+
>>> target = pred * 1.25
240+
>>> ssim(pred, target)
241+
tensor(0.9520)
242+
"""
243+
244+
if pred.dtype != target.dtype:
245+
raise TypeError(
246+
"Expected `pred` and `target` to have the same data type."
247+
f" Got pred: {pred.dtype} and target: {target.dtype}."
248+
)
249+
250+
if pred.shape != target.shape:
251+
raise ValueError(
252+
"Expected `pred` and `target` to have the same shape."
253+
f" Got pred: {pred.shape} and target: {target.shape}."
254+
)
255+
256+
if len(pred.shape) != 4 or len(target.shape) != 4:
257+
raise ValueError(
258+
"Expected `pred` and `target` to have BxCxHxW shape."
259+
f" Got pred: {pred.shape} and target: {target.shape}."
260+
)
261+
262+
if len(kernel_size) != 2 or len(sigma) != 2:
263+
raise ValueError(
264+
"Expected `kernel_size` and `sigma` to have the length of two."
265+
f" Got kernel_size: {len(kernel_size)} and sigma: {len(sigma)}."
266+
)
267+
268+
if any(x % 2 == 0 or x <= 0 for x in kernel_size):
269+
raise ValueError(f"Expected `kernel_size` to have odd positive number. Got {kernel_size}.")
270+
271+
if any(y <= 0 for y in sigma):
272+
raise ValueError(f"Expected `sigma` to have positive number. Got {sigma}.")
273+
274+
if data_range is None:
275+
data_range = max(pred.max() - pred.min(), target.max() - target.min())
276+
277+
C1 = pow(k1 * data_range, 2)
278+
C2 = pow(k2 * data_range, 2)
279+
device = pred.device
280+
281+
channel = pred.size(1)
282+
kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
283+
mu_pred = F.conv2d(pred, kernel, groups=channel)
284+
mu_target = F.conv2d(target, kernel, groups=channel)
285+
286+
mu_pred_sq = mu_pred.pow(2)
287+
mu_target_sq = mu_target.pow(2)
288+
mu_pred_target = mu_pred * mu_target
289+
290+
sigma_pred_sq = F.conv2d(pred * pred, kernel, groups=channel) - mu_pred_sq
291+
sigma_target_sq = F.conv2d(target * target, kernel, groups=channel) - mu_target_sq
292+
sigma_pred_target = F.conv2d(pred * target, kernel, groups=channel) - mu_pred_target
293+
294+
UPPER = 2 * sigma_pred_target + C2
295+
LOWER = sigma_pred_sq + sigma_target_sq + C2
296+
297+
ssim_idx = ((2 * mu_pred_target + C1) * UPPER) / ((mu_pred_sq + mu_target_sq + C1) * LOWER)
298+
299+
return reduce(ssim_idx, reduction)

pytorch_lightning/metrics/regression.py

+63-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1+
from typing import Sequence
2+
13
import torch
24

35
from pytorch_lightning.metrics.functional.regression import (
46
mae,
57
mse,
68
psnr,
79
rmse,
8-
rmsle
10+
rmsle,
11+
ssim
912
)
1013
from pytorch_lightning.metrics.metric import Metric
1114

@@ -229,3 +232,62 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
229232
A Tensor with psnr score.
230233
"""
231234
return psnr(pred, target, self.data_range, self.base, self.reduction)
235+
236+
237+
class SSIM(Metric):
238+
"""
239+
Computes Structual Similarity Index Measure
240+
241+
Example:
242+
243+
>>> pred = torch.rand([16, 1, 16, 16])
244+
>>> target = pred * 1.25
245+
>>> metric = SSIM()
246+
>>> metric(pred, target)
247+
tensor(0.9520)
248+
"""
249+
250+
def __init__(
251+
self,
252+
kernel_size: Sequence[int] = (11, 11),
253+
sigma: Sequence[float] = (1.5, 1.5),
254+
reduction: str = "elementwise_mean",
255+
data_range: float = None,
256+
k1: float = 0.01,
257+
k2: float = 0.03
258+
):
259+
"""
260+
Args:
261+
kernel_size: Size of the gaussian kernel. Default: (11, 11)
262+
sigma: Standard deviation of the gaussian kernel. Default: (1.5, 1.5)
263+
reduction: A method for reducing ssim. Default: ``elementwise_mean``
264+
265+
Available reduction methods:
266+
- elementwise_mean: takes the mean
267+
- none: pass away
268+
- sum: add elements
269+
270+
data_range: Range of the image. If ``None``, it is determined from the image (max - min)
271+
k1: Parameter of SSIM. Default: 0.01
272+
k2: Parameter of SSIM. Default: 0.03
273+
"""
274+
super().__init__(name="ssim")
275+
self.kernel_size = kernel_size
276+
self.sigma = sigma
277+
self.reduction = reduction
278+
self.data_range = data_range
279+
self.k1 = k1
280+
self.k2 = k2
281+
282+
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
283+
"""
284+
Actual metric computation
285+
286+
Args:
287+
pred: Estimated image
288+
target: Ground truth image
289+
290+
Return:
291+
torch.Tensor: SSIM Score
292+
"""
293+
return ssim(pred, target, self.kernel_size, self.sigma, self.reduction, self.data_range, self.k1, self.k2)

tests/metrics/functional/test_regression.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import torch
33
import numpy as np
44
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
5+
from skimage.metrics import structural_similarity as ski_ssim
56

67
from pytorch_lightning.metrics.functional import (
78
mae,
89
mse,
910
psnr,
1011
rmse,
11-
rmsle
12+
rmsle,
13+
ssim
1214
)
1315

1416

@@ -93,3 +95,50 @@ def test_psnr_against_sklearn(sklearn_metric, torch_metric):
9395
sk_score = torch.tensor(sk_score, dtype=torch.float, device=device)
9496
pl_score = torch_metric(pred, target, data_range=n_cls_target)
9597
assert torch.allclose(sk_score, pl_score)
98+
99+
100+
@pytest.mark.parametrize(['size', 'channel', 'plus', 'multichannel'], [
101+
pytest.param(16, 1, 0.125, False),
102+
pytest.param(32, 1, 0.25, False),
103+
pytest.param(48, 3, 0.5, True),
104+
pytest.param(64, 4, 0.75, True),
105+
pytest.param(128, 5, 1, True)
106+
])
107+
def test_ssim(size, channel, plus, multichannel):
108+
device = "cuda" if torch.cuda.is_available() else "cpu"
109+
pred = torch.rand(1, channel, size, size, device=device)
110+
target = pred + plus
111+
ssim_idx = ssim(pred, target)
112+
np_pred = np.random.rand(size, size, channel)
113+
if multichannel is False:
114+
np_pred = np_pred[:, :, 0]
115+
np_target = np.add(np_pred, plus)
116+
sk_ssim_idx = ski_ssim(np_pred, np_target, win_size=11, multichannel=multichannel, gaussian_weights=True)
117+
assert torch.allclose(ssim_idx, torch.tensor(sk_ssim_idx, dtype=torch.float, device=device), atol=1e-2, rtol=1e-2)
118+
119+
ssim_idx = ssim(pred, pred)
120+
assert torch.allclose(ssim_idx, torch.tensor(1.0, device=device))
121+
122+
123+
@pytest.mark.parametrize(['pred', 'target', 'kernel', 'sigma'], [
124+
pytest.param([1, 1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # shape
125+
pytest.param([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
126+
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
127+
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
128+
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
129+
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
130+
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 10], [1.5, 1.5]), # invalid kernel input
131+
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, -11], [1.5, 1.5]), # invalid kernel input
132+
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5, 0]), # invalid sigma input
133+
pytest.param([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, -1.5]), # invalid sigma input
134+
])
135+
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
136+
pred_t = torch.rand(pred)
137+
target_t = torch.rand(target, dtype=torch.float64)
138+
with pytest.raises(TypeError):
139+
ssim(pred_t, target_t)
140+
141+
pred = torch.rand(pred)
142+
target = torch.rand(target)
143+
with pytest.raises(ValueError):
144+
ssim(pred, target, kernel, sigma)

tests/metrics/test_regression.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from skimage.metrics import peak_signal_noise_ratio as ski_psnr
77

88
from pytorch_lightning.metrics.regression import (
9-
MAE, MSE, RMSE, RMSLE, PSNR
9+
MAE, MSE, RMSE, RMSLE, PSNR, SSIM
1010
)
1111

1212

@@ -58,3 +58,13 @@ def test_psnr():
5858
target = torch.tensor([0., 1, 2, 2])
5959
score = psnr(pred, target)
6060
assert isinstance(score, torch.Tensor)
61+
62+
63+
def test_ssim():
64+
ssim = SSIM()
65+
assert ssim.name == 'ssim'
66+
67+
pred = torch.rand([16, 1, 16, 16])
68+
target = pred * 1.25
69+
score = ssim(pred, target)
70+
assert isinstance(score, torch.Tensor)

0 commit comments

Comments
 (0)