Skip to content

Commit

Permalink
Add tests for conditional scrabble proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
carriepl-mila committed Apr 25, 2024
1 parent de1f774 commit a60b1fd
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gflownet/proxy/scrabble.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __call__(
output[valid_states] = -1.0 * self.scores[states[valid_states]].sum(dim=1)
return output
elif isinstance(states, list):
if not self.conditional:
if self.conditional:
raise ValueError(
"The Scrabble proxy doesn't support input states in a non-tensor "
"format when in conditional mode."
Expand Down
46 changes: 46 additions & 0 deletions tests/gflownet/proxy/test_scrabble_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def env():
return Scrabble(max_length=7, device="cpu")


@pytest.fixture
def env_conditional():
return Scrabble(max_length=7, device="cpu", conditional=True)


@pytest.mark.parametrize(
"samples, scores_expected",
[
Expand Down Expand Up @@ -95,3 +100,44 @@ def test__scrabble_scorer__returns_expected_scores_input_state2proxy(
sample_proxy = env.state2proxy()
score = proxy(sample_proxy)
assert score.tolist() == [-1.0 * score_expected]


@pytest.mark.parametrize(
"condition, sample, score_expected",
[
(
"ABCDTTT",
"C A T",
3 + 1 + 1,
),
(
"AACDOZZ",
"D O G Z",
0, # Word not in condition and not in vocabulary
),
(
"ABCDEFG",
"B I R D",
0, # Word in vocabulary but not in condition
),
(
"FGGNNN",
"G F N",
0, # Word in condition but not in vocabulary
),
(
"DEFINRS",
"F R I E N D S",
4 + 1 + 1 + 1 + 1 + 2 + 1,
),
],
)
def test__scrabble_scorer__returns_expected_scores_conditional_input_state2proxy(
env_conditional, proxy, condition, sample, score_expected
):
proxy.setup(env_conditional)
env_conditional.reset(condition=condition)
env_conditional.set_state(env_conditional.readable2state(sample))
sample_proxy = env_conditional.state2proxy()
score = proxy(sample_proxy)
assert score.tolist() == [-1.0 * score_expected]

0 comments on commit a60b1fd

Please sign in to comment.