Skip to content

Commit

Permalink
Choose sample filter depedent on batch size (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
timmens committed Jan 21, 2024
1 parent 055d7a2 commit b0faac1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
4 changes: 4 additions & 0 deletions src/tranquilo/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def get_default_stagnation_options(noisy, batch_size):
return out


def get_default_sample_filter(batch_size):
return "drop_excess" if batch_size > 1 else "keep_all"


def get_default_radius_options(x):
return RadiusOptions(initial_radius=0.1 * np.max(np.abs(x)))

Expand Down
13 changes: 12 additions & 1 deletion src/tranquilo/process_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_default_sample_size,
get_default_search_radius_factor,
get_default_stagnation_options,
get_default_sample_filter,
update_option_bundle,
NoiseAdaptationOptions,
)
Expand Down Expand Up @@ -73,7 +74,7 @@ def process_arguments(
# component names and related options
sampler="optimal_hull",
sampler_options=None,
sample_filter="keep_all",
sample_filter=None,
sample_filter_options=None,
model_fitter=None,
model_fitter_options=None,
Expand Down Expand Up @@ -156,6 +157,7 @@ def process_arguments(
acceptance_decider = _process_acceptance_decider(acceptance_decider, noisy)

# process options that depend on arguments with dependent defaults
sample_filter = _process_sample_filter(sample_filter, batch_size)
stagnation_options = update_option_bundle(
get_default_stagnation_options(noisy, batch_size=batch_size), stagnation_options
)
Expand Down Expand Up @@ -274,6 +276,15 @@ def _process_batch_size(batch_size, n_cores):
return int(batch_size)


def _process_sample_filter(sample_filter, batch_size):
if sample_filter is None:
out = get_default_sample_filter(batch_size)
else:
out = sample_filter

return out


def _process_sample_size(sample_size, model_type, x):
if sample_size is None:
out = get_default_sample_size(model_type=model_type, x=x)
Expand Down

0 comments on commit b0faac1

Please sign in to comment.