Skip to content

Commit

Permalink
Fix generation of SPARK test code
Browse files Browse the repository at this point in the history
Ref. #993
  • Loading branch information
treiher committed Sep 23, 2022
1 parent 73070a4 commit 6324077
Showing 1 changed file with 40 additions and 11 deletions.
51 changes: 40 additions & 11 deletions tools/generate_spark_test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

"""Generate the SPARK code for the test project and all feature tests."""

from __future__ import annotations

import filecmp
import logging
import sys
from pathlib import Path

from rflx.common import unique
from rflx.generator import Generator
from rflx.integration import Integration
from rflx.model import Model, Session, Type
from rflx.specification import Parser
from tests.const import SPEC_DIR
from tests.unit.generator_test import MODELS
Expand All @@ -35,17 +38,25 @@
]


def main() -> int:
def main() -> None:
generate_spark_tests()
generate_feature_tests()


def generate_spark_tests() -> None:
remove_ada_files(OUTPUT_DIRECTORY)

parser = Parser()
parser.parse(*SPECIFICATION_FILES)
model = merge_models([parser.create_model(), *MODELS])
Generator(
"RFLX",
reproducible=True,
ignore_unsupported_checksum=True,
).generate(model, Integration(), OUTPUT_DIRECTORY)

for model in [parser.create_model(), *MODELS]:
Generator(
"RFLX",
reproducible=True,
ignore_unsupported_checksum=True,
).generate(model, Integration(), OUTPUT_DIRECTORY)

def generate_feature_tests() -> None:
generate(SHARED_DIRECTORY)
shared_files = {f.name: f for f in (SHARED_DIRECTORY / "generated").iterdir()}

Expand All @@ -56,12 +67,13 @@ def main() -> int:
f.unlink()
f.symlink_to(f"../../shared/generated/{f.name}")

return 0


def generate(feature_test: Path) -> None:
output_directory = feature_test / "generated"
output_directory.mkdir(exist_ok=True)

remove_ada_files(output_directory)

parser = Parser()
parser.parse(feature_test / "test.rflx")
Generator(
Expand All @@ -71,5 +83,22 @@ def generate(feature_test: Path) -> None:
).generate(parser.create_model(), parser.get_integration(), output_directory)


def remove_ada_files(directory: Path) -> None:
for f in directory.glob("*.ad?"):
print(f"Removing {f}")
f.unlink()


def merge_models(models: list[Model]) -> Model:
types: list[Type] = []
sessions: list[Session] = []

for m in models:
types.extend(m.types)
sessions.extend(m.sessions)

return Model(list(unique(types)), list(unique(sessions)))


if __name__ == "__main__":
sys.exit(main())
main()

0 comments on commit 6324077

Please sign in to comment.