Skip to content

Commit

Permalink
Fix generation of SPARK test code
Browse files Browse the repository at this point in the history
Ref. #993
  • Loading branch information
treiher committed Sep 22, 2022
1 parent 3bcc943 commit 6fa3050
Showing 1 changed file with 38 additions and 11 deletions.
49 changes: 38 additions & 11 deletions tools/generate_spark_test_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

import filecmp
import logging
import sys
from pathlib import Path

from rflx.common import unique
from rflx.generator import Generator
from rflx.integration import Integration
from rflx.model import Model, Session, Type
from rflx.specification import Parser
from tests.const import SPEC_DIR
from tests.unit.generator_test import MODELS
Expand All @@ -35,17 +36,25 @@
]


def main() -> int:
def main() -> None:
generate_spark_tests()
generate_feature_tests()


def generate_spark_tests() -> None:
remove_ada_files(OUTPUT_DIRECTORY)

parser = Parser()
parser.parse(*SPECIFICATION_FILES)
model = merge_models([parser.create_model(), *MODELS])
Generator(
"RFLX",
reproducible=True,
ignore_unsupported_checksum=True,
).generate(model, Integration(), OUTPUT_DIRECTORY)

for model in [parser.create_model(), *MODELS]:
Generator(
"RFLX",
reproducible=True,
ignore_unsupported_checksum=True,
).generate(model, Integration(), OUTPUT_DIRECTORY)

def generate_feature_tests() -> None:
generate(SHARED_DIRECTORY)
shared_files = {f.name: f for f in (SHARED_DIRECTORY / "generated").iterdir()}

Expand All @@ -56,12 +65,13 @@ def main() -> int:
f.unlink()
f.symlink_to(f"../../shared/generated/{f.name}")

return 0


def generate(feature_test: Path) -> None:
output_directory = feature_test / "generated"
output_directory.mkdir(exist_ok=True)

remove_ada_files(output_directory)

parser = Parser()
parser.parse(feature_test / "test.rflx")
Generator(
Expand All @@ -71,5 +81,22 @@ def generate(feature_test: Path) -> None:
).generate(parser.create_model(), parser.get_integration(), output_directory)


def remove_ada_files(directory: Path) -> None:
for f in directory.glob("*.ad?"):
print(f"Removing {f}")
f.unlink()


def merge_models(models: list[Model]) -> Model:
types: list[Type] = []
sessions: list[Session] = []

for m in models:
types.extend(m.types)
sessions.extend(m.sessions)

return Model(list(unique(types)), list(unique(sessions)))


if __name__ == "__main__":
sys.exit(main())
main()

0 comments on commit 6fa3050

Please sign in to comment.