Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 47 additions & 4 deletions flamesdk/resources/client_apis/clients/result_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import math
from io import BytesIO
from typing import Any, Literal, Optional
from httpx import Client
import pickle
import re

from flamesdk.resources.utils import flame_log
from typing_extensions import TypedDict

class LocalDifferentialPrivacyParams(TypedDict, total=True):
epsilon: float
sensitivity: float


class ResultClient:
Expand All @@ -26,6 +32,7 @@ def push_result(self,
remote_node_id: Optional[str] = None,
type: Literal["final", "global", "local"] = "final",
output_type: Literal['str', 'bytes', 'pickle'] = 'pickle',
local_dp: LocalDifferentialPrivacyParams = None,
silent: bool = False) -> dict[str, str]:
"""
Pushes the result to the hub. Making it available for analysts to download.
Expand All @@ -35,6 +42,7 @@ def push_result(self,
:param remote_node_id: optional remote node id (used for accessing remote node's public key for encryption)
:param type: location to save the result, final saves in the hub to be downloaded, global saves in central instance of MinIO, local saves in the node
:param output_type: the type of the result, str, bytes or pickle only for final results
:param local_dp: parameters for local differential privacy, only for final floating-point type results
:param silent: if True, the response will not be logged
:return:
"""
Expand All @@ -48,20 +56,55 @@ def push_result(self,
if tag and not re.match(r'^[a-z0-9]{1,2}|[a-z0-9][a-z0-9-]{,30}[a-z0-9]+$', tag):
raise ValueError("Tag must consist only of lowercase letters, numbers, and hyphens")

if (type == 'final') and (output_type == 'str'):
# check if local dp parameters have been supplied
use_local_dp = isinstance(local_dp, dict)

if use_local_dp:
# check if result is a numeric value
if not isinstance(result, (float, int)):
raise ValueError("Local differential privacy can only be applied on numeric values")

# check if result is finite
if not math.isfinite(result):
raise ValueError("Result is not finite")

# check if final result submission is requested
if type != "final":
raise ValueError("Local differential privacy is only supported for submission of final results")

# print warning if output_type other than str is specified
if output_type != "str":
flame_log(
f"Result submission with local differential privacy requested but output type is set to `{output_type}`."
"`str` is enforced but this may change in a future version.",
silent
)

# write as string to request body
file_body = str(result).encode("utf-8")
elif (type == 'final') and (output_type == 'str'):
file_body = str(result).encode('utf-8')
elif (type == 'final') and (output_type == 'bytes'):
file_body = bytes(result)
else:
file_body = pickle.dumps(result)

data = {}

if remote_node_id:
data = {"remote_node_id": remote_node_id}
elif tag:
data = {"tag": tag}
else:
data = {}
response = self.client.put(f"/{type}/",

request_path = f"/{type}/"

if use_local_dp:
# append to request path
request_path += "localdp"
# local_dp is guaranteed to not be None, so remap values to string and update request data mapping
data.update({k: str(v) for k, v in local_dp.items()})

response = self.client.put(request_path,
files={"file": BytesIO(file_body)},
data=data,
headers=[('Connection', 'close')])
Expand Down