Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotations for Approximations #258 #322

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/sympc/approximations/exponential.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
"""function used to calculate exp of a given tensor."""

from sympc.tensor import MPCTensor
from sympc.tensor.register_approximation import register_approximation

def exp(value, iterations=8):

@register_approximation("exp")
def exp(value: MPCTensor, iterations: int = 8) -> MPCTensor:
r"""Approximates the exponential function using a limit approximation.

exp(x) = \lim_{n -> infty} (1 + x / n) ^ n
Here we compute exp by choosing n = 2 ** d for some large d equal to
iterations. We then compute (1 + x / n) once and square `d` times.

Args:
value: tensor whose exp is to be calculated
value (MPCTensor): tensor whose exp is to be calculated
iterations (int): number of iterations for limit approximation

Ref: https://github.com/LaRiffle/approximate-models
Expand Down
11 changes: 9 additions & 2 deletions src/sympc/approximations/log.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
"""fucntion used to calculate log of given tensor."""

# stdlib
# from typing import TypeVar
# MPCTensor = TypeVar("MPCTensor")

from sympc.approximations.exponential import exp
from sympc.tensor import MPCTensor
from sympc.tensor import RegisterApproximation


def log(self, iterations=2, exp_iterations=8):
@RegisterApproximation("log")
def log(self: MPCTensor, iterations: int = 2, exp_iterations: int = 8) -> MPCTensor:
"""Approximates the natural logarithm using 8th order modified Householder iterations.

Recall that Householder method is an algorithm to solve a non linear equation f(x) = 0.
Expand All @@ -14,7 +21,7 @@ def log(self, iterations=2, exp_iterations=8):
y_{n+1} = y_n - h * (1 + h / 2 + h^2 / 3 + h^3 / 6 + h^4 / 5 + h^5 / 7)

Args:
self: tensor whose log has to be calculated
self (MPCTensor): tensor whose log has to be calculated
iterations (int): number of iterations for 6th order modified
Householder approximation.
exp_iterations (int): number of iterations for limit approximation of exp
Expand Down
16 changes: 11 additions & 5 deletions src/sympc/approximations/reciprocal.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,34 @@
"""function used to calculate reciprocal of a given tensor."""

# stdlib
# from typing import TypeVar

from sympc.approximations.exponential import exp
from sympc.approximations.log import log
from sympc.approximations.utils import modulus
from sympc.approximations.utils import sign
from sympc.tensor import MPCTensor
from sympc.tensor import RegisterApproximation


def reciprocal(self, method: str = "NR", nr_iters: int = 10):
@RegisterApproximation("reciprocal")
def reciprocal(self: MPCTensor, method: str = "NR", nr_iters: int = 10) -> MPCTensor:
r"""Calculate the reciprocal using the algorithm specified in the method args.

Ref: https://github.com/facebookresearch/CrypTen

Args:
self: input data
nr_iters: Number of iterations for Newton-Raphson
method: 'NR' - `Newton-Raphson`_ method computes the reciprocal using iterations
self (MPCTensor): input data
nr_iters (int): Number of iterations for Newton-Raphson
method (str): 'NR' - `Newton-Raphson`_ method computes the reciprocal using iterations
of :math:`x_{i+1} = (2x_i - self * x_i^2)` and uses
:math:`3*exp(-(x-.5)) + 0.003` as an initial guess by default

'log' - Computes the reciprocal of the input from the observation that:
:math:`x^{-1} = exp(-log(x))`

Returns:
Reciprocal of `self`
MPCTensor : Reciprocal of `self`

Raises:
ValueError: if the given method is not supported
Expand Down
13 changes: 9 additions & 4 deletions src/sympc/approximations/sigmoid.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,30 @@
"""function used to calculate sigmoid of a given tensor."""

# stdlib
from typing import Any

# third party
import torch

from sympc.approximations.exponential import exp
from sympc.approximations.reciprocal import reciprocal
from sympc.approximations.utils import sign
from sympc.tensor import MPCTensor
from sympc.tensor import RegisterApproximation

# from typing import TypeVar


def sigmoid(tensor: Any, method: str = "exp") -> Any:
@RegisterApproximation("sigmoid")
def sigmoid(tensor: MPCTensor, method: str = "exp") -> MPCTensor:
"""Approximates the sigmoid function using a given method.

Args:
tensor (Any): tensor to calculate sigmoid
tensor (MPCTensor): tensor to calculate sigmoid
method (str): (default = "chebyshev")
Possible values: "exp", "maclaurin", "chebyshev"

Returns:
tensor (Any): the calulated sigmoid value
MPCTensor: the calulated sigmoid value

Raises:
ValueError: if the given method is not supported
Expand Down
3 changes: 3 additions & 0 deletions src/sympc/approximations/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from sympc.approximations.log import log
from sympc.approximations.reciprocal import reciprocal
from sympc.tensor import MPCTensor
from sympc.tensor import RegisterApproximation


@RegisterApproximation("softmax")
def softmax(tensor: MPCTensor, dim: Optional[int] = None) -> MPCTensor:
"""Calculates tanh of given tensor's elements along the given dimension.

Expand Down Expand Up @@ -36,6 +38,7 @@ def softmax(tensor: MPCTensor, dim: Optional[int] = None) -> MPCTensor:
return numerator * reciprocal(denominator)


@RegisterApproximation("log_softmax")
def log_softmax(tensor: MPCTensor, dim: Optional[int] = None) -> MPCTensor:
"""Applies a softmax followed by a logarithm.

Expand Down
8 changes: 5 additions & 3 deletions src/sympc/approximations/tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@
from sympc.approximations.sigmoid import sigmoid
from sympc.module.nn import relu
from sympc.tensor import MPCTensor
from sympc.tensor import RegisterApproximation
from sympc.tensor.static import stack


def _tanh_sigmoid(tensor):
def _tanh_sigmoid(tensor: MPCTensor) -> MPCTensor:
"""Compute the tanh using the sigmoid approximation.

Args:
tensor (tensor): values where tanh should be approximated
tensor (MPCTensor): values where tanh should be approximated

Returns:
tensor (tensor): tanh calculated using sigmoid
MPCTensor: tanh calculated using sigmoid
"""
return 2 * sigmoid(2 * tensor) - 1


@RegisterApproximation("tanh")
def tanh(tensor: MPCTensor, method: str = "sigmoid") -> MPCTensor:
"""Calculates tanh of given tensor.

Expand Down
13 changes: 9 additions & 4 deletions src/sympc/approximations/utils.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
"""Utility functions for approximation functions."""

from sympc.tensor import MPCTensor
from sympc.tensor import RegisterApproximation

def sign(data):

@RegisterApproximation("sign")
def sign(data: MPCTensor) -> MPCTensor:
"""Calculate sign of given tensor.

Args:
data: tensor whose sign has to be determined
data (MPCTensor): tensor whose sign has to be determined

Returns:
MPCTensor: tensor with the determined sign
"""
return (data > 0) + (data < 0) * (-1)


def modulus(data):
@RegisterApproximation("modulus")
def modulus(data: MPCTensor) -> MPCTensor:
"""Calculation of modulus for a given tensor.

Args:
data(MPCTensor): tensor whose modulus has to be calculated
data (MPCTensor): tensor whose modulus has to be calculated

Returns:
MPCTensor: the required modulus
Expand Down
2 changes: 2 additions & 0 deletions src/sympc/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from . import static
from .mpc_tensor import METHODS_TO_ADD
from .mpc_tensor import MPCTensor
from .register_approximation import RegisterApproximation
from .replicatedshare_tensor import PRIME_NUMBER
from .replicatedshare_tensor import ReplicatedSharedTensor

__all__ = [
"ShareTensor",
"ReplicatedSharedTensor",
"MPCTensor",
"RegisterApproximation",
"METHODS_TO_ADD",
"static",
"PRIME_NUMBER",
Expand Down
4 changes: 3 additions & 1 deletion src/sympc/tensor/mpc_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,17 @@
import torch
import torchcsprng as csprng # type: ignore

from sympc.approximations import APPROXIMATIONS
from sympc.config import Config
from sympc.encoder import FixedPointEncoder
from sympc.session import Session
from sympc.tensor import RegisterApproximation
from sympc.tensor import ShareTensor
from sympc.utils import ispointer

from .tensor import SyMPCTensor

APPROXIMATIONS = RegisterApproximation.approx_dict

PROPERTIES_FORWARD_ALL_SHARES = {"T"}
METHODS_FORWARD_ALL_SHARES = {
"t",
Expand Down
42 changes: 42 additions & 0 deletions src/sympc/tensor/register_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Register decorator that keeps track of all the approximations we have."""


class RegisterApproximation:
"""Used to keep track of all the approximations we have.

Arguments:
nfunc: the name of the function.

This class is used as a register decorator class that keeps track of all the
approximation functions we have in a dictionary.

"""

approx_dict = {}

def __init__(self, nfunc):
"""Initializer for the RegisterApproximation class.

Arguments:
nfunc: the name of the function.
"""
self.nfunc = nfunc

def __call__(self, func):
"""Returns a wrapper functions that adds an approximation function to the dictionary approx_dict.

Arguments:
func: function to be added to the dictionary approx_dict.

Returns:
wrapper functions of function func that was added to the approx_dict dictionary.

"""
self.approx_dict[self.nfunc] = func

def wrapper(*args, **kwargs):

res = func(*args, **kwargs)
return res

return wrapper