In [1]:
import jax
from jax import numpy as jnp
import numpy as np

In [2]:
# stdlib
from functools import lru_cache
from typing import Dict
from typing import List
from typing import Optional

# third party
from autodp import dp_bank
from autodp import fdp_bank
from autodp.autodp_core import Mechanism
from nacl.signing import VerifyKey
import numpy as np

In [None]:
# methods serialize/deserialize np.int64 number
# syft.serde seems to not support np.int64 serialization/deserialization
def numpy64tolist(value: np.int64) -> List:
    list_version = value.tolist()
    return list_version


def listtonumpy64(value: List) -> np.int64:
    return np.int64(value)

In [None]:
# returns the privacy budget spent by each entity
@lru_cache(maxsize=None)
def _individual_RDP_gaussian(
    sigma: float, value: float, L: float, alpha: float
) -> float:
    return (alpha * (L**2) * (value**2)) / (2 * (sigma**2))


def individual_RDP_gaussian(params: Dict, alpha: float) -> np.float64:
    """
    :param params:
        'sigma' --- is the normalized noise level: std divided by global L2 sensitivity
        'value' --- is the output of query on a data point
        'L' --- is the Lipschitz constant of query with respect to the output of query on a data point
    :param alpha: The order of the Renyi Divergence
    :return: Evaluation of the RDP's epsilon
    """
    sigma = params["sigma"]
    value = params["value"]
    L = params["L"]
    if sigma <= 0:
        raise Exception("Sigma should be above 0")
    if alpha < 0:
        raise Exception("Sigma should not be below 0")

    return _individual_RDP_gaussian(sigma=sigma, alpha=alpha, value=value, L=L)

In [None]:
# Example of a specific mechanism that inherits the Mechanism class
@serializable(recursive_serde=True)
class iDPGaussianMechanism(Mechanism):
    __attr_allowlist__ = [
        "name",
        "params",
        "entity_name",
        "fdp",
        "eps_pureDP",
        "delta0",
        "RDP_off",
        "approxDP_off",
        "fdp_off",
        "use_basic_rdp_to_approx_dp_conversion",
        "use_fdp_based_rdp_to_approx_dp_conversion",
        "user_key",
    ]

    # delta0 is a numpy.int64 number (not supported by syft.serde)
    __serde_overrides__ = {
        "delta0": [numpy64tolist, listtonumpy64],
    }

    def __init__(
        self,
        sigma: float,
        squared_l2_norm: float,
        squared_l2_norm_upper_bound: float,
        L: float,
        entity_name: str,
        name: str = "Gaussian",
        RDP_off: bool = False,
        approxDP_off: bool = False,
        use_basic_rdp_to_approx_dp_conversion: bool = False,
        use_fdp_based_rdp_to_approx_dp_conversion: bool = False,
        user_key: Optional[VerifyKey] = None,
    ):

        # the sigma parameter is the std of the noise divide by the l2 sensitivity
        Mechanism.__init__(self)

        self.user_key = user_key

        self.name = name  # When composing
        self.params = {
            "sigma": float(sigma),
            "private_value": float(squared_l2_norm),
            "public_value": float(squared_l2_norm_upper_bound),
            "L": float(L),
        }  # This will be useful for the Calibrator

        self.entity_name = entity_name
        # TODO: should a generic unspecified mechanism have a name and a param dictionary?

        self.delta0 = 0
        if not RDP_off:
            new_rdp = lambda x: individual_RDP_gaussian(self.params, x)  # noqa: E731
            # This is the default setting with fast computation of RDP to approx-DP
            self.propagate_updates(new_rdp, "RDP")

        if not approxDP_off:  # Direct implementation of approxDP
            new_approxdp = lambda x: dp_bank.get_eps_ana_gaussian(  # noqa: E731
                sigma, x
            )
            self.propagate_updates(new_approxdp, "approxDP_func")

        # Discussion:  Sometimes delta as a function of eps has a closed-form solution
        # while eps as a function of delta does not
        # Shall we represent delta as a function of eps instead?

In [3]:
import flax
from jax import numpy as np
import numpy as np

In [None]:

@flax.struct.dataclass
def GaussianMechanism:
    sigma: float,
    private_squared_l2_norm: jnp.array,
    public_squared_l2_norm: jnp.array,
    lipschitz_bound: jnp.array,
    entity_indices: jnp.array,
    entity_lookup: jnp.array