-
Notifications
You must be signed in to change notification settings - Fork 387
/
pesq.py
168 lines (143 loc) · 6 KB
/
pesq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
from deprecate import deprecated, void
from torch import Tensor, tensor
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality
from torchmetrics.metric import Metric
from torchmetrics.utilities import _future_warning
from torchmetrics.utilities.imports import _PESQ_AVAILABLE
__doctest_requires__ = {("PerceptualEvaluationSpeechQuality", "PESQ"): ["pesq"]}
class PerceptualEvaluationSpeechQuality(Metric):
"""Perceptual Evaluation of Speech Quality (PESQ)
This is a wrapper for the pesq package [1]. . Note that input will be moved to `cpu`
to perform the metric calculation.
.. note:: using this metrics requires you to have ``pesq`` install. Either install as ``pip install
torchmetrics[audio]`` or ``pip install pesq``
Forward accepts
- ``preds``: ``shape [...,time]``
- ``target``: ``shape [...,time]``
Args:
fs:
sampling frequency, should be 16000 or 8000 (Hz)
mode:
'wb' (wide-band) or 'nb' (narrow-band)
keep_same_device:
whether to move the pesq value to the device of preds
compute_on_step:
Forward only calls ``update()`` and return ``None`` if this is set to ``False``.
dist_sync_on_step:
Synchronize metric state across processes at each ``forward()``
before returning the value at the step
process_group:
Specify the process group on which synchronization is called.
default: ``None`` (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
Raises:
ModuleNotFoundError:
If ``peqs`` package is not installed
ValueError:
If ``fs`` is not either ``8000`` or ``16000``
ValueError:
If ``mode`` is not either ``"wb"`` or ``"nb"``
Example:
>>> from torchmetrics.audio.pesq import PerceptualEvaluationSpeechQuality
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> nb_pesq = PerceptualEvaluationSpeechQuality(8000, 'nb')
>>> nb_pesq(preds, target)
tensor(2.2076)
>>> wb_pesq = PerceptualEvaluationSpeechQuality(16000, 'wb')
>>> wb_pesq(preds, target)
tensor(1.7359)
References:
[1] https://github.com/ludlows/python-pesq
"""
sum_pesq: Tensor
total: Tensor
is_differentiable = False
higher_is_better = True
def __init__(
self,
fs: int,
mode: str,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
if not _PESQ_AVAILABLE:
raise ModuleNotFoundError(
"PerceptualEvaluationSpeechQuality metric requires that `pesq` is installed."
" Either install as `pip install torchmetrics[audio]` or `pip install pesq`."
)
if fs not in (8000, 16000):
raise ValueError(f"Expected argument `fs` to either be 8000 or 16000 but got {fs}")
self.fs = fs
if mode not in ("wb", "nb"):
raise ValueError(f"Expected argument `mode` to either be 'wb' or 'nb' but got {mode}")
self.mode = mode
self.add_state("sum_pesq", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore
"""Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
pesq_batch = perceptual_evaluation_speech_quality(preds, target, self.fs, self.mode, False).to(
self.sum_pesq.device
)
self.sum_pesq += pesq_batch.sum()
self.total += pesq_batch.numel()
def compute(self) -> Tensor:
"""Computes average PESQ."""
return self.sum_pesq / self.total
class PESQ(PerceptualEvaluationSpeechQuality):
"""Perceptual Evaluation of Speech Quality (PESQ).
.. deprecated:: v0.7
Use :class:`torchmetrics.audio.PerceptualEvaluationSpeechQuality`. Will be removed in v0.8.
Example:
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> nb_pesq = PESQ(8000, 'nb')
>>> nb_pesq(preds, target)
tensor(2.2076)
>>> wb_pesq = PESQ(16000, 'wb')
>>> wb_pesq(preds, target)
tensor(1.7359)
"""
@deprecated(target=PerceptualEvaluationSpeechQuality, deprecated_in="0.7", remove_in="0.8", stream=_future_warning)
def __init__(
self,
fs: int,
mode: str,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Optional[Callable[[Tensor], Tensor]] = None,
) -> None:
void(fs, mode, compute_on_step, dist_sync_on_step, process_group, dist_sync_fn)