@@ -26,6 +26,7 @@ def __init__(
26
26
api = None ,
27
27
num_nearest_neighbors = 1 ,
28
28
backend = "auto" ,
29
+ vote_normalization_level = "sample" ,
29
30
):
30
31
"""Constructor.
31
32
@@ -58,6 +59,10 @@ def __init__(
58
59
private samples is large. It requires the installation of `faiss-gpu` or `faiss-cpu` package. See
59
60
https://faiss.ai/
60
61
:type backend: str, optional
62
+ :param vote_normalization_level: The level of normalization for the votes. It should be one of the following:
63
+ "sample" (normalize the votes from each private sample to have l2 norm = 1), "client" (normalize the votes
64
+ from all private samples of the same client to have l2 norm = 1). Defaults to "sample"
65
+ :type vote_normalization_level: str, optional
61
66
:raises ValueError: If the `api` is not provided when `lookahead_degree` is greater than 0
62
67
:raises ValueError: If the `backend` is unknown
63
68
"""
@@ -86,6 +91,8 @@ def __init__(
86
91
else :
87
92
raise ValueError (f"Unknown backend: { backend } " )
88
93
94
+ self ._vote_normalization_level = vote_normalization_level
95
+
89
96
def _log_lookahead (self , syn_data , lookahead_id ):
90
97
"""Log the lookahead data.
91
98
@@ -163,6 +170,7 @@ def compute_histogram(self, priv_data, syn_data):
163
170
:type priv_data: :py:class:`pe.data.Data`
164
171
:param syn_data: The synthetic data
165
172
:type syn_data: :py:class:`pe.data.Data`
173
+ :raises ValueError: If the `vote_normalization_level` is unknown
166
174
:return: The private data, possibly with the additional embedding column, and the synthetic data, with the
167
175
computed histogram in the column :py:const:`pe.constant.data.CLEAN_HISTOGRAM_COLUMN_NAME` and possibly with
168
176
the additional embedding column
@@ -189,10 +197,22 @@ def compute_histogram(self, priv_data, syn_data):
189
197
)
190
198
self ._log_voting_details (priv_data = priv_data , syn_data = syn_data , ids = ids )
191
199
192
- counter = Counter (list (ids .flatten ()))
200
+ priv_data = priv_data .reset_index (drop = True )
201
+ if self ._vote_normalization_level == "client" :
202
+ priv_data_list = priv_data .split_by_client ()
203
+ elif self ._vote_normalization_level == "sample" :
204
+ priv_data_list = priv_data .split_by_index ()
205
+ else :
206
+ raise ValueError (f"Unknown vote normalization level: { self ._vote_normalization_level } " )
207
+
193
208
count = np .zeros (shape = syn_embedding .shape [0 ], dtype = np .float32 )
194
- count [list (counter .keys ())] = list (counter .values ())
195
- count /= np .sqrt (self ._num_nearest_neighbors )
209
+ for sub_priv_data in priv_data_list :
210
+ sub_count = np .zeros (shape = syn_embedding .shape [0 ], dtype = np .float32 )
211
+ sub_ids = ids [sub_priv_data .data_frame .index ]
212
+ counter = Counter (list (sub_ids .flatten ()))
213
+ sub_count [list (counter .keys ())] = list (counter .values ())
214
+ sub_count /= np .linalg .norm (sub_count )
215
+ count += sub_count
196
216
197
217
syn_data .data_frame [CLEAN_HISTOGRAM_COLUMN_NAME ] = count
198
218
0 commit comments