-
Notifications
You must be signed in to change notification settings - Fork 387
/
_deprecated.py
122 lines (103 loc) · 4.37 KB
/
_deprecated.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Any, Callable, Optional, Tuple
from torch import Tensor
from typing_extensions import Literal
from torchmetrics.functional.audio.pit import permutation_invariant_training, pit_permutate
from torchmetrics.functional.audio.sdr import scale_invariant_signal_distortion_ratio, signal_distortion_ratio
from torchmetrics.functional.audio.snr import scale_invariant_signal_noise_ratio, signal_noise_ratio
from torchmetrics.utilities.prints import _deprecated_root_import_func
def _permutation_invariant_training(
preds: Tensor,
target: Tensor,
metric_func: Callable,
mode: Literal["speaker-wise", "permutation-wise"] = "speaker-wise",
eval_func: Literal["max", "min"] = "max",
**kwargs: Any
) -> Tuple[Tensor, Tensor]:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])
>>> target = tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]])
>>> best_metric, best_perm = _permutation_invariant_training(
... preds, target, _scale_invariant_signal_distortion_ratio)
>>> best_metric
tensor([-5.1091])
>>> best_perm
tensor([[0, 1]])
>>> pit_permutate(preds, best_perm)
tensor([[[-0.0579, 0.3560, -0.9604],
[-0.1719, 0.3205, 0.2951]]])
"""
_deprecated_root_import_func("permutation_invariant_training", "audio")
return permutation_invariant_training(
preds=preds, target=target, metric_func=metric_func, mode=mode, eval_func=eval_func, **kwargs
)
def _pit_permutate(preds: Tensor, perm: Tensor) -> Tensor:
"""Wrapper for deprecated import."""
_deprecated_root_import_func("pit_permutate", "audio")
return pit_permutate(preds=preds, perm=perm)
def _scale_invariant_signal_distortion_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> _scale_invariant_signal_distortion_ratio(preds, target)
tensor(18.4030)
"""
_deprecated_root_import_func("scale_invariant_signal_distortion_ratio", "audio")
return scale_invariant_signal_distortion_ratio(preds=preds, target=target, zero_mean=zero_mean)
def _signal_distortion_ratio(
preds: Tensor,
target: Tensor,
use_cg_iter: Optional[int] = None,
filter_length: int = 512,
zero_mean: bool = False,
load_diag: Optional[float] = None,
) -> Tensor:
"""Wrapper for deprecated import.
>>> import torch
>>> g = torch.manual_seed(1)
>>> preds = torch.randn(8000)
>>> target = torch.randn(8000)
>>> _signal_distortion_ratio(preds, target)
tensor(-12.0589)
>>> # use with permutation_invariant_training
>>> preds = torch.randn(4, 2, 8000) # [batch, spk, time]
>>> target = torch.randn(4, 2, 8000)
>>> best_metric, best_perm = _permutation_invariant_training(preds, target, _signal_distortion_ratio)
>>> best_metric
tensor([-11.6375, -11.4358, -11.7148, -11.6325])
>>> best_perm
tensor([[1, 0],
[0, 1],
[1, 0],
[0, 1]])
"""
_deprecated_root_import_func("signal_distortion_ratio", "audio")
return signal_distortion_ratio(
preds=preds,
target=target,
use_cg_iter=use_cg_iter,
filter_length=filter_length,
zero_mean=zero_mean,
load_diag=load_diag,
)
def _scale_invariant_signal_noise_ratio(preds: Tensor, target: Tensor) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> _scale_invariant_signal_noise_ratio(preds, target)
tensor(15.0918)
"""
_deprecated_root_import_func("scale_invariant_signal_noise_ratio", "audio")
return scale_invariant_signal_noise_ratio(preds=preds, target=target)
def _signal_noise_ratio(preds: Tensor, target: Tensor, zero_mean: bool = False) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> target = tensor([3.0, -0.5, 2.0, 7.0])
>>> preds = tensor([2.5, 0.0, 2.0, 8.0])
>>> _signal_noise_ratio(preds, target)
tensor(16.1805)
"""
_deprecated_root_import_func("signal_noise_ratio", "audio")
return signal_noise_ratio(preds=preds, target=target, zero_mean=zero_mean)