Skip to content

Commit

Permalink
Fix code generation for comparison of field with aggregate
Browse files Browse the repository at this point in the history
Ref. #328
  • Loading branch information
treiher committed Jul 9, 2020
1 parent d0d3ba2 commit 8f408b9
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 2 deletions.
7 changes: 6 additions & 1 deletion rflx/generator/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Mod,
Name,
NamedAggregate,
Not,
NotEqual,
Number,
Or,
Expand Down Expand Up @@ -87,6 +88,7 @@ def byte_aggregate(aggregate: Aggregate) -> Aggregate:
field = Field(expression.right.name)
aggregate = byte_aggregate(expression.left)
if field and aggregate:
assert field in message.fields
if embedded:
return Equal(
Indexed(
Expand Down Expand Up @@ -118,7 +120,10 @@ def byte_aggregate(aggregate: Aggregate) -> Aggregate:
),
aggregate,
)
return Call("Equal", [Variable("Ctx"), Variable(field.affixed_name), aggregate])
equal_call = Call(
"Equal", [Variable("Ctx"), Variable(field.affixed_name), aggregate]
)
return equal_call if isinstance(expression, Equal) else Not(equal_call)

literals = [
l
Expand Down
41 changes: 40 additions & 1 deletion tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from pathlib import Path
from typing import Callable

import pytest

import rflx.expression as expr
from rflx.generator import Generator, common, const
from rflx.identifier import ID
from rflx.model import Model, Type
from rflx.model import BUILTIN_TYPES, Model, Type
from tests.models import (
ARRAYS_MODEL,
DERIVATION_MODEL,
Expand All @@ -15,8 +17,10 @@
NULL_MESSAGE_IN_TLV_MESSAGE_MODEL,
NULL_MODEL,
RANGE_INTEGER,
TLV_MESSAGE,
TLV_MODEL,
)
from tests.utils import assert_equal

TESTDIR = Path("generated")

Expand Down Expand Up @@ -170,6 +174,41 @@ def test_derivation_body() -> None:
assert_body(generator)


@pytest.mark.parametrize(
"left,right",
[
(expr.Variable("Value"), expr.Aggregate(expr.Number(1), expr.Number(2))),
(expr.Aggregate(expr.Number(1), expr.Number(2)), expr.Variable("Value")),
],
)
@pytest.mark.parametrize("relation", [expr.Equal, expr.NotEqual])
def test_substitution_relation_aggregate(
relation: Callable[[expr.Expr, expr.Expr], expr.Relation], left: expr.Expr, right: expr.Expr
) -> None:
equal_call = expr.Call(
"Equal",
[
expr.Variable("Ctx"),
expr.Variable("F_Value"),
expr.Aggregate(
expr.Val(expr.Variable("Types.Byte"), expr.Number(1)),
expr.Val(expr.Variable("Types.Byte"), expr.Number(2)),
),
],
)

assert_equal(
relation(left, right).substituted(common.substitution(TLV_MESSAGE)),
equal_call if relation == expr.Equal else expr.Not(equal_call),
)


def test_prefixed_type_name() -> None:
assert common.prefixed_type_name(ID("Modular"), "P") == ID("P.Modular")
for t in BUILTIN_TYPES:
assert common.prefixed_type_name(t, "P") == t


def test_base_type_name() -> None:
assert common.base_type_name(MODULAR_INTEGER) == ID("Modular")
assert common.base_type_name(RANGE_INTEGER) == ID("Range_Base")
Expand Down

0 comments on commit 8f408b9

Please sign in to comment.