Skip to content

Commit

Permalink
feat: add more substitution tests
Browse files Browse the repository at this point in the history
  • Loading branch information
BenTenmann committed Dec 9, 2021
1 parent 802bb3d commit c17c71c
Showing 1 changed file with 71 additions and 2 deletions.
73 changes: 71 additions & 2 deletions tests/test_substitution.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,29 @@
'incomplete_1'
]

add_tokens_single_value = [
('-', 1.),
('+', 1.),
('*', 1.)
]

add_tokens_multiple_values = [
('-', [0., 0., 1.]),
('+', [0., 0., 1.]),
('*', [0., 0., 1.])
]

add_tokens_existing = [
('hello', 1.),
('world', 1.)
]

add_tokens_multiple_values_wrong_dim = [
('-', [0., 1.]),
('+', [0., 0., 1., 0.]),
('*', [0.])
]


# ------ Fixtures ---------------------------------------------------------------------------------------------------- #
@pytest.fixture()
Expand Down Expand Up @@ -83,5 +106,51 @@ def test_substitution_matrix_from_json_error(mock_json_reader, file_name):
setriq.SubstitutionMatrix.from_json(file_name)


def test_substitution_matrix_add_token():
pass
@pytest.mark.parametrize(['token', 'value'], add_tokens_single_value)
def test_substitution_matrix_add_token_single_value(substitution_matrix_parts, token, value):
idx, scoring = substitution_matrix_parts()

sm = setriq.SubstitutionMatrix(index=idx, substitution_matrix=scoring)
n = len(sm)

lm = sm.add_token(token, value)
assert len(lm) == (n + 1)
assert lm(token, token) == value

sm.add_token(token, value, inplace=True)
assert len(sm) == (n + 1)
assert sm(token, token) == value


@pytest.mark.parametrize(['token', 'values'], add_tokens_multiple_values)
def test_substitution_matrix_add_token_multiple_values(substitution_matrix_parts, token, values):
idx, scoring = substitution_matrix_parts()

sm = setriq.SubstitutionMatrix(index=idx, substitution_matrix=scoring)
n = len(sm)

lm = sm.add_token(token, values)
assert len(lm) == (n + 1)
assert lm(token, token) == values[-1]

sm.add_token(token, values, inplace=True)
assert len(sm) == (n + 1)
assert sm(token, token) == values[-1]


@pytest.mark.parametrize(['token', 'value'], add_tokens_existing)
def test_substitution_matrix_add_token_existing(substitution_matrix_parts, token, value):
idx, scoring = substitution_matrix_parts()

sm = setriq.SubstitutionMatrix(index=idx, substitution_matrix=scoring)
with pytest.raises(ValueError):
sm.add_token(token, value)


@pytest.mark.parametrize(['token', 'values'], add_tokens_multiple_values_wrong_dim)
def test_substitution_matrix_add_token_existing(substitution_matrix_parts, token, values):
idx, scoring = substitution_matrix_parts()

sm = setriq.SubstitutionMatrix(index=idx, substitution_matrix=scoring)
with pytest.raises(ValueError):
sm.add_token(token, values)

0 comments on commit c17c71c

Please sign in to comment.