-
Notifications
You must be signed in to change notification settings - Fork 388
/
pesq.py
131 lines (111 loc) · 4.6 KB
/
pesq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from deprecate import deprecated, void
from torchmetrics.utilities.imports import _PESQ_AVAILABLE
if _PESQ_AVAILABLE:
import pesq as pesq_backend
else:
pesq_backend = None
import torch
from torch import Tensor
from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.checks import _check_same_shape
__doctest_requires__ = {
(
"perceptual_evaluation_speech_quality",
"pesq",
): ["pesq"]
}
def perceptual_evaluation_speech_quality(
preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False
) -> Tensor:
r"""PESQ (Perceptual Evaluation of Speech Quality)
This is a wrapper for the ``pesq`` package [1]. Note that input will be moved to `cpu`
to perform the metric calculation.
.. note:: using this metrics requires you to have ``pesq`` install. Either install as ``pip install
torchmetrics[audio]`` or ``pip install pesq``
Args:
preds:
shape ``[...,time]``
target:
shape ``[...,time]``
fs:
sampling frequency, should be 16000 or 8000 (Hz)
mode:
'wb' (wide-band) or 'nb' (narrow-band)
keep_same_device:
whether to move the pesq value to the device of preds
Returns:
pesq value of shape [...]
Raises:
ModuleNotFoundError:
If ``peqs`` package is not installed
ValueError:
If ``fs`` is not either ``8000`` or ``16000``
ValueError:
If ``mode`` is not either ``"wb"`` or ``"nb"``
Example:
>>> from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> perceptual_evaluation_speech_quality(preds, target, 8000, 'nb')
tensor(2.2076)
>>> perceptual_evaluation_speech_quality(preds, target, 16000, 'wb')
tensor(1.7359)
References:
[1] https://github.com/ludlows/python-pesq
"""
if not _PESQ_AVAILABLE:
raise ModuleNotFoundError(
"PESQ metric requires that pesq is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install pesq`."
)
if fs not in (8000, 16000):
raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}")
if mode not in ("wb", "nb"):
raise ValueError(f"Expected argument `mode` to either be 'wb' or 'nb' but got {mode}")
_check_same_shape(preds, target)
if preds.ndim == 1:
pesq_val_np = pesq_backend.pesq(fs, target.detach().cpu().numpy(), preds.detach().cpu().numpy(), mode)
pesq_val = torch.tensor(pesq_val_np)
else:
preds_np = preds.reshape(-1, preds.shape[-1]).detach().cpu().numpy()
target_np = target.reshape(-1, preds.shape[-1]).detach().cpu().numpy()
pesq_val_np = np.empty(shape=(preds_np.shape[0]))
for b in range(preds_np.shape[0]):
pesq_val_np[b] = pesq_backend.pesq(fs, target_np[b, :], preds_np[b, :], mode)
pesq_val = torch.from_numpy(pesq_val_np)
pesq_val = pesq_val.reshape(preds.shape[:-1])
if keep_same_device:
pesq_val = pesq_val.to(preds.device)
return pesq_val
@deprecated(target=perceptual_evaluation_speech_quality, deprecated_in="0.7", remove_in="0.8", stream=_future_warning)
def pesq(preds: Tensor, target: Tensor, fs: int, mode: str, keep_same_device: bool = False) -> Tensor:
r"""PESQ (Perceptual Evaluation of Speech Quality)
.. deprecated:: v0.7
Use :func:`torchmetrics.functional.audio.perceptual_evaluation_speech_quality`. Will be removed in v0.8.
Example:
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> pesq(preds, target, 8000, 'nb')
tensor(2.2076)
>>> pesq(preds, target, 16000, 'wb')
tensor(1.7359)
"""
return void(preds, target, fs, mode, keep_same_device)