Skip to content

Commit

Permalink
Fix generation of SPARK test code
Browse files Browse the repository at this point in the history
Ref. #993
  • Loading branch information
treiher committed Sep 23, 2022
1 parent 3bcc943 commit 9aba2cb
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions tools/generate_spark_test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

"""Generate the SPARK code for the test project and all feature tests."""

from __future__ import annotations

import filecmp
import logging
import sys
from pathlib import Path

from rflx.common import unique
from rflx.generator import Generator
from rflx.integration import Integration
from rflx.model import Model, Session, Type
from rflx.specification import Parser
from tests.const import SPEC_DIR
from tests.unit.generator_test import MODELS
Expand All @@ -35,17 +38,25 @@
]


def main() -> int:
def main() -> None:
generate_spark_tests()
generate_feature_tests()


def generate_spark_tests() -> None:
remove_ada_files(OUTPUT_DIRECTORY)

parser = Parser()
parser.parse(*SPECIFICATION_FILES)
model = merge_models([parser.create_model(), *MODELS])
Generator(
"RFLX",
reproducible=True,
ignore_unsupported_checksum=True,
).generate(model, Integration(), OUTPUT_DIRECTORY)

for model in [parser.create_model(), *MODELS]:
Generator(
"RFLX",
reproducible=True,
ignore_unsupported_checksum=True,
).generate(model, Integration(), OUTPUT_DIRECTORY)

def generate_feature_tests() -> None:
generate(SHARED_DIRECTORY)
shared_files = {f.name: f for f in (SHARED_DIRECTORY / "generated").iterdir()}

Expand All @@ -56,12 +67,13 @@ def main() -> int:
f.unlink()
f.symlink_to(f"../../shared/generated/{f.name}")

return 0


def generate(feature_test: Path) -> None:
output_directory = feature_test / "generated"
output_directory.mkdir(exist_ok=True)

remove_ada_files(output_directory)

parser = Parser()
parser.parse(feature_test / "test.rflx")
Generator(
Expand All @@ -71,5 +83,22 @@ def generate(feature_test: Path) -> None:
).generate(parser.create_model(), parser.get_integration(), output_directory)


def remove_ada_files(directory: Path) -> None:
for f in directory.glob("*.ad?"):
print(f"Removing {f}")
f.unlink()


def merge_models(models: list[Model]) -> Model:
types: list[Type] = []
sessions: list[Session] = []

for m in models:
types.extend(m.types)
sessions.extend(m.sessions)

return Model(list(unique(types)), list(unique(sessions)))


if __name__ == "__main__":
sys.exit(main())
main()

0 comments on commit 9aba2cb

Please sign in to comment.