diff --git a/src/sympc/approximations/exponential.py b/src/sympc/approximations/exponential.py index 46965a7b..0c6bf573 100644 --- a/src/sympc/approximations/exponential.py +++ b/src/sympc/approximations/exponential.py @@ -1,7 +1,11 @@ """function used to calculate exp of a given tensor.""" +from sympc.tensor import MPCTensor +from sympc.tensor.register_approximation import register_approximation -def exp(value, iterations=8): + +@register_approximation("exp") +def exp(value: MPCTensor, iterations: int = 8) -> MPCTensor: r"""Approximates the exponential function using a limit approximation. exp(x) = \lim_{n -> infty} (1 + x / n) ^ n @@ -9,7 +13,7 @@ def exp(value, iterations=8): iterations. We then compute (1 + x / n) once and square `d` times. Args: - value: tensor whose exp is to be calculated + value (MPCTensor): tensor whose exp is to be calculated iterations (int): number of iterations for limit approximation Ref: https://github.com/LaRiffle/approximate-models diff --git a/src/sympc/approximations/log.py b/src/sympc/approximations/log.py index 3b6176a6..9662f061 100644 --- a/src/sympc/approximations/log.py +++ b/src/sympc/approximations/log.py @@ -1,9 +1,16 @@ """fucntion used to calculate log of given tensor.""" +# stdlib +# from typing import TypeVar +# MPCTensor = TypeVar("MPCTensor") + from sympc.approximations.exponential import exp +from sympc.tensor import MPCTensor +from sympc.tensor import RegisterApproximation -def log(self, iterations=2, exp_iterations=8): +@RegisterApproximation("log") +def log(self: MPCTensor, iterations: int = 2, exp_iterations: int = 8) -> MPCTensor: """Approximates the natural logarithm using 8th order modified Householder iterations. Recall that Householder method is an algorithm to solve a non linear equation f(x) = 0. @@ -14,7 +21,7 @@ def log(self, iterations=2, exp_iterations=8): y_{n+1} = y_n - h * (1 + h / 2 + h^2 / 3 + h^3 / 6 + h^4 / 5 + h^5 / 7) Args: - self: tensor whose log has to be calculated + self (MPCTensor): tensor whose log has to be calculated iterations (int): number of iterations for 6th order modified Householder approximation. exp_iterations (int): number of iterations for limit approximation of exp diff --git a/src/sympc/approximations/reciprocal.py b/src/sympc/approximations/reciprocal.py index 39837285..f356e201 100644 --- a/src/sympc/approximations/reciprocal.py +++ b/src/sympc/approximations/reciprocal.py @@ -1,20 +1,26 @@ """function used to calculate reciprocal of a given tensor.""" +# stdlib +# from typing import TypeVar + from sympc.approximations.exponential import exp from sympc.approximations.log import log from sympc.approximations.utils import modulus from sympc.approximations.utils import sign +from sympc.tensor import MPCTensor +from sympc.tensor import RegisterApproximation -def reciprocal(self, method: str = "NR", nr_iters: int = 10): +@RegisterApproximation("reciprocal") +def reciprocal(self: MPCTensor, method: str = "NR", nr_iters: int = 10) -> MPCTensor: r"""Calculate the reciprocal using the algorithm specified in the method args. Ref: https://github.com/facebookresearch/CrypTen Args: - self: input data - nr_iters: Number of iterations for Newton-Raphson - method: 'NR' - `Newton-Raphson`_ method computes the reciprocal using iterations + self (MPCTensor): input data + nr_iters (int): Number of iterations for Newton-Raphson + method (str): 'NR' - `Newton-Raphson`_ method computes the reciprocal using iterations of :math:`x_{i+1} = (2x_i - self * x_i^2)` and uses :math:`3*exp(-(x-.5)) + 0.003` as an initial guess by default @@ -22,7 +28,7 @@ def reciprocal(self, method: str = "NR", nr_iters: int = 10): :math:`x^{-1} = exp(-log(x))` Returns: - Reciprocal of `self` + MPCTensor : Reciprocal of `self` Raises: ValueError: if the given method is not supported diff --git a/src/sympc/approximations/sigmoid.py b/src/sympc/approximations/sigmoid.py index 2a5337ed..7d30fea6 100644 --- a/src/sympc/approximations/sigmoid.py +++ b/src/sympc/approximations/sigmoid.py @@ -1,6 +1,6 @@ """function used to calculate sigmoid of a given tensor.""" + # stdlib -from typing import Any # third party import torch @@ -8,18 +8,23 @@ from sympc.approximations.exponential import exp from sympc.approximations.reciprocal import reciprocal from sympc.approximations.utils import sign +from sympc.tensor import MPCTensor +from sympc.tensor import RegisterApproximation + +# from typing import TypeVar -def sigmoid(tensor: Any, method: str = "exp") -> Any: +@RegisterApproximation("sigmoid") +def sigmoid(tensor: MPCTensor, method: str = "exp") -> MPCTensor: """Approximates the sigmoid function using a given method. Args: - tensor (Any): tensor to calculate sigmoid + tensor (MPCTensor): tensor to calculate sigmoid method (str): (default = "chebyshev") Possible values: "exp", "maclaurin", "chebyshev" Returns: - tensor (Any): the calulated sigmoid value + MPCTensor: the calulated sigmoid value Raises: ValueError: if the given method is not supported diff --git a/src/sympc/approximations/softmax.py b/src/sympc/approximations/softmax.py index 54d48ac5..8cd7c69e 100644 --- a/src/sympc/approximations/softmax.py +++ b/src/sympc/approximations/softmax.py @@ -7,8 +7,10 @@ from sympc.approximations.log import log from sympc.approximations.reciprocal import reciprocal from sympc.tensor import MPCTensor +from sympc.tensor import RegisterApproximation +@RegisterApproximation("softmax") def softmax(tensor: MPCTensor, dim: Optional[int] = None) -> MPCTensor: """Calculates tanh of given tensor's elements along the given dimension. @@ -36,6 +38,7 @@ def softmax(tensor: MPCTensor, dim: Optional[int] = None) -> MPCTensor: return numerator * reciprocal(denominator) +@RegisterApproximation("log_softmax") def log_softmax(tensor: MPCTensor, dim: Optional[int] = None) -> MPCTensor: """Applies a softmax followed by a logarithm. diff --git a/src/sympc/approximations/tanh.py b/src/sympc/approximations/tanh.py index 6844c9d5..9a5e9e6c 100644 --- a/src/sympc/approximations/tanh.py +++ b/src/sympc/approximations/tanh.py @@ -12,21 +12,23 @@ from sympc.approximations.sigmoid import sigmoid from sympc.module.nn import relu from sympc.tensor import MPCTensor +from sympc.tensor import RegisterApproximation from sympc.tensor.static import stack -def _tanh_sigmoid(tensor): +def _tanh_sigmoid(tensor: MPCTensor) -> MPCTensor: """Compute the tanh using the sigmoid approximation. Args: - tensor (tensor): values where tanh should be approximated + tensor (MPCTensor): values where tanh should be approximated Returns: - tensor (tensor): tanh calculated using sigmoid + MPCTensor: tanh calculated using sigmoid """ return 2 * sigmoid(2 * tensor) - 1 +@RegisterApproximation("tanh") def tanh(tensor: MPCTensor, method: str = "sigmoid") -> MPCTensor: """Calculates tanh of given tensor. diff --git a/src/sympc/approximations/utils.py b/src/sympc/approximations/utils.py index c8189954..8173d302 100644 --- a/src/sympc/approximations/utils.py +++ b/src/sympc/approximations/utils.py @@ -1,11 +1,15 @@ """Utility functions for approximation functions.""" +from sympc.tensor import MPCTensor +from sympc.tensor import RegisterApproximation -def sign(data): + +@RegisterApproximation("sign") +def sign(data: MPCTensor) -> MPCTensor: """Calculate sign of given tensor. Args: - data: tensor whose sign has to be determined + data (MPCTensor): tensor whose sign has to be determined Returns: MPCTensor: tensor with the determined sign @@ -13,11 +17,12 @@ def sign(data): return (data > 0) + (data < 0) * (-1) -def modulus(data): +@RegisterApproximation("modulus") +def modulus(data: MPCTensor) -> MPCTensor: """Calculation of modulus for a given tensor. Args: - data(MPCTensor): tensor whose modulus has to be calculated + data (MPCTensor): tensor whose modulus has to be calculated Returns: MPCTensor: the required modulus diff --git a/src/sympc/tensor/__init__.py b/src/sympc/tensor/__init__.py index 98462f79..afae4bce 100644 --- a/src/sympc/tensor/__init__.py +++ b/src/sympc/tensor/__init__.py @@ -5,6 +5,7 @@ from . import static from .mpc_tensor import METHODS_TO_ADD from .mpc_tensor import MPCTensor +from .register_approximation import RegisterApproximation from .replicatedshare_tensor import PRIME_NUMBER from .replicatedshare_tensor import ReplicatedSharedTensor @@ -12,6 +13,7 @@ "ShareTensor", "ReplicatedSharedTensor", "MPCTensor", + "RegisterApproximation", "METHODS_TO_ADD", "static", "PRIME_NUMBER", diff --git a/src/sympc/tensor/mpc_tensor.py b/src/sympc/tensor/mpc_tensor.py index 72172d2d..ef917b75 100644 --- a/src/sympc/tensor/mpc_tensor.py +++ b/src/sympc/tensor/mpc_tensor.py @@ -16,15 +16,17 @@ import torch import torchcsprng as csprng # type: ignore -from sympc.approximations import APPROXIMATIONS from sympc.config import Config from sympc.encoder import FixedPointEncoder from sympc.session import Session +from sympc.tensor import RegisterApproximation from sympc.tensor import ShareTensor from sympc.utils import ispointer from .tensor import SyMPCTensor +APPROXIMATIONS = RegisterApproximation.approx_dict + PROPERTIES_FORWARD_ALL_SHARES = {"T"} METHODS_FORWARD_ALL_SHARES = { "t", diff --git a/src/sympc/tensor/register_approximation.py b/src/sympc/tensor/register_approximation.py new file mode 100644 index 00000000..558a7729 --- /dev/null +++ b/src/sympc/tensor/register_approximation.py @@ -0,0 +1,42 @@ +"""Register decorator that keeps track of all the approximations we have.""" + + +class RegisterApproximation: + """Used to keep track of all the approximations we have. + + Arguments: + nfunc: the name of the function. + + This class is used as a register decorator class that keeps track of all the + approximation functions we have in a dictionary. + + """ + + approx_dict = {} + + def __init__(self, nfunc): + """Initializer for the RegisterApproximation class. + + Arguments: + nfunc: the name of the function. + """ + self.nfunc = nfunc + + def __call__(self, func): + """Returns a wrapper functions that adds an approximation function to the dictionary approx_dict. + + Arguments: + func: function to be added to the dictionary approx_dict. + + Returns: + wrapper functions of function func that was added to the approx_dict dictionary. + + """ + self.approx_dict[self.nfunc] = func + + def wrapper(*args, **kwargs): + + res = func(*args, **kwargs) + return res + + return wrapper