diff --git a/flamesdk/resources/client_apis/clients/result_client.py b/flamesdk/resources/client_apis/clients/result_client.py index f5087d3..3286eb0 100644 --- a/flamesdk/resources/client_apis/clients/result_client.py +++ b/flamesdk/resources/client_apis/clients/result_client.py @@ -1,3 +1,4 @@ +import math from io import BytesIO from typing import Any, Literal, Optional from httpx import Client @@ -5,6 +6,11 @@ import re from flamesdk.resources.utils import flame_log +from typing_extensions import TypedDict + +class LocalDifferentialPrivacyParams(TypedDict, total=True): + epsilon: float + sensitivity: float class ResultClient: @@ -26,6 +32,7 @@ def push_result(self, remote_node_id: Optional[str] = None, type: Literal["final", "global", "local"] = "final", output_type: Literal['str', 'bytes', 'pickle'] = 'pickle', + local_dp: LocalDifferentialPrivacyParams = None, silent: bool = False) -> dict[str, str]: """ Pushes the result to the hub. Making it available for analysts to download. @@ -35,6 +42,7 @@ def push_result(self, :param remote_node_id: optional remote node id (used for accessing remote node's public key for encryption) :param type: location to save the result, final saves in the hub to be downloaded, global saves in central instance of MinIO, local saves in the node :param output_type: the type of the result, str, bytes or pickle only for final results + :param local_dp: parameters for local differential privacy, only for final floating-point type results :param silent: if True, the response will not be logged :return: """ @@ -48,20 +56,55 @@ def push_result(self, if tag and not re.match(r'^[a-z0-9]{1,2}|[a-z0-9][a-z0-9-]{,30}[a-z0-9]+$', tag): raise ValueError("Tag must consist only of lowercase letters, numbers, and hyphens") - if (type == 'final') and (output_type == 'str'): + # check if local dp parameters have been supplied + use_local_dp = isinstance(local_dp, dict) + + if use_local_dp: + # check if result is a numeric value + if not isinstance(result, (float, int)): + raise ValueError("Local differential privacy can only be applied on numeric values") + + # check if result is finite + if not math.isfinite(result): + raise ValueError("Result is not finite") + + # check if final result submission is requested + if type != "final": + raise ValueError("Local differential privacy is only supported for submission of final results") + + # print warning if output_type other than str is specified + if output_type != "str": + flame_log( + f"Result submission with local differential privacy requested but output type is set to `{output_type}`." + "`str` is enforced but this may change in a future version.", + silent + ) + + # write as string to request body + file_body = str(result).encode("utf-8") + elif (type == 'final') and (output_type == 'str'): file_body = str(result).encode('utf-8') elif (type == 'final') and (output_type == 'bytes'): file_body = bytes(result) else: file_body = pickle.dumps(result) + data = {} + if remote_node_id: data = {"remote_node_id": remote_node_id} elif tag: data = {"tag": tag} - else: - data = {} - response = self.client.put(f"/{type}/", + + request_path = f"/{type}/" + + if use_local_dp: + # append to request path + request_path += "localdp" + # local_dp is guaranteed to not be None, so remap values to string and update request data mapping + data.update({k: str(v) for k, v in local_dp.items()}) + + response = self.client.put(request_path, files={"file": BytesIO(file_body)}, data=data, headers=[('Connection', 'close')])