Skip to content

Commit

Permalink
Adapt Hartmann and implement tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed May 29, 2024
1 parent a1d8cd0 commit ad4229c
Show file tree
Hide file tree
Showing 2 changed files with 241 additions and 23 deletions.
86 changes: 63 additions & 23 deletions gflownet/proxy/box/hartmann.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,72 @@
This code is based on the implementation by Nikita Saxena (nikita-0209) in
https://github.com/alexhernandezgarcia/activelearning
The implementation assumes that the inputs will be on [0, 1]^6 as is typical in the
uses of the Hartmann function. The original range is negative, which is the convention
for other proxy classes, and negate=False is used in the call to the BoTorch method in
order to keep the range.
The implementation assumes that the inputs will be on [-1, 1]^6 as is typical in the
uses of the Hartmann function. The original range is negative and is a minimisation
problem. By default, the proxy values remain in this range and the absolute value of
the proxy values is used as the reward function.
"""

from typing import Callable, Optional, Union

import numpy as np
import torch
from botorch.test_functions.multi_fidelity import AugmentedHartmann
from torchtyping import TensorType

from gflownet.proxy.base import Proxy

# Global optimum, according to BoTorch
OPTIMUM = -3.32237
# A rough estimate of modes
MODES = [
[0.2, 0.2, 0.5, 0.3, 0.3, 0.7],
[0.4, 0.9, 0.9, 0.6, 0.1, 0.0],
[0.3, 0.1, 0.4, 0.3, 0.3, 0.7],
[0.4, 0.9, 0.4, 0.6, 0.0, 0.0],
[0.4, 0.9, 0.6, 0.6, 0.3, 0.0],
]


class Hartmann(Proxy):
def __init__(self, fidelity=1.0, **kwargs):
def __init__(
self,
fidelity=1.0,
do_domain_map: bool = True,
negate: bool = False,
reward_function: Optional[Union[Callable, str]] = "absolute",
**kwargs
):
"""
Parameters
----------
fidelity : float
Fidelity of the Hartmann oracle. 1.0 corresponds to the original Hartmann.
Smaller values (up to 0.0) reduce the fidelity of the oracle.
do_domain_map : bool
If True, the states are assumed to be in [-1, 1]^6 and are re-mapped to the
standard domain in [0, 1]^6 before calling the botorch method. If False,
the botorch method is called directly on the states values.
negate : bool
If True, proxy values are multiplied by -1.
reward_function : str or Callable
The transformation applied to the proxy outputs to obtain a GFlowNet
reward. By default, the reward function is the absolute value of proxy
outputs.
See: https://botorch.org/api/test_functions.html
"""
# Replace the value of reward_function in kwargs by the one passed explicitly
# as a parameter
kwargs["reward_function"] = reward_function
super().__init__(**kwargs)
self.fidelity = fidelity
self.function_mf_botorch = AugmentedHartmann(negate=False)
# This is just a rough estimate of modes
self.modes = [
[0.2, 0.2, 0.5, 0.3, 0.3, 0.7],
[0.4, 0.9, 0.9, 0.6, 0.1, 0.0],
[0.3, 0.1, 0.4, 0.3, 0.3, 0.7],
[0.4, 0.9, 0.4, 0.6, 0.0, 0.0],
[0.4, 0.9, 0.6, 0.6, 0.3, 0.0],
]
# Global optimum, according to BoTorch
self.extremum = -3.32237
self.do_domain_map = do_domain_map
self.function_mf_botorch = AugmentedHartmann(negate=negate)
# Optimum
self._optimum = torch.tensor(OPTIMUM, device=self.device, dtype=self.float)
if negate:
self._optimum *= -1.0

def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]:
if states.shape[1] != 6:
Expand All @@ -58,10 +95,13 @@ def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batc
)
return self.function_mf_botorch(states)

@property
def min(self):
if not hasattr(self, "_min"):
self._min = torch.tensor(
self.extremum, device=self.device, dtype=self.float
)
return self._min
def map_to_standard_domain(
self,
states: TensorType["batch", "6"],
) -> TensorType["batch", "6"]:
"""
Maps a batch of input states onto the domain typically used to evaluate the
Hartmann function, that is [0, 1]^6. See DOMAIN and LENGTH. It assumes that the
inputs are on [-1, 1]^6
"""
return (states + 1.0) / 2.0
178 changes: 178 additions & 0 deletions tests/gflownet/proxy/test_hartmann.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import pytest
import torch

from gflownet.envs.cube import ContinuousCube
from gflownet.envs.grid import Grid
from gflownet.proxy.box.hartmann import Hartmann
from gflownet.utils.common import tfloat


@pytest.fixture()
def proxy_default():
return Hartmann(device="cpu", float_precision=32)


@pytest.fixture()
def proxy_negate_exp_reward():
return Hartmann(
negate=True,
reward_function="exponential",
reward_function_kwargs={"beta": 1.0},
device="cpu",
float_precision=32,
)


@pytest.fixture()
def proxy_fid01_exp_reward():
return Hartmann(
fidelity=0.1,
reward_function="exponential",
reward_function_kwargs={"beta": -1.0},
device="cpu",
float_precision=32,
)


@pytest.fixture
def grid():
return Grid(n_dim=6, length=10, device="cpu")


@pytest.fixture
def cube():
return ContinuousCube(n_dim=6, n_comp=3, min_incr=0.1)


@pytest.mark.parametrize(
"samples, samples_standard_domain",
[
(
[
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[-1.0, -0.5, 0.5, -0.2, 0.2, 1.0],
],
[
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[0.0, 0.25, 0.75, 0.4, 0.6, 1.0],
],
),
],
)
def test__map_to_standard_domain__returns_expected(
proxy_default, samples, samples_standard_domain
):
proxy = proxy_default
samples = tfloat(samples, float_type=proxy.float, device=proxy.device)
samples_standard_domain = tfloat(
samples_standard_domain, float_type=proxy.float, device=proxy.device
)
assert torch.allclose(
proxy.map_to_standard_domain(samples), samples_standard_domain
)


@pytest.mark.parametrize(
"proxy, samples, proxy_expected",
[
(
"proxy_default",
[
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[-1.0, -0.5, 0.5, -0.2, 0.2, 1.0],
[0.2, 0.2, 0.5, 0.3, 0.3, 0.7],
],
[-5.4972e-35, -3.4085e-05, -2.5341e-04, -3.2216],
),
(
"proxy_negate_exp_reward",
[
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[-1.0, -0.5, 0.5, -0.2, 0.2, 1.0],
[0.2, 0.2, 0.5, 0.3, 0.3, 0.7],
],
[5.4972e-35, 3.4085e-05, 2.5341e-04, 3.2216],
),
(
"proxy_fid01_exp_reward",
[
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[-1.0, -0.5, 0.5, -0.2, 0.2, 1.0],
[0.2, 0.2, 0.5, 0.3, 0.3, 0.7],
],
[-5.4971e-35, -3.4084e-05, -2.5340e-04, -3.1874],
),
],
)
def test__proxy__returns_expected(proxy, samples, proxy_expected, request):
proxy = request.getfixturevalue(proxy)
samples = tfloat(samples, float_type=proxy.float, device=proxy.device)
proxy_expected = tfloat(proxy_expected, float_type=proxy.float, device=proxy.device)
assert torch.allclose(proxy(samples), proxy_expected, atol=1e-04)


@pytest.mark.parametrize(
"proxy, samples, rewards_expected",
[
(
"proxy_default",
[
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[-1.0, -0.5, 0.5, -0.2, 0.2, 1.0],
[0.2, 0.2, 0.5, 0.3, 0.3, 0.7],
],
[5.4972e-35, 3.4085e-05, 2.5341e-04, 3.2216],
),
(
"proxy_negate_exp_reward",
[
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[-1.0, -0.5, 0.5, -0.2, 0.2, 1.0],
[0.2, 0.2, 0.5, 0.3, 0.3, 0.7],
],
[1.0000, 1.0000, 1.0003, 25.0672],
),
(
"proxy_fid01_exp_reward",
[
[-1.0, -1.0, -1.0, -1.0, -1.0, -1.0],
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
[-1.0, -0.5, 0.5, -0.2, 0.2, 1.0],
[0.2, 0.2, 0.5, 0.3, 0.3, 0.7],
],
[1.0000, 1.0000, 1.0003, 24.2251],
),
],
)
def test__rewards__returns_expected(proxy, samples, rewards_expected, request):
proxy = request.getfixturevalue(proxy)
samples = tfloat(samples, float_type=proxy.float, device=proxy.device)
rewards_expected = tfloat(
rewards_expected, float_type=proxy.float, device=proxy.device
)
assert torch.allclose(proxy.rewards(samples), rewards_expected, atol=1e-04)


@pytest.mark.parametrize(
"proxy, max_reward_expected",
[
(
"proxy_default",
3.32237,
),
(
"proxy_negate_exp_reward",
27.7260,
),
],
)
def test__get_max_reward__returns_expected(proxy, max_reward_expected, request):
proxy = request.getfixturevalue(proxy)
assert torch.isclose(proxy.get_max_reward(), torch.tensor(max_reward_expected))

0 comments on commit ad4229c

Please sign in to comment.