-
Notifications
You must be signed in to change notification settings - Fork 387
/
_deprecated.py
140 lines (111 loc) · 5.28 KB
/
_deprecated.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import Optional, Tuple
from torch import Tensor
from torchmetrics.functional.retrieval.average_precision import retrieval_average_precision
from torchmetrics.functional.retrieval.fall_out import retrieval_fall_out
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
from torchmetrics.utilities.prints import _deprecated_root_import_func
def _retrieval_average_precision(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_average_precision(preds, target)
tensor(0.8333)
"""
_deprecated_root_import_func("retrieval_average_precision", "retrieval")
return retrieval_average_precision(preds=preds, target=target, top_k=top_k)
def _retrieval_fall_out(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_fall_out(preds, target, top_k=2)
tensor(1.)
"""
_deprecated_root_import_func("retrieval_fall_out", "retrieval")
return retrieval_fall_out(preds=preds, target=target, top_k=top_k)
def _retrieval_hit_rate(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_hit_rate(preds, target, top_k=2)
tensor(1.)
"""
_deprecated_root_import_func("retrieval_hit_rate", "retrieval")
return retrieval_hit_rate(preds=preds, target=target, top_k=top_k)
def _retrieval_normalized_dcg(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([.1, .2, .3, 4, 70])
>>> target = tensor([10, 0, 0, 1, 5])
>>> _retrieval_normalized_dcg(preds, target)
tensor(0.6957)
"""
_deprecated_root_import_func("retrieval_normalized_dcg", "retrieval")
return retrieval_normalized_dcg(preds=preds, target=target, top_k=top_k)
def _retrieval_precision(
preds: Tensor, target: Tensor, top_k: Optional[int] = None, adaptive_k: bool = False
) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_precision(preds, target, top_k=2)
tensor(0.5000)
"""
_deprecated_root_import_func("retrieval_precision", "retrieval")
return retrieval_precision(preds=preds, target=target, top_k=top_k, adaptive_k=adaptive_k)
def _retrieval_precision_recall_curve(
preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> precisions, recalls, top_k = _retrieval_precision_recall_curve(preds, target, max_k=2)
>>> precisions
tensor([1.0000, 0.5000])
>>> recalls
tensor([0.5000, 0.5000])
>>> top_k
tensor([1, 2])
"""
_deprecated_root_import_func("retrieval_precision_recall_curve", "retrieval")
return retrieval_precision_recall_curve(preds=preds, target=target, max_k=max_k, adaptive_k=adaptive_k)
def _retrieval_r_precision(preds: Tensor, target: Tensor) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_r_precision(preds, target)
tensor(0.5000)
"""
_deprecated_root_import_func("retrieval_r_precision", "retrieval")
return retrieval_r_precision(preds=preds, target=target)
def _retrieval_recall(preds: Tensor, target: Tensor, top_k: Optional[int] = None) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> _retrieval_recall(preds, target, top_k=2)
tensor(0.5000)
"""
_deprecated_root_import_func("retrieval_recall", "retrieval")
return retrieval_recall(preds=preds, target=target, top_k=top_k)
def _retrieval_reciprocal_rank(preds: Tensor, target: Tensor) -> Tensor:
"""Wrapper for deprecated import.
>>> from torch import tensor
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([False, True, False])
>>> _retrieval_reciprocal_rank(preds, target)
tensor(0.5000)
"""
_deprecated_root_import_func("retrieval_reciprocal_rank", "retrieval")
return retrieval_reciprocal_rank(preds=preds, target=target)