-
Notifications
You must be signed in to change notification settings - Fork 387
/
snr.py
132 lines (101 loc) · 4.92 KB
/
snr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio
from torchmetrics.utilities.checks import _check_same_shape
def signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
r"""Calculate `Signal-to-noise ratio`_ (SNR_) meric for evaluating quality of audio.
.. math::
\text{SNR} = \frac{P_{signal}}{P_{noise}}
where :math:`P` denotes the power of each signal. The SNR metric compares the level of the desired signal to
the level of background noise. Therefore, a high value of SNR means that the audio is clear.
Args:
preds: float tensor with shape ``(...,time)``
target: float tensor with shape ``(...,time)``
zero_mean: if to zero mean target and preds or not
Returns:
Float tensor with shape ``(...,)`` of SNR values per sample
Raises:
RuntimeError:
If ``preds`` and ``target`` does not have the same shape
Example:
>>> from torchmetrics.functional.audio import signal_noise_ratio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> signal_noise_ratio(preds, target)
tensor(16.1805)
"""
_check_same_shape(preds, target)
eps = torch.finfo(preds.dtype).eps
if zero_mean:
target = target - torch.mean(target, dim=-1, keepdim=True)
preds = preds - torch.mean(preds, dim=-1, keepdim=True)
noise = target - preds
snr_value = (torch.sum(target**2, dim=-1) + eps) / (torch.sum(noise**2, dim=-1) + eps)
return 10 * torch.log10(snr_value)
def scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
"""`Scale-invariant signal-to-noise ratio`_ (SI-SNR).
Args:
preds: float tensor with shape ``(...,time)``
target: float tensor with shape ``(...,time)``
Returns:
Float tensor with shape ``(...,)`` of SI-SNR values per sample
Raises:
RuntimeError:
If ``preds`` and ``target`` does not have the same shape
Example:
>>> import torch
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> target = torch.tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = torch.tensor([2.5, 0.0, 2.0, 8.0])
>>> scale_invariant_signal_noise_ratio(preds, target)
tensor(15.0918)
"""
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=True)
def complex_scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""`Complex scale-invariant signal-to-noise ratio`_ (C-SI-SNR).
Args:
preds: real float tensor with shape ``(...,frequency,time,2)`` or complex float tensor with
shape ``(..., frequency,time)``
target: real float tensor with shape ``(...,frequency,time,2)`` or complex float tensor with
shape ``(..., frequency,time)``
zero_mean: When set to True, the mean of all signals is subtracted prior to computation of the metrics
Returns:
Float tensor with shape ``(...,)`` of C-SI-SNR values per sample
Raises:
RuntimeError:
If ``preds`` is not the shape (...,frequency,time,2) (after being converted to real if it is complex).
If ``preds`` and ``target`` does not have the same shape.
Example:
>>> import torch
>>> from torchmetrics.functional.audio import complex_scale_invariant_signal_noise_ratio
>>> g = torch.manual_seed(1)
>>> preds = torch.randn((1,257,100,2))
>>> target = torch.randn((1,257,100,2))
>>> complex_scale_invariant_signal_noise_ratio(preds, target)
tensor([-63.4849])
"""
if preds.is_complex():
preds = torch.view_as_real(preds)
if target.is_complex():
target = torch.view_as_real(target)
if (preds.ndim < 3 or preds.shape[-1] != 2) or (target.ndim < 3 or target.shape[-1] != 2):
raise RuntimeError(
"Predictions and targets are expected to have the shape (..., frequency, time, 2),"
" but got {preds.shape} and {target.shape}."
)
preds = preds.reshape(*preds.shape[:-3], -1)
target = target.reshape(*target.shape[:-3], -1)
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean)