-
Notifications
You must be signed in to change notification settings - Fork 0
/
functional.py
203 lines (167 loc) · 6.62 KB
/
functional.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import torch
import torch.nn.functional as F
from torch.types import _dtype as DType
from typing import Any, Callable, Optional, Union, List, Tuple
__all__ = [
'masked_softmax',
'masked_sum',
'masked_mean',
'masked_max',
'masked_min'
]
def _fill_with_mask(input: torch.Tensor, mask: torch.Tensor, fill_value) -> torch.Tensor:
inverted_mask = (1.0 - mask.float()).bool()
return input.masked_fill(inverted_mask, fill_value)
def _call_torch(func: Callable, **kwargs) -> Any:
if kwargs["dim"] is None:
kwargs.pop("dim")
kwargs.pop("keepdim")
return func(**kwargs)
def masked_softmax(
input: torch.Tensor,
mask: torch.Tensor,
dim: Optional[int] = None,
_stacklevel: int = 3,
dtype: Optional[DType] = None
) -> torch.Tensor:
"""
Apply `torch.nn.functional.softmax` while some of the elements of input
tensor being masked.
Parameters
----------
input : torch.Tensor
Input
mask : torch.Tensor
A 0-1 mask for the input tensor, 0 for tokens that are masked, 1
for tokens that are not masked. Should be broadcastable with input.
If None, a regular softmax result will be returned.
dim : int
A dimension along which softmax will be computed.
dtype : torch.dtype, optional
The desired data type of returned tensor.
"""
if mask is None:
return F.softmax(input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
masked_input = _fill_with_mask(input, mask, -float('inf'))
return F.softmax(masked_input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
def masked_sum(
input: torch.Tensor,
mask: torch.Tensor,
dim: Optional[Union[int, List[int], Tuple[int]]] = None,
keepdim: bool = False,
dtype: Optional[DType] = None
) -> torch.Tensor:
"""
Apply `torch.sum` while some of the elements of input tensor being masked.
Parameters
----------
input : torch.Tensor
Input
mask : torch.Tensor
A 0-1 mask for the input tensor, 0 for tokens that are masked, 1
for tokens that are not masked. Should be broadcastable with input.
If None, a regular sum result will be returned.
dim : int or List[int] or Tuple[int], optional
A dimension or list of dimensions along which sum will be computed.
If not specified, returns the masked sum of all elements in the input
tensor.
keepdim : bool, optional, default=False
Whether the output tensor has dim retained or not.
dtype : torch.dtype, optional
The desired data type of returned tensor.
"""
if mask is None:
return _call_torch(input.sum, dim=dim, keepdim=keepdim, dtype=dtype)
masked_input = _fill_with_mask(input, mask, 0.)
return _call_torch(masked_input.sum, dim=dim, keepdim=keepdim, dtype=dtype)
def masked_mean(
input: torch.Tensor,
mask: torch.Tensor,
dim: Optional[Union[int, List[int], Tuple[int]]] = None,
keepdim: bool = False,
dtype: Optional[DType] = None
) -> torch.Tensor:
"""
Apply `torch.mean` while some of the elements of input tensor being masked.
Parameters
----------
input : torch.Tensor
Input
mask : torch.Tensor
A 0-1 mask for the input tensor, 0 for tokens that are masked, 1
for tokens that are not masked. Should be broadcastable with input.
If None, a regular mean result will be returned.
dim : int or List[int] or Tuple[int], optional
A dimension or list of dimensions along which mean will be computed.
If not specified, returns the masked mean of all elements in the input
tensor.
keepdim : bool, optional, default=False
Whether the output tensor has dim retained or not.
dtype : torch.dtype, optional
The desired data type of returned tensor.
"""
if mask is None:
return _call_torch(input.mean, dim=dim, keepdim=keepdim, dtype=dtype)
mask_sum = _call_torch(mask.float().sum, dim=dim, keepdim=keepdim)
mask_sum = mask_sum.clamp(min=1.).to(dtype)
return masked_sum(input, mask, dim, keepdim, dtype) / mask_sum
def masked_max(
input: torch.Tensor,
mask: torch.Tensor,
dim: Optional[int] = None,
keepdim: bool = False
) -> torch.Tensor:
"""
Apply `torch.max` while some of the elements of input tensor being masked.
Parameters
----------
input : torch.Tensor
Input
mask : torch.Tensor
A 0-1 mask for the input tensor, 0 for tokens that are masked, 1
for tokens that are not masked. Should be broadcastable with input.
If None, a regular max result will be returned.
dim : int, optional
A dimension along which max will be computed. If not specified, returns
the maximum value of all unmasked elements in the input tensor. If
specified, returns a namedtuple (values, indices) where values is the
maximum value of unmasked elements in each row of the input tensor in
the given dimension dim, and indices is the index location of each
maximum value found (argmax).
keepdim : bool, optional, default=False
Whether the output tensor has dim retained or not.
"""
if mask is None:
return _call_torch(input.max, dim=dim, keepdim=keepdim)
masked_input = _fill_with_mask(input, mask, -float('inf'))
return _call_torch(masked_input.max, dim=dim, keepdim=keepdim)
def masked_min(
input: torch.Tensor,
mask: torch.Tensor,
dim: Optional[int] = None,
keepdim: bool = False
) -> torch.Tensor:
"""
Apply `torch.min` while some of the elements of input tensor being masked.
Parameters
----------
input : torch.Tensor
Input
mask : torch.Tensor
A 0-1 mask for the input tensor, 0 for tokens that are masked, 1
for tokens that are not masked. Should be broadcastable with input.
If None, a regular min result will be returned.
dim : int, optional
A dimension along which min will be computed. If not specified, returns
the minimum value of all unmasked elements in the input tensor. If
specified, returns a namedtuple (values, indices) where values is the
minimum value of unmasked elements in each row of the input tensor in
the given dimension dim, and indices is the index location of each
minimum value found (argmin).
keepdim : bool, optional, default=False
Whether the output tensor has dim retained or not.
"""
if mask is None:
return _call_torch(input.min, dim=dim, keepdim=keepdim)
masked_input = _fill_with_mask(input, mask, float('inf'))
return _call_torch(masked_input.min, dim=dim, keepdim=keepdim)