Skip to content

Commit 25cdf85

Browse files
committed
add the support of client-level DP
1 parent d623e76 commit 25cdf85

File tree

3 files changed

+57
-3
lines changed

3 files changed

+57
-3
lines changed

pe/constant/data.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#: The column name of the label ID
22
LABEL_ID_COLUMN_NAME = "PE.LABEL_ID"
3+
#: The column name of the client ID (if using client-level DP)
4+
CLIENT_ID_COLUMN_NAME = "PE.CLIENT_ID"
35

46
#: The column name of the clean histogram
57
CLEAN_HISTOGRAM_COLUMN_NAME = "PE.CLEAN_HISTOGRAM"

pe/data/data.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import pandas as pd
44
import numpy as np
55
from pe.constant.data import LABEL_ID_COLUMN_NAME
6+
from pe.constant.data import CLIENT_ID_COLUMN_NAME
67

78

89
class Data:
@@ -183,3 +184,34 @@ def concat(cls, data_list, metadata=None):
183184
raise ValueError("Metadata must be the same")
184185
metadata = metadata_list[0]
185186
return Data(data_frame=pd.concat(data_frame_list), metadata=metadata)
187+
188+
def split_by_client(self):
189+
"""Split the data frame by client ID
190+
191+
:raises ValueError: If the client ID column is not in the data frame
192+
:return: The list of data objects with the splited data
193+
:rtype: list[:py:class:`pe.data.Data`]
194+
"""
195+
if CLIENT_ID_COLUMN_NAME not in self.data_frame.columns:
196+
raise ValueError(f"{CLIENT_ID_COLUMN_NAME} not in data frame")
197+
grouped_data_frame = self.data_frame.groupby(CLIENT_ID_COLUMN_NAME)
198+
return [Data(data_frame=data_frame, metadata=self.metadata) for _, data_frame in grouped_data_frame]
199+
200+
def split_by_index(self):
201+
"""Split the data frame by index
202+
203+
:return: The list of data objects with the splited data
204+
:rtype: list[:py:class:`pe.data.Data`]
205+
"""
206+
grouped_data_frame = self.data_frame.groupby(self.data_frame.index)
207+
return [Data(data_frame=data_frame, metadata=self.metadata) for _, data_frame in grouped_data_frame]
208+
209+
def reset_index(self, **kwargs):
210+
"""Reset the index of the data frame
211+
212+
:param kwargs: The keyword arguments to pass to the pandas reset_index function
213+
:type kwargs: dict
214+
:return: A new :py:class:`pe.data.Data` object with the reset index data frame
215+
:rtype: :py:class:`pe.data.Data`
216+
"""
217+
return Data(data_frame=self.data_frame.reset_index(**kwargs), metadata=self.metadata)

pe/histogram/nearest_neighbors.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
api=None,
2727
num_nearest_neighbors=1,
2828
backend="auto",
29+
vote_normalization_level="sample",
2930
):
3031
"""Constructor.
3132
@@ -58,6 +59,10 @@ def __init__(
5859
private samples is large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See
5960
https://faiss.ai/
6061
:type backend: str, optional
62+
:param vote_normalization_level: The level of normalization for the votes. It should be one of the following:
63+
"sample" (normalize the votes from each private sample to have l2 norm = 1), "client" (normalize the votes
64+
from all private samples of the same client to have l2 norm = 1). Defaults to "sample"
65+
:type vote_normalization_level: str, optional
6166
:raises ValueError: If the `api` is not provided when `lookahead_degree` is greater than 0
6267
:raises ValueError: If the `backend` is unknown
6368
"""
@@ -86,6 +91,8 @@ def __init__(
8691
else:
8792
raise ValueError(f"Unknown backend: {backend}")
8893

94+
self._vote_normalization_level = vote_normalization_level
95+
8996
def _log_lookahead(self, syn_data, lookahead_id):
9097
"""Log the lookahead data.
9198
@@ -163,6 +170,7 @@ def compute_histogram(self, priv_data, syn_data):
163170
:type priv_data: :py:class:`pe.data.Data`
164171
:param syn_data: The synthetic data
165172
:type syn_data: :py:class:`pe.data.Data`
173+
:raises ValueError: If the `vote_normalization_level` is unknown
166174
:return: The private data, possibly with the additional embedding column, and the synthetic data, with the
167175
computed histogram in the column :py:const:`pe.constant.data.CLEAN_HISTOGRAM_COLUMN_NAME` and possibly with
168176
the additional embedding column
@@ -189,10 +197,22 @@ def compute_histogram(self, priv_data, syn_data):
189197
)
190198
self._log_voting_details(priv_data=priv_data, syn_data=syn_data, ids=ids)
191199

192-
counter = Counter(list(ids.flatten()))
200+
priv_data = priv_data.reset_index(drop=True)
201+
if self._vote_normalization_level == "client":
202+
priv_data_list = priv_data.split_by_client()
203+
elif self._vote_normalization_level == "sample":
204+
priv_data_list = priv_data.split_by_index()
205+
else:
206+
raise ValueError(f"Unknown vote normalization level: {self._vote_normalization_level}")
207+
193208
count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
194-
count[list(counter.keys())] = list(counter.values())
195-
count /= np.sqrt(self._num_nearest_neighbors)
209+
for sub_priv_data in priv_data_list:
210+
sub_count = np.zeros(shape=syn_embedding.shape[0], dtype=np.float32)
211+
sub_ids = ids[sub_priv_data.data_frame.index]
212+
counter = Counter(list(sub_ids.flatten()))
213+
sub_count[list(counter.keys())] = list(counter.values())
214+
sub_count /= np.linalg.norm(sub_count)
215+
count += sub_count
196216

197217
syn_data.data_frame[CLEAN_HISTOGRAM_COLUMN_NAME] = count
198218

0 commit comments

Comments
 (0)