Skip to content

Commit

Permalink
Update proxies: outputs are not negative by default anymore; Restore …
Browse files Browse the repository at this point in the history
…proxy2reward and beta values in configs but as proxy config
  • Loading branch information
alexhernandezgarcia committed May 7, 2024
1 parent 7f9e136 commit 4267acb
Show file tree
Hide file tree
Showing 15 changed files with 54 additions and 16 deletions.
5 changes: 5 additions & 0 deletions config/experiments/crystals/albatross.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ env:
output_csv: ccrystal_val.csv
output_pkl: ccrystal_val.pkl

# Proxy
proxy:
reward_function: exponential
beta: 8

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
Expand Down
5 changes: 5 additions & 0 deletions config/experiments/crystals/albatross_sg_first.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ env:
output_csv: ccrystal_val.csv
output_pkl: ccrystal_val.pkl

# Proxy
proxy:
reward_function: exponential
beta: 8

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
Expand Down
5 changes: 5 additions & 0 deletions config/experiments/crystals/lattice_parameters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ env:
buffer:
replay_capacity: 1000

# Proxy
proxy:
reward_function: exponential
beta: 0.3

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
Expand Down
5 changes: 5 additions & 0 deletions config/experiments/crystals/pigeon.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ env:
output_csv: ccrystal_val.csv
output_pkl: ccrystal_val.pkl

# Proxy
proxy:
reward_function: exponential
beta: 8

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
Expand Down
5 changes: 5 additions & 0 deletions config/experiments/neurips23/crystal-comp-sg-lp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ env:
composition_kwargs:
elements: 89

# Proxy
proxy:
reward_function: exponential
beta: 1

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
Expand Down
6 changes: 5 additions & 1 deletion config/experiments/simple_tetris.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ env:
output_pkl: simple_tetris_val.pkl
n: 100

proxy:
reward_function: exponential
beta: 10

gflownet:
random_action_prob: 0.3
optimizer:
Expand All @@ -42,4 +46,4 @@ device: cpu
logger:
do:
online: True
project_name: simple_tetris
project_name: simple_tetris
5 changes: 5 additions & 0 deletions config/experiments/tree.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ env:
buffer:
replay_capacity: 100

# Proxy
proxy:
reward_function: exponential
beta: 32

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
Expand Down
5 changes: 5 additions & 0 deletions config/experiments/workshop23/discrete-matbench.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ env:
buffer:
replay_capacity: 0

# Proxy
proxy:
reward_function: exponential
beta: 1

# GFlowNet hyperparameters
gflownet:
random_action_prob: 0.1
Expand Down
6 changes: 3 additions & 3 deletions config/proxy/base.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
_target_: gflownet.proxy.base.Proxy

# Reward function: string identifier of the proxy-to-reward function:
# - identity
# - absolute (default)
# - identity (default)
# - absolute
# - power
# - exponential
# - shift
# - product
# Alternatively, it can be a callable of the function itself.
reward_function: absolute
reward_function: identity
# A callable of the proxy-to-logreward function.
# None by default, which takes the log of the proxy-to-reward function
logreward_function: null
Expand Down
2 changes: 1 addition & 1 deletion gflownet/proxy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
self,
device,
float_precision,
reward_function: Optional[Union[Callable, str]] = "absolute",
reward_function: Optional[Union[Callable, str]] = "identity",
logreward_function: Optional[Callable] = None,
reward_function_kwargs: Optional[dict] = {},
reward_min: float = 0.0,
Expand Down
3 changes: 1 addition & 2 deletions gflownet/proxy/corners.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ def min(self):

def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]:
return (
-1.0
* self.mulnormal_norm
self.mulnormal_norm
* torch.exp(
-0.5
* (
Expand Down
2 changes: 1 addition & 1 deletion gflownet/proxy/scrabble.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __call__(
):
scores.append(0.0)
else:
scores.append(-1.0 * self._sum_scores(sample))
scores.append(self._sum_scores(sample))
return tfloat(scores, device=self.device, float_type=self.float)
else:
raise NotImplementedError(
Expand Down
4 changes: 2 additions & 2 deletions gflownet/proxy/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def setup(self, env=None):
@property
def norm(self):
if self.normalize:
return -(self.height * self.width)
return (self.height * self.width)
else:
return -1.0
return 1.0

def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]:
if states.dim() == 2:
Expand Down
8 changes: 4 additions & 4 deletions gflownet/proxy/torus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ def setup(self, env=None):
def min(self):
if not hasattr(self, "_min"):
if self.normalize:
self._min = torch.tensor(-1.0, device=self.device, dtype=self.float)
self._min = torch.tensor(0.0, device=self.device, dtype=self.float)
else:
self._min = torch.tensor(
-((self.n_dim * 2) ** 3), device=self.device, dtype=self.float
((self.n_dim * 2) ** 3), device=self.device, dtype=self.float
)
return self._min

Expand All @@ -31,10 +31,10 @@ def norm(self):
if not hasattr(self, "_norm"):
if self.normalize:
self._norm = torch.tensor(
-((self.n_dim * 2) ** 3), device=self.device, dtype=self.float
((self.n_dim * 2) ** 3), device=self.device, dtype=self.float
)
else:
self._norm = torch.tensor(-1.0, device=self.device, dtype=self.float)
self._norm = torch.tensor(1.0, device=self.device, dtype=self.float)
return self._norm

def __call__(self, states: TensorType["batch", "state_dim"]) -> TensorType["batch"]:
Expand Down
4 changes: 2 additions & 2 deletions gflownet/proxy/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ def __init__(self, **kwargs):
def __call__(
self, states: Union[List, TensorType["batch", "state_dim"]]
) -> TensorType["batch"]:
return -1.0 * torch.ones(len(states), device=self.device, dtype=self.float)
return torch.ones(len(states), device=self.device, dtype=self.float)

@property
def min(self):
if not hasattr(self, "_min"):
self._min = torch.tensor(-1.0, device=self.device, dtype=self.float)
self._min = torch.tensor(1.0, device=self.device, dtype=self.float)
return self._min

0 comments on commit 4267acb

Please sign in to comment.