-
-
Notifications
You must be signed in to change notification settings - Fork 39
/
sigmoid.py
43 lines (33 loc) · 1.31 KB
/
sigmoid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""function used to calculate sigmoid of a given tensor."""
# third party
import torch
from sympc.approximations.exponential import exp
from sympc.approximations.reci import reciprocal
from sympc.approximations.utils import sign
from sympc.tensor.mpc_tensor import MPCTensor
def sigmoid(tensor: MPCTensor, method: str = "exp") -> "MPCTensor":
"""Approximates the sigmoid function using a given method.
Args:
tensor: tensor to calculate sigmoid
method (str): (default = "chebyshev")
Possible values: "exp", "maclaurin", "chebyshev"
Returns:
tensor: the calulated sigmoid value
"""
if method == "exp":
_sign = sign(tensor)
# Make sure the elements are all positive
x = tensor * _sign
ones = tensor * 0 + 1
half = ones / 2
result = reciprocal(ones + exp(-1 * ones * x), method="nr")
return (result - half) * _sign + half
elif method == "maclaurin":
weights = torch.tensor([0.5, 1.91204779e-01, -4.58667307e-03, 4.20690803e-05])
degrees = [0, 1, 3, 5]
# initiate with term of degree 0 to avoid errors with tensor ** 0
one = tensor * 0 + 1
result = one * weights[0]
for i, d in enumerate(degrees[1:]):
result += (tensor ** d) * weights[i + 1]
return result