Skip to content

Commit

Permalink
Set default sample size for linear models in noise-free and serial ca…
Browse files Browse the repository at this point in the history
…se to 2 * n + 2 (#14)
  • Loading branch information
timmens committed Jan 22, 2024
1 parent b0faac1 commit a5895dc
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
10 changes: 8 additions & 2 deletions src/tranquilo/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,17 @@ def get_default_acceptance_decider(noisy):
return "noisy" if noisy else "classic"


def get_default_sample_size(model_type, x):
def get_default_sample_size(model_type, x, noisy, batch_size):
if model_type == "quadratic":
out = 2 * len(x) + 1
else:
out = len(x) + 1
# Use one point more for the standard least-squares case. Benchmarks have not
# shown an improved performance for the noisy or parallel case with one
# additional point.
if noisy or batch_size > 1:
out = len(x) + 1
else:
out = len(x) + 2

return out

Expand Down
8 changes: 6 additions & 2 deletions src/tranquilo/process_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ def process_arguments(
sample_size=sample_size,
model_type=model_type,
x=x,
noisy=noisy,
batch_size=batch_size,
)
model_fitter = _process_model_fitter(
model_fitter, model_type=model_type, sample_size=target_sample_size, x=x
Expand Down Expand Up @@ -285,9 +287,11 @@ def _process_sample_filter(sample_filter, batch_size):
return out


def _process_sample_size(sample_size, model_type, x):
def _process_sample_size(sample_size, model_type, x, noisy, batch_size):
if sample_size is None:
out = get_default_sample_size(model_type=model_type, x=x)
out = get_default_sample_size(
model_type=model_type, x=x, noisy=noisy, batch_size=batch_size
)
elif callable(sample_size):
out = sample_size(x=x, model_type=model_type)
else:
Expand Down
35 changes: 31 additions & 4 deletions tests/test_process_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,42 @@ def test_process_batch_size_invalid():

def test_process_sample_size():
x = np.arange(3)
assert _process_sample_size(sample_size=None, model_type="linear", x=x) == 4
assert _process_sample_size(sample_size=None, model_type="quadratic", x=x) == 7
assert _process_sample_size(10, None, None) == 10
assert (
_process_sample_size(
sample_size=None, model_type="linear", x=x, noisy=True, batch_size=1
)
== 4
)
assert (
_process_sample_size(
sample_size=None, model_type="linear", x=x, noisy=False, batch_size=2
)
== 4
)
assert (
_process_sample_size(
sample_size=None, model_type="linear", x=x, noisy=False, batch_size=1
)
== 5
)
assert (
_process_sample_size(
sample_size=None, model_type="quadratic", x=x, noisy=False, batch_size=1
)
== 7
)
assert _process_sample_size(10, None, None, False, 1) == 10


def test_process_sample_size_callable():
x = np.arange(3)
sample_size = lambda x, model_type: len(x) ** 2
assert _process_sample_size(sample_size=sample_size, model_type="linear", x=x) == 9
assert (
_process_sample_size(
sample_size=sample_size, model_type="linear", x=x, noisy=False, batch_size=1
)
== 9
)


def test_process_model_type():
Expand Down

0 comments on commit a5895dc

Please sign in to comment.