Skip to content

Commit

Permalink
tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dvadym committed Dec 1, 2023
1 parent 661341b commit f288b3b
Showing 1 changed file with 12 additions and 11 deletions.
23 changes: 12 additions & 11 deletions analysis/tests/utility_analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,11 @@ def test_multi_parameters(self):

multi_param = analysis.MultiParameterConfiguration(
max_partitions_contributed=[1, 2],
max_contributions_per_partition=[1, 2])
max_contributions_per_partition=[1, 2],
partition_selection_strategy=[
pipeline_dp.PartitionSelectionStrategy.TRUNCATED_GEOMETRIC,
pipeline_dp.PartitionSelectionStrategy.GAUSSIAN_THRESHOLDING
])

# Input collection has 1 privacy id, which contributes to 2 partitions
# 1 and 2 times correspondingly.
Expand All @@ -253,8 +257,6 @@ def test_multi_parameters(self):
partition_extractor=lambda x: x[1],
value_extractor=lambda x: 0)

public_partitions = ["pk0", "pk1"]

output, _ = analysis.perform_utility_analysis(
col=input,
backend=pipeline_dp.LocalBackend(),
Expand All @@ -264,7 +266,6 @@ def test_multi_parameters(self):
aggregate_params=aggregate_params,
multi_param_configuration=multi_param),
data_extractors=data_extractors,
public_partitions=public_partitions,
)

utility_reports = list(output)
Expand All @@ -273,16 +274,16 @@ def test_multi_parameters(self):
self.assertLen(utility_reports, 2) # one report per each configuration.

# Check the parameter configuration
expected_noise_std = [3.02734375, 8.56262117843085]
expected_noise_std = [5.9765625, 16.904271487740903]
expected_l0_error = [-0.5, 0]
expected_partition_info = metrics.PartitionsInfo(
public_partitions=True,
num_dataset_partitions=2,
num_non_public_partitions=0,
num_empty_partitions=0)
for i_configuration, report in enumerate(utility_reports):
self.assertEqual(report.configuration_index, i_configuration)
self.assertEqual(report.partitions_info, expected_partition_info)
self.assertFalse(report.partitions_info.public_partitions)
self.assertEqual(report.partitions_info.num_dataset_partitions, 2)
self.assertEqual(report.partitions_info.num_dataset_partitions, 2)
self.assertEqual(
report.partitions_info.strategy,
multi_param.partition_selection_strategy[i_configuration])
self.assertLen(report.metric_errors, 1) # metrics for COUNT
errors = report.metric_errors[0]
self.assertEqual(errors.metric, pipeline_dp.Metrics.COUNT)
Expand Down

0 comments on commit f288b3b

Please sign in to comment.