Skip to content

Commit

Permalink
Changes in Branin proxy and test
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed May 29, 2024
1 parent d9f7353 commit 553d377
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 7 deletions.
25 changes: 18 additions & 7 deletions gflownet/proxy/box/branin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torchtyping import TensorType

from gflownet.proxy.base import Proxy
from gflownet.utils.common import tfloat

X1_DOMAIN = [-5, 10]
X1_LENGTH = X1_DOMAIN[1] - X1_DOMAIN[0]
Expand All @@ -31,6 +32,7 @@ def __init__(
self,
fidelity=1.0,
do_domain_map=True,
shift_to_negative=False,
reward_function="product",
rewareward_function_kwargs={"beta": -1},
**kwargs
Expand All @@ -51,7 +53,15 @@ def __init__(
super().__init__(**kwargs)
self.fidelity = fidelity
self.do_domain_map = do_domain_map
self.shift_to_negative = shift_to_negative
self.function_mf_botorch = AugmentedBranin(negate=False)
# Constants
self.domain_left = tfloat(
[[X1_DOMAIN[0], X2_DOMAIN[0]]], float_type=self.float, device=self.device
)
self.domain_length = tfloat(
[[X1_LENGTH, X2_LENGTH]], float_type=self.float, device=self.device
)
# Modes and extremum compatible with 100x100 grid
self.modes = [
[12.4, 81.833],
Expand All @@ -69,7 +79,7 @@ def __call__(self, states: TensorType["batch", "2"]) -> TensorType["batch"]:
"""
)
if self.do_domain_map:
states = Branin.map_to_standard_domain(states)
states = self.map_to_standard_domain(states)
# Append fidelity as a new dimension of states
states = torch.cat(
[
Expand All @@ -81,7 +91,10 @@ def __call__(self, states: TensorType["batch", "2"]) -> TensorType["batch"]:
],
dim=1,
)
return Branin.map_to_negative_range(self.function_mf_botorch(states))
if self.shift_to_negative:
return Branin.map_to_negative_range(self.function_mf_botorch(states))
else:
return self.function_mf_botorch(states)

@property
def min(self):
Expand All @@ -91,18 +104,16 @@ def min(self):
)
return self._min

@staticmethod
def map_to_standard_domain(
states: TensorType["batch", "2"]
self,
states: TensorType["batch", "2"],
) -> TensorType["batch", "2"]:
"""
Maps a batch of input states onto the domain typically used to evaluate the
Branin function. See X1_DOMAIN and X2_DOMAIN. It assumes that the inputs are on
[-1, 1] x [-1, 1].
"""
states[:, 0] = X1_DOMAIN[0] + ((states[:, 0] + 1.0) * X1_LENGTH) / 2.0
states[:, 1] = X2_DOMAIN[0] + ((states[:, 1] + 1.0) * X2_LENGTH) / 2.0
return states
return self.domain_left + ((states + 1.0) * self.domain_length) / 2.0

@staticmethod
def map_to_negative_range(values: TensorType["batch"]) -> TensorType["batch"]:
Expand Down
24 changes: 24 additions & 0 deletions tests/gflownet/proxy/test_branin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,27 @@ def test__map_to_standard_domain__returns_expected(
assert torch.allclose(
proxy.map_to_standard_domain(samples), samples_standard_domain
)


@pytest.mark.parametrize(
"samples, proxy_expected",
[
(
[
[-1.0, -1.0],
[-1.0, 1.0],
[1.0, -1.0],
[1.0, 1.0],
[0.0, -1.0],
[0.0, 1.0],
[0.0, 0.0],
],
[308.1291, 17.5083, 10.9609, 145.8722, 10.3079, 150.4520, 24.1300],
),
],
)
def test__proxy__returns_expected(proxy_default, samples, proxy_expected):
proxy = proxy_default
samples = tfloat(samples, float_type=proxy.float, device=proxy.device)
proxy_expected = tfloat(proxy_expected, float_type=proxy.float, device=proxy.device)
assert torch.allclose(proxy(samples), proxy_expected)

0 comments on commit 553d377

Please sign in to comment.