Skip to content

Commit

Permalink
Merge pull request #319 from alexhernandezgarcia/evaluator-fix
Browse files Browse the repository at this point in the history
Bug fix of Evaluator (Alex's fault)
  • Loading branch information
alexhernandezgarcia committed Jun 6, 2024
2 parents 5d4b45e + 6da67bd commit c972237
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gflownet/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs):
values are the figures.
"""

fig_kde_pred = fig_kde_true = fig_reward_samples = None
fig_kde_pred = fig_kde_true = fig_reward_samples = fig_samples_topk = None

if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None:
(sample_space_batch, rewards_sample_space) = (
Expand Down
2 changes: 2 additions & 0 deletions tests/gflownet/envs/test_tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def setup(self, env):
self.env = env
self.repeats = {
"test__reset__state_is_source": 10,
"test__gflownet_minimal_runs": 0,
}
self.n_states = {} # TODO: Populate.

Expand All @@ -559,6 +560,7 @@ def setup(self, env_full):
self.env = env_full
self.repeats = {
"test__reset__state_is_source": 10,
"test__gflownet_minimal_runs": 0,
}
self.n_states = {} # TODO: Populate.

Expand Down
3 changes: 3 additions & 0 deletions tests/gflownet/evaluator/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ def test__eval(gflownet_for_tests, parameterization):
elif parameterization == "ctorus":
for figname, fig in figs.items():
assert isinstance(figname, str)
# plot_samples_topk not implemented in ctorus
if figname == "Samples TopK":
continue
assert isinstance(fig, plt.Figure)
else:
raise ValueError(f"Unknown parameterization: {parameterization}")

0 comments on commit c972237

Please sign in to comment.