Skip to content

Commit

Permalink
Fix substitution of relations
Browse files Browse the repository at this point in the history
Ref. #320
  • Loading branch information
treiher committed Jul 8, 2020
1 parent a6c6ca6 commit e2df8a5
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
19 changes: 2 additions & 17 deletions rflx/generator/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,6 @@ def byte_aggregate(aggregate: Aggregate) -> Aggregate:
)
return equal_call if isinstance(expression, Equal) else Not(equal_call)

literals = [
l
for t in message.types.values()
if isinstance(t, Enumeration)
for l in t.literals.keys()
]

def field_value(field: Field) -> Expr:
if public:
return Call(f"Get_{field.name}", [Variable("Ctx")])
Expand All @@ -147,23 +140,15 @@ def field_value(field: Field) -> Expr:
if (
isinstance(expression.left, Variable)
and Field(expression.left.name) in message.fields
and (
isinstance(expression.right, Number)
or (
isinstance(expression.right, Variable) and expression.right.name in literals
)
)
and isinstance(expression.right, Number)
):
return expression.__class__(
field_value(Field(expression.left.name)), expression.right
)
if (
isinstance(expression.right, Variable)
and Field(expression.right.name) in message.fields
and (
isinstance(expression.left, Number)
or (isinstance(expression.left, Variable) and expression.left.name in literals)
)
and isinstance(expression.left, Number)
):
return expression.__class__(
expression.left, field_value(Field(expression.right.name))
Expand Down
28 changes: 27 additions & 1 deletion tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Callable
from typing import Callable, Tuple

import pytest

Expand Down Expand Up @@ -203,6 +203,32 @@ def test_substitution_relation_aggregate(
)


@pytest.mark.parametrize(
"expressions,expected",
[
(
(expr.Variable("Length"), expr.Number(1)),
(expr.Call("Get_Length", [expr.Variable("Ctx")]), expr.Number(1)),
),
(
(expr.Number(1), expr.Variable("Length")),
(expr.Number(1), expr.Call("Get_Length", [expr.Variable("Ctx")])),
),
((expr.Number(1), expr.Variable("Unknown")), (expr.Number(1), expr.Variable("Unknown"))),
],
)
@pytest.mark.parametrize("relation", [expr.Less, expr.Equal, expr.Greater, expr.NotEqual])
def test_substitution_relation_scalar(
relation: Callable[[expr.Expr, expr.Expr], expr.Relation],
expressions: Tuple[expr.Expr, expr.Expr],
expected: Tuple[expr.Expr, expr.Expr],
) -> None:
assert_equal(
relation(*expressions).substituted(common.substitution(TLV_MESSAGE, public=True)),
relation(*expected),
)


def test_prefixed_type_name() -> None:
assert common.prefixed_type_name(ID("Modular"), "P") == ID("P.Modular")
for t in BUILTIN_TYPES:
Expand Down

0 comments on commit e2df8a5

Please sign in to comment.