diff --git a/.github/codecov.yml b/.github/codecov.yml index ab58d8b71c..eb64341e32 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -29,6 +29,11 @@ coverage: threshold: 0.5% flags: - tests + railjson_generator: + target: auto + threshold: 0.5% + flags: + - railjson_generator patch: default: off @@ -48,3 +53,6 @@ flags: tests: paths: - tests/ + railjson_generator: + paths: + - python/railjson_generator/ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 847c92c485..da291b0b81 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -159,6 +159,22 @@ jobs: run: | cd python/railjson_generator poetry run pytype -j auto + + - name: Pytest + run: | + cd python/railjson_generator + poetry run pytest --cov --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + name: codecov + flags: railjson_generator + directory: ./python/railjson_generator + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false + verbose: true + files: coverage.xml check_integration_tests: runs-on: ubuntu-latest diff --git a/python/railjson_generator/README.md b/python/railjson_generator/README.md index 82c86d7cf4..efd1b95280 100644 --- a/python/railjson_generator/README.md +++ b/python/railjson_generator/README.md @@ -1,5 +1,11 @@ # RAILJSON GENERATOR +Use poetry to install dependencies: + +```sh +poetry install +``` + ## Running generation scripts To run a generation script, pass its output directory as its first argument: @@ -15,7 +21,7 @@ This library provides an helper to generate multiple infrastructures at once: poetry run python3 -m railjson_generator /tmp/all_infras scripts/*.py ``` -## How to use +## API ### Infra Builder @@ -82,3 +88,25 @@ Route can either be manually created, or generated using `generate_routes`, and ## Example You can find a complete example [here](./railjson_generator/scripts/examples/example_script.py). + +## Testing + +```sh +poetry run pytest +``` + +## Linting + +Use pflake8 and pytype to check for style issues and potential errors. + +```sh +$ poetry run pflake8 --config ./pyproject.toml +$ poetry run pytype -j auto +``` + +Use black and isort to fix formatting. + +```sh +$ poetry run black . +$ poetry run isort . +``` diff --git a/python/railjson_generator/poetry.lock b/python/railjson_generator/poetry.lock index 5647fb0ca4..44ee65efd0 100644 --- a/python/railjson_generator/poetry.lock +++ b/python/railjson_generator/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "annotated-types" @@ -89,6 +89,87 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.3.4" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.3.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:aff2bd3d585969cc4486bfc69655e862028b689404563e6b549e6a8244f226df"}, + {file = "coverage-7.3.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e4353923f38d752ecfbd3f1f20bf7a3546993ae5ecd7c07fd2f25d40b4e54571"}, + {file = "coverage-7.3.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea473c37872f0159294f7073f3fa72f68b03a129799f3533b2bb44d5e9fa4f82"}, + {file = "coverage-7.3.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5214362abf26e254d749fc0c18af4c57b532a4bfde1a057565616dd3b8d7cc94"}, + {file = "coverage-7.3.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f99b7d3f7a7adfa3d11e3a48d1a91bb65739555dd6a0d3fa68aa5852d962e5b1"}, + {file = "coverage-7.3.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:74397a1263275bea9d736572d4cf338efaade2de9ff759f9c26bcdceb383bb49"}, + {file = "coverage-7.3.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:f154bd866318185ef5865ace5be3ac047b6d1cc0aeecf53bf83fe846f4384d5d"}, + {file = "coverage-7.3.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e0d84099ea7cba9ff467f9c6f747e3fc3906e2aadac1ce7b41add72e8d0a3712"}, + {file = "coverage-7.3.4-cp310-cp310-win32.whl", hash = "sha256:3f477fb8a56e0c603587b8278d9dbd32e54bcc2922d62405f65574bd76eba78a"}, + {file = "coverage-7.3.4-cp310-cp310-win_amd64.whl", hash = "sha256:c75738ce13d257efbb6633a049fb2ed8e87e2e6c2e906c52d1093a4d08d67c6b"}, + {file = "coverage-7.3.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:997aa14b3e014339d8101b9886063c5d06238848905d9ad6c6eabe533440a9a7"}, + {file = "coverage-7.3.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8a9c5bc5db3eb4cd55ecb8397d8e9b70247904f8eca718cc53c12dcc98e59fc8"}, + {file = "coverage-7.3.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:27ee94f088397d1feea3cb524e4313ff0410ead7d968029ecc4bc5a7e1d34fbf"}, + {file = "coverage-7.3.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8ce03e25e18dd9bf44723e83bc202114817f3367789052dc9e5b5c79f40cf59d"}, + {file = "coverage-7.3.4-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85072e99474d894e5df582faec04abe137b28972d5e466999bc64fc37f564a03"}, + {file = "coverage-7.3.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a877810ef918d0d345b783fc569608804f3ed2507bf32f14f652e4eaf5d8f8d0"}, + {file = "coverage-7.3.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:9ac17b94ab4ca66cf803f2b22d47e392f0977f9da838bf71d1f0db6c32893cb9"}, + {file = "coverage-7.3.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:36d75ef2acab74dc948d0b537ef021306796da551e8ac8b467810911000af66a"}, + {file = "coverage-7.3.4-cp311-cp311-win32.whl", hash = "sha256:47ee56c2cd445ea35a8cc3ad5c8134cb9bece3a5cb50bb8265514208d0a65928"}, + {file = "coverage-7.3.4-cp311-cp311-win_amd64.whl", hash = "sha256:11ab62d0ce5d9324915726f611f511a761efcca970bd49d876cf831b4de65be5"}, + {file = "coverage-7.3.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:33e63c578f4acce1b6cd292a66bc30164495010f1091d4b7529d014845cd9bee"}, + {file = "coverage-7.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:782693b817218169bfeb9b9ba7f4a9f242764e180ac9589b45112571f32a0ba6"}, + {file = "coverage-7.3.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c4277ddaad9293454da19121c59f2d850f16bcb27f71f89a5c4836906eb35ef"}, + {file = "coverage-7.3.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3d892a19ae24b9801771a5a989fb3e850bd1ad2e2b6e83e949c65e8f37bc67a1"}, + {file = "coverage-7.3.4-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3024ec1b3a221bd10b5d87337d0373c2bcaf7afd86d42081afe39b3e1820323b"}, + {file = "coverage-7.3.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a1c3e9d2bbd6f3f79cfecd6f20854f4dc0c6e0ec317df2b265266d0dc06535f1"}, + {file = "coverage-7.3.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e91029d7f151d8bf5ab7d8bfe2c3dbefd239759d642b211a677bc0709c9fdb96"}, + {file = "coverage-7.3.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:6879fe41c60080aa4bb59703a526c54e0412b77e649a0d06a61782ecf0853ee1"}, + {file = "coverage-7.3.4-cp312-cp312-win32.whl", hash = "sha256:fd2f8a641f8f193968afdc8fd1697e602e199931012b574194052d132a79be13"}, + {file = "coverage-7.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:d1d0ce6c6947a3a4aa5479bebceff2c807b9f3b529b637e2b33dea4468d75fc7"}, + {file = "coverage-7.3.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:36797b3625d1da885b369bdaaa3b0d9fb8865caed3c2b8230afaa6005434aa2f"}, + {file = "coverage-7.3.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bfed0ec4b419fbc807dec417c401499ea869436910e1ca524cfb4f81cf3f60e7"}, + {file = "coverage-7.3.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f97ff5a9fc2ca47f3383482858dd2cb8ddbf7514427eecf5aa5f7992d0571429"}, + {file = "coverage-7.3.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:607b6c6b35aa49defaebf4526729bd5238bc36fe3ef1a417d9839e1d96ee1e4c"}, + {file = "coverage-7.3.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a8e258dcc335055ab59fe79f1dec217d9fb0cdace103d6b5c6df6b75915e7959"}, + {file = "coverage-7.3.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a02ac7c51819702b384fea5ee033a7c202f732a2a2f1fe6c41e3d4019828c8d3"}, + {file = "coverage-7.3.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b710869a15b8caf02e31d16487a931dbe78335462a122c8603bb9bd401ff6fb2"}, + {file = "coverage-7.3.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:c6a23ae9348a7a92e7f750f9b7e828448e428e99c24616dec93a0720342f241d"}, + {file = "coverage-7.3.4-cp38-cp38-win32.whl", hash = "sha256:758ebaf74578b73f727acc4e8ab4b16ab6f22a5ffd7dd254e5946aba42a4ce76"}, + {file = "coverage-7.3.4-cp38-cp38-win_amd64.whl", hash = "sha256:309ed6a559bc942b7cc721f2976326efbfe81fc2b8f601c722bff927328507dc"}, + {file = "coverage-7.3.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:aefbb29dc56317a4fcb2f3857d5bce9b881038ed7e5aa5d3bcab25bd23f57328"}, + {file = "coverage-7.3.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:183c16173a70caf92e2dfcfe7c7a576de6fa9edc4119b8e13f91db7ca33a7923"}, + {file = "coverage-7.3.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a4184dcbe4f98d86470273e758f1d24191ca095412e4335ff27b417291f5964"}, + {file = "coverage-7.3.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93698ac0995516ccdca55342599a1463ed2e2d8942316da31686d4d614597ef9"}, + {file = "coverage-7.3.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fb220b3596358a86361139edce40d97da7458412d412e1e10c8e1970ee8c09ab"}, + {file = "coverage-7.3.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d5b14abde6f8d969e6b9dd8c7a013d9a2b52af1235fe7bebef25ad5c8f47fa18"}, + {file = "coverage-7.3.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:610afaf929dc0e09a5eef6981edb6a57a46b7eceff151947b836d869d6d567c1"}, + {file = "coverage-7.3.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d6ed790728fb71e6b8247bd28e77e99d0c276dff952389b5388169b8ca7b1c28"}, + {file = "coverage-7.3.4-cp39-cp39-win32.whl", hash = "sha256:c15fdfb141fcf6a900e68bfa35689e1256a670db32b96e7a931cab4a0e1600e5"}, + {file = "coverage-7.3.4-cp39-cp39-win_amd64.whl", hash = "sha256:38d0b307c4d99a7aca4e00cad4311b7c51b7ac38fb7dea2abe0d182dd4008e05"}, + {file = "coverage-7.3.4-pp38.pp39.pp310-none-any.whl", hash = "sha256:b1e0f25ae99cf247abfb3f0fac7ae25739e4cd96bf1afa3537827c576b4847e5"}, + {file = "coverage-7.3.4.tar.gz", hash = "sha256:020d56d2da5bc22a0e00a5b0d54597ee91ad72446fa4cf1b97c35022f6b6dbf0"}, +] + +[package.dependencies] +tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} + +[package.extras] +toml = ["tomli"] + +[[package]] +name = "exceptiongroup" +version = "1.2.0" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, + {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, +] + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "flake8" version = "6.0.0" @@ -138,6 +219,17 @@ files = [ [package.dependencies] networkx = ">=2" +[[package]] +name = "iniconfig" +version = "2.0.0" +description = "brain-dead simple config-ini parsing" +optional = false +python-versions = ">=3.7" +files = [ + {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, + {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, +] + [[package]] name = "isort" version = "5.12.0" @@ -373,6 +465,17 @@ pydantic = "2.1.1" type = "directory" url = "../osrd_schemas" +[[package]] +name = "packaging" +version = "23.2" +description = "Core utilities for Python packages" +optional = false +python-versions = ">=3.7" +files = [ + {file = "packaging-23.2-py3-none-any.whl", hash = "sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7"}, + {file = "packaging-23.2.tar.gz", hash = "sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5"}, +] + [[package]] name = "pathspec" version = "0.11.1" @@ -399,6 +502,21 @@ files = [ docs = ["furo (>=2022.12.7)", "proselint (>=0.13)", "sphinx (>=6.1.3)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)"] test = ["appdirs (==1.4.4)", "covdefaults (>=2.2.2)", "pytest (>=7.2.1)", "pytest-cov (>=4)", "pytest-mock (>=3.10)"] +[[package]] +name = "pluggy" +version = "1.3.0" +description = "plugin and hook calling mechanisms for python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pluggy-1.3.0-py3-none-any.whl", hash = "sha256:d89c696a773f8bd377d18e5ecda92b7a3793cbe66c87060a6fb58c7b6e1061f7"}, + {file = "pluggy-1.3.0.tar.gz", hash = "sha256:cf61ae8f126ac6f7c451172cf30e3e43d3ca77615509771b3a984a0730651e12"}, +] + +[package.extras] +dev = ["pre-commit", "tox"] +testing = ["pytest", "pytest-benchmark"] + [[package]] name = "pycnite" version = "2023.10.11" @@ -607,6 +725,46 @@ files = [ flake8 = "6.0.0" tomli = {version = "*", markers = "python_version < \"3.11\""} +[[package]] +name = "pytest" +version = "7.4.3" +description = "pytest: simple powerful testing with Python" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, + {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} +iniconfig = "*" +packaging = "*" +pluggy = ">=0.12,<2.0" +tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""} + +[package.extras] +testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] + +[[package]] +name = "pytest-cov" +version = "4.1.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6"}, + {file = "pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] + [[package]] name = "pytype" version = "2023.10.24" @@ -766,4 +924,4 @@ typing-extensions = ">=3.7.4" [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "f39d4c50cb72650e500f425a415b38b14bc7ebc0dd2bd701e320317d879c45a9" +content-hash = "b212825044907464691c1b60aa66bdb706045bf1f4a8a188f695a7d123bfd1a1" diff --git a/python/railjson_generator/pyproject.toml b/python/railjson_generator/pyproject.toml index a32fbe65fb..95f874514b 100644 --- a/python/railjson_generator/pyproject.toml +++ b/python/railjson_generator/pyproject.toml @@ -7,6 +7,8 @@ authors = ["OSRD "] [tool.poetry.dependencies] python = ">=3.9,<3.12" osrd-schemas = { path = "../osrd_schemas/", develop = false } +pytest = "^7.4.3" +pytest-cov = "^4.1.0" [tool.poetry.group.dev.dependencies] black = "^22.12.0" diff --git a/python/railjson_generator/railjson_generator/external_generated_inputs.py b/python/railjson_generator/railjson_generator/external_generated_inputs.py index 40c2615da9..97d29c28be 100644 --- a/python/railjson_generator/railjson_generator/external_generated_inputs.py +++ b/python/railjson_generator/railjson_generator/external_generated_inputs.py @@ -1,9 +1,11 @@ from dataclasses import dataclass, field +from os import PathLike from typing import List from osrd_schemas import external_generated_inputs from railjson_generator.schema.infra.range_elements import TrackRange +from railjson_generator.schema.infra.track_section import TrackSection @dataclass @@ -12,10 +14,12 @@ class ElectricalProfile: power_class: str track_ranges: List[TrackRange] = field(default_factory=list) - def add_track_range(self, track, begin, end): + def add_track_range(self, track: TrackSection, begin: float, end: float): + """Build a track range and add it to the profile.""" self.track_ranges.append(TrackRange(track=track, begin=begin, end=end)) def to_rjs(self): + """Return the corresponding railjson object.""" return external_generated_inputs.ElectricalProfile( value=self.value, power_class=self.power_class, @@ -28,14 +32,17 @@ class ExternalGeneratedInputs: electrical_profiles: List[ElectricalProfile] = field(default_factory=list) def add_electrical_profile(self, *args, **kwargs) -> ElectricalProfile: + """Build an electrical profile, add it to the inputs, and return it.""" self.electrical_profiles.append(ElectricalProfile(*args, **kwargs)) return self.electrical_profiles[-1] - def save(self, path): + def save(self, path: PathLike): + """Write to the path as railjson.""" with open(path, "w") as f: f.write(self.to_rjs().model_dump_json(indent=2)) def to_rjs(self): + """Return the corresponding railjson `ElectricalProfileSet`.""" return external_generated_inputs.ElectricalProfileSet( levels=[profile.to_rjs() for profile in self.electrical_profiles], level_order={ diff --git a/python/railjson_generator/railjson_generator/infra_builder.py b/python/railjson_generator/railjson_generator/infra_builder.py index 8755bf15d9..029554de0d 100644 --- a/python/railjson_generator/railjson_generator/infra_builder.py +++ b/python/railjson_generator/railjson_generator/infra_builder.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Iterable +from typing import Iterable, Optional, Tuple from .schema.infra.endpoint import TrackEndpoint from .schema.infra.infra import Infra @@ -18,7 +18,7 @@ from .utils import generate_routes -def _check_connections(endpoint, connections): +def _check_connections(endpoint: TrackEndpoint, connections: Iterable[Tuple[TrackEndpoint, Optional[SwitchGroup]]]): switches = [] for connected_endpoint, switch_group in connections: if connected_endpoint == endpoint: @@ -39,7 +39,7 @@ def _check_connections(endpoint, connections): def _register_connection(endpoint_a: TrackEndpoint, endpoint_b: TrackEndpoint, switch_group: SwitchGroup): - """Connect two track endpoints together""" + """Connect two track endpoints together.""" a_neighbors = endpoint_a.get_neighbors() b_neighbors = endpoint_b.get_neighbors() a_neighbors.append((endpoint_b, switch_group)) @@ -56,12 +56,14 @@ class InfraBuilder: infra: Infra = field(default_factory=Infra) - def add_track_section(self, *args, **kwargs): + def add_track_section(self, *args, **kwargs) -> TrackSection: + """Build a track section, add it to the infra, and return it.""" track = TrackSection(index=len(self.infra.track_sections), *args, **kwargs) self.infra.track_sections.append(track) return track - def add_point_switch(self, base: TrackEndpoint, left: TrackEndpoint, right: TrackEndpoint, **kwargs): + def add_point_switch(self, base: TrackEndpoint, left: TrackEndpoint, right: TrackEndpoint, **kwargs) -> PointSwitch: + """Build a point switch, add it to the infra, and return it.""" switch = PointSwitch(A=base, B1=left, B2=right, **kwargs) _register_connection(base, left, switch.group("A_B1")) _register_connection(base, right, switch.group("A_B2")) @@ -70,7 +72,8 @@ def add_point_switch(self, base: TrackEndpoint, left: TrackEndpoint, right: Trac def add_crossing( self, north: TrackEndpoint, south: TrackEndpoint, east: TrackEndpoint, west: TrackEndpoint, **kwargs - ): + ) -> Crossing: + """Build a crossing, add it to the infra, and return it.""" switch = Crossing(A1=north, B1=south, B2=east, A2=west, **kwargs) _register_connection(north, south, switch.group("STATIC")) _register_connection(east, west, switch.group("STATIC")) @@ -79,7 +82,8 @@ def add_crossing( def add_double_slip_switch( self, north_1: TrackEndpoint, north_2: TrackEndpoint, south_1: TrackEndpoint, south_2: TrackEndpoint, **kwargs - ): + ) -> DoubleSlipSwitch: + """Build a double slip switch, add it to the infra, and return it.""" switch = DoubleSlipSwitch(A1=north_1, A2=north_2, B1=south_1, B2=south_2, **kwargs) for (src, dst), group_name in [ ((north_1, south_1), "A1_B1"), @@ -91,21 +95,25 @@ def add_double_slip_switch( self.infra.switches.append(switch) return switch - def add_link(self, source: TrackEndpoint, destination: TrackEndpoint, **kwargs): + def add_link(self, source: TrackEndpoint, destination: TrackEndpoint, **kwargs) -> Link: + """Build a link, add it to the infra, and return it.""" switch = Link(A=source, B=destination, **kwargs) self.infra.switches.append(switch) _register_connection(source, destination, switch.group("STATIC")) return switch - def add_operational_point(self, *args, **kwargs): + def add_operational_point(self, *args, **kwargs) -> OperationalPoint: + """Build an operational point, add it to the infra, and return it.""" self.infra.operational_points.append(OperationalPoint(*args, **kwargs)) return self.infra.operational_points[-1] - def add_speed_section(self, *args, **kwargs): + def add_speed_section(self, *args, **kwargs) -> SpeedSection: + """Build a speed section, add it to the infra, and return it.""" self.infra.speed_sections.append(SpeedSection(*args, **kwargs)) return self.infra.speed_sections[-1] - def add_neutral_section(self, *args, **kwargs): + def add_neutral_section(self, *args, **kwargs) -> NeutralSection: + """Build a neutral section, add it to the infra, and return it.""" self.infra.neutral_sections.append(NeutralSection(*args, **kwargs)) return self.infra.neutral_sections[-1] @@ -119,7 +127,7 @@ def _auto_gen_buffer_stops(self): track.add_buffer_stop(position=track.length) def register_route(self, route: Route): - """Adds a route to the infrastructure""" + """Add a route to the infrastructure.""" self.infra.routes.append(route) def _prepare_infra(self): @@ -131,10 +139,10 @@ def _prepare_infra(self): track.sort_signals() track.sort_waypoints() - def generate_routes(self, progressive_release=True) -> Iterable[Route]: + def generate_routes(self, progressive_release: bool = True) -> Iterable[Route]: """ Generate routes using signaling and detectors. - Route need to be manually registered using register_route. + Routes need to be manually registered using register_route. Buffer stops will be added where missing. Keyword arguments: @@ -143,7 +151,7 @@ def generate_routes(self, progressive_release=True) -> Iterable[Route]: self._prepare_infra() return generate_routes(self.infra, progressive_release) - def build(self, progressive_release=True): + def build(self, progressive_release: bool = True) -> Infra: """Build the RailJSON infrastructure. Routes will be generated if missing.""" self._prepare_infra() diff --git a/python/railjson_generator/railjson_generator/schema/location.py b/python/railjson_generator/railjson_generator/schema/location.py index d63f65372e..0f2e483e31 100644 --- a/python/railjson_generator/railjson_generator/schema/location.py +++ b/python/railjson_generator/railjson_generator/schema/location.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from typing import Any, Dict from railjson_generator.schema.infra.direction import Direction from railjson_generator.schema.infra.track_section import TrackSection @@ -9,7 +10,8 @@ class Location: track_section: TrackSection offset: float - def format(self): + def format(self) -> Dict[str, Any]: + """Return a summary of the location as a dictionary.""" return { "track_section": self.track_section.label, "offset": self.offset, @@ -20,12 +22,14 @@ def format(self): class DirectedLocation(Location): direction: Direction - def format(self): + def format(self) -> Dict[str, Any]: + """Return a summary of the directed location as a dictionary.""" return { **super().format(), "direction": self.direction.name, } @staticmethod - def from_location(location: Location, direction: Direction): + def from_location(location: Location, direction: Direction) -> "DirectedLocation": + """Return a directed location with the given direction.""" return DirectedLocation(location.track_section, location.offset, direction) diff --git a/python/railjson_generator/railjson_generator/schema/test_location.py b/python/railjson_generator/railjson_generator/schema/test_location.py new file mode 100644 index 0000000000..8e2116eb50 --- /dev/null +++ b/python/railjson_generator/railjson_generator/schema/test_location.py @@ -0,0 +1,30 @@ +from railjson_generator.schema.infra.direction import Direction +from railjson_generator.schema.infra.track_section import TrackSection +from railjson_generator.schema.location import DirectedLocation, Location + + +class TestLocation: + def test_format(self): + ts = TrackSection(label="ts", length=1) + location = Location(ts, offset=0) + + assert location.format() == {"track_section": "ts", "offset": 0} + + +class TestDirectedLocation: + def test_from_location(self): + ts = TrackSection(label="ts", length=1) + location = Location(ts, offset=0) + + assert DirectedLocation.from_location(location, Direction.START_TO_STOP) == DirectedLocation( + ts, offset=0, direction=Direction.START_TO_STOP + ) + assert DirectedLocation.from_location(location, Direction.STOP_TO_START) == DirectedLocation( + ts, offset=0, direction=Direction.STOP_TO_START + ) + + def test_format(self): + ts = TrackSection(label="ts", length=1) + location = DirectedLocation(ts, offset=0, direction=Direction.START_TO_STOP) + + assert location.format() == {"track_section": "ts", "offset": 0, "direction": "START_TO_STOP"} diff --git a/python/railjson_generator/railjson_generator/simulation_builder.py b/python/railjson_generator/railjson_generator/simulation_builder.py index 42e5f18c92..ac5a528b8c 100644 --- a/python/railjson_generator/railjson_generator/simulation_builder.py +++ b/python/railjson_generator/railjson_generator/simulation_builder.py @@ -15,7 +15,7 @@ class SimulationBuilder: simulation: Simulation = field(default_factory=Simulation) def add_train_schedule(self, *locations: Location, **kwargs) -> TrainSchedule: - """Creates a train schedule group containing only this train schedule.""" + """Create a train schedule group containing only this train schedule.""" train_schedule = TrainSchedule(**kwargs) self.add_train_schedule_group(locations, train_schedule) return train_schedule @@ -23,7 +23,7 @@ def add_train_schedule(self, *locations: Location, **kwargs) -> TrainSchedule: def add_train_schedule_group( self, locations: Sequence[Union[Location, DirectedLocation]], *train_schedules: TrainSchedule ) -> TrainScheduleGroup: - """Creates a train schedule group containing the given train schedules. + """Create a train schedule group containing the given train schedules. Simple locations are expanded to directed locations in all directions.""" if len(locations) < 2: @@ -39,4 +39,5 @@ def add_train_schedule_group( return train_schedule_group def build(self) -> Simulation: + """Return the simulation object.""" return self.simulation diff --git a/python/railjson_generator/railjson_generator/test_external_generated_inputs.py b/python/railjson_generator/railjson_generator/test_external_generated_inputs.py new file mode 100644 index 0000000000..ba8fc668d0 --- /dev/null +++ b/python/railjson_generator/railjson_generator/test_external_generated_inputs.py @@ -0,0 +1,69 @@ +from osrd_schemas import external_generated_inputs, infra + +from railjson_generator.external_generated_inputs import ( + ElectricalProfile, + ExternalGeneratedInputs, +) +from railjson_generator.schema.infra.range_elements import TrackRange +from railjson_generator.schema.infra.track_section import TrackSection + + +class TestElectricalProfile: + def test_add_track_range(self): + ep = ElectricalProfile(value="dummy", power_class="dummy") + assert ep.track_ranges == [] + track = TrackSection(length=1) + + ep.add_track_range(track=track, begin=0, end=1) + + assert ep.track_ranges == [TrackRange(track=track, begin=0, end=1)] + + def test_to_rjs(self): + ep = ElectricalProfile(value="value", power_class="power_class") + track = TrackSection(length=1) + ep.add_track_range(track=track, begin=0, end=1) + + assert ep.to_rjs() == external_generated_inputs.ElectricalProfile( + value="value", power_class="power_class", track_ranges=[infra.TrackRange(track=track.id, begin=0, end=1)] + ) + + +class TestExternalGeneratedInputs: + def test_add_electrical_profile(self): + egi = ExternalGeneratedInputs() + assert egi.electrical_profiles == [] + + egi.add_electrical_profile(value="value", power_class="power_class") + + assert egi.electrical_profiles == [ElectricalProfile(value="value", power_class="power_class")] + + def test_to_rjs(self): + egi = ExternalGeneratedInputs() + egi.add_electrical_profile(value="value", power_class="power_class") + track = TrackSection(length=1) + egi.electrical_profiles[0].add_track_range(track=track, begin=0, end=1) + + assert egi.to_rjs() == external_generated_inputs.ElectricalProfileSet( + levels=[ + external_generated_inputs.ElectricalProfile( + value="value", + power_class="power_class", + track_ranges=[infra.TrackRange(track=track.id, begin=0, end=1)], + ) + ], + level_order={"25000": ["25000", "22500", "20000"]}, + ) + + def test_save(self, tmp_path): + import json + + egi = ExternalGeneratedInputs() + egi.add_electrical_profile(value="value", power_class="power_class") + track = TrackSection(length=1) + egi.electrical_profiles[0].add_track_range(track=track, begin=0, end=1) + path = tmp_path / "test_external_generated_inputs_test_save.json" + + egi.save(path) + + with open(path) as f: + assert external_generated_inputs.ElectricalProfileSet(**json.load(f)) == egi.to_rjs() diff --git a/python/railjson_generator/railjson_generator/test_infra_builder.py b/python/railjson_generator/railjson_generator/test_infra_builder.py new file mode 100644 index 0000000000..ad5ff1c106 --- /dev/null +++ b/python/railjson_generator/railjson_generator/test_infra_builder.py @@ -0,0 +1,339 @@ +import pytest + +from railjson_generator.infra_builder import InfraBuilder +from railjson_generator.schema.infra.direction import Direction +from railjson_generator.schema.infra.neutral_section import NeutralSection +from railjson_generator.schema.infra.operational_point import OperationalPoint +from railjson_generator.schema.infra.route import Route +from railjson_generator.schema.infra.speed_section import SpeedSection +from railjson_generator.schema.infra.switch import ( + Crossing, + DoubleSlipSwitch, + Link, + PointSwitch, +) +from railjson_generator.schema.infra.track_section import TrackSection +from railjson_generator.schema.infra.waypoint import BufferStop, Detector + + +class TestInfraBuilder: + def test_add_track_section(self): + ib = InfraBuilder() + assert ib.infra.track_sections == [] + + track = ib.add_track_section(length=1) + + assert track == TrackSection(index=0, length=1, label=track.label) + assert ib.infra.track_sections == [track] + + def test_add_point_switch(self): + ib = InfraBuilder() + assert ib.infra.switches == [] + # x y + # =========o====== + # \===== + # z + x = ib.add_track_section(length=1) + base = x.end() + y = ib.add_track_section(length=1) + left = y.begin() + z = ib.add_track_section(length=1) + right = z.begin() + + switch = ib.add_point_switch(base, left, right) + + assert switch == PointSwitch(A=base, B1=left, B2=right, label=switch.label) + assert ib.infra.switches == [switch] + assert base.get_neighbors() == [(left, switch.group("A_B1")), (right, switch.group("A_B2"))] + assert left.get_neighbors() == [(base, switch.group("A_B1"))] + assert right.get_neighbors() == [(base, switch.group("A_B2"))] + + def test_add_crossing(self): + ib = InfraBuilder() + assert ib.infra.switches == [] + # w y + # ========\ /======= + # o + # ========/ \======= + # x z + w = ib.add_track_section(length=1) + north = w.end() + x = ib.add_track_section(length=1) + west = x.end() + y = ib.add_track_section(length=1) + east = y.begin() + z = ib.add_track_section(length=1) + south = z.begin() + + crossing = ib.add_crossing(north, south, east, west) + + assert crossing == Crossing(A1=north, B1=south, B2=east, A2=west, label=crossing.label) + assert ib.infra.switches == [crossing] + assert north.get_neighbors() == [(south, crossing.group("STATIC"))] + assert south.get_neighbors() == [(north, crossing.group("STATIC"))] + assert east.get_neighbors() == [(west, crossing.group("STATIC"))] + assert west.get_neighbors() == [(east, crossing.group("STATIC"))] + + def test_add_double_slip_switch(self): + ib = InfraBuilder() + assert ib.infra.switches == [] + # Here, we can one can go ahead OR "turn". + # w y + # ================== + # o + # ================== + # x z + w = ib.add_track_section(length=1) + north_1 = w.end() + x = ib.add_track_section(length=1) + south_1 = x.end() + y = ib.add_track_section(length=1) + north_2 = y.begin() + z = ib.add_track_section(length=1) + south_2 = z.begin() + + switch = ib.add_double_slip_switch(north_1, north_2, south_1, south_2) + + assert switch == DoubleSlipSwitch(A1=north_1, A2=north_2, B1=south_1, B2=south_2, label=switch.label) + assert ib.infra.switches == [switch] + assert north_1.get_neighbors() == [(south_1, switch.group("A1_B1")), (south_2, switch.group("A1_B2"))] + assert north_2.get_neighbors() == [(south_1, switch.group("A2_B1")), (south_2, switch.group("A2_B2"))] + assert south_1.get_neighbors() == [(north_1, switch.group("A1_B1")), (north_2, switch.group("A2_B1"))] + assert south_2.get_neighbors() == [(north_1, switch.group("A1_B2")), (north_2, switch.group("A2_B2"))] + + def test_add_link(self): + ib = InfraBuilder() + assert ib.infra.switches == [] + # x y + # ========o======== + x = ib.add_track_section(length=1) + source = x.end() + y = ib.add_track_section(length=1) + destination = y.begin() + + link = ib.add_link(source, destination) + + assert link == Link(A=source, B=destination, label=link.label) + assert ib.infra.switches == [link] + assert source.get_neighbors() == [(destination, link.group("STATIC"))] + assert destination.get_neighbors() == [(source, link.group("STATIC"))] + + def test_add_operational_point(self): + ib = InfraBuilder() + assert ib.infra.operational_points == [] + + op = ib.add_operational_point(label="label") + + assert op == OperationalPoint(label="label") + assert ib.infra.operational_points == [op] + + def test_add_speed_section(self): + ib = InfraBuilder() + assert ib.infra.speed_sections == [] + + ss = ib.add_speed_section(speed_limit=1) + + assert ss == SpeedSection(label=ss.label, speed_limit=1) + assert ib.infra.speed_sections == [ss] + + def test_add_neutral_section(self): + ib = InfraBuilder() + assert ib.infra.neutral_sections == [] + + ns = ib.add_neutral_section() + + assert ns == NeutralSection(label=ns.label) + assert ib.infra.neutral_sections == [ns] + + def test_register_route(self): + ib = InfraBuilder() + assert ib.infra.routes == [] + start = BufferStop(position=0) + stop = BufferStop(position=1) + route = Route(entry_point_direction=Direction.START_TO_STOP, waypoints=[start, stop], release_waypoints=[]) + + ib.register_route(route) + + # Be careful, Route.__eq__ only compares labels. + assert ib.infra.routes == [route] + + def test_generate_routes(self): + ib = InfraBuilder() + assert ib.infra.routes == [] + # Possible routes are x->y, x->z, and mirrors. + # x y + # ======d==o====== + # \===== + # z + detector = Detector(position=0.5) + x = ib.add_track_section(length=1, waypoints=[detector]) + base = x.end() + y = ib.add_track_section(length=1) + left = y.begin() + z = ib.add_track_section(length=1) + right = z.begin() + switch = ib.add_point_switch(base, left, right) + + routes = ib.generate_routes(progressive_release=True) + + assert x.waypoints[0] == BufferStop(position=0, label=x.waypoints[0].label) + assert y.waypoints[-1] == BufferStop(position=1, label=y.waypoints[-1].label) + assert z.waypoints[-1] == BufferStop(position=1, label=z.waypoints[-1].label) + xy = Route( + entry_point_direction=Direction.START_TO_STOP, + waypoints=[x.waypoints[0], y.waypoints[-1]], + release_waypoints=[detector], + switches_directions={switch.label: "A_B1"}, + label=f"rt.{x.waypoints[0].label}->{y.waypoints[-1].label}", + ).to_rjs() + yx = Route( + entry_point_direction=Direction.STOP_TO_START, + waypoints=[y.waypoints[-1], x.waypoints[0]], + release_waypoints=[detector], + switches_directions={switch.label: "A_B1"}, + label=f"rt.{y.waypoints[-1].label}->{x.waypoints[0].label}", + ).to_rjs() + xz = Route( + entry_point_direction=Direction.START_TO_STOP, + waypoints=[x.waypoints[0], z.waypoints[-1]], + release_waypoints=[detector], + switches_directions={switch.label: "A_B2"}, + label=f"rt.{x.waypoints[0].label}->{z.waypoints[-1].label}", + ).to_rjs() + zx = Route( + entry_point_direction=Direction.STOP_TO_START, + waypoints=[z.waypoints[-1], x.waypoints[0]], + release_waypoints=[detector], + switches_directions={switch.label: "A_B2"}, + label=f"rt.{z.waypoints[-1].label}->{x.waypoints[0].label}", + ).to_rjs() + # Route.__eq__ only compares labels, so let's compare resulting rjs instead. + routes = [route.to_rjs() for route in routes] + assert len(routes) == 4 + assert xy in routes + assert yx in routes + assert xz in routes + assert zx in routes + + def test_generate_routes_without_release(self): + ib = InfraBuilder() + assert ib.infra.routes == [] + # Possible routes are x->y, x->z, and mirrors. + # x y + # ======d==o====== + # \===== + # z + detector = Detector(position=0.5) + x = ib.add_track_section(length=1, waypoints=[detector]) + base = x.end() + y = ib.add_track_section(length=1) + left = y.begin() + z = ib.add_track_section(length=1) + right = z.begin() + switch = ib.add_point_switch(base, left, right) + + routes = ib.generate_routes(progressive_release=False) + + assert x.waypoints[0] == BufferStop(position=0, label=x.waypoints[0].label) + assert y.waypoints[-1] == BufferStop(position=1, label=y.waypoints[-1].label) + assert z.waypoints[-1] == BufferStop(position=1, label=z.waypoints[-1].label) + xy = Route( + entry_point_direction=Direction.START_TO_STOP, + waypoints=[x.waypoints[0], y.waypoints[-1]], + release_waypoints=[], + switches_directions={switch.label: "A_B1"}, + label=f"rt.{x.waypoints[0].label}->{y.waypoints[-1].label}", + ).to_rjs() + yx = Route( + entry_point_direction=Direction.STOP_TO_START, + waypoints=[y.waypoints[-1], x.waypoints[0]], + release_waypoints=[], + switches_directions={switch.label: "A_B1"}, + label=f"rt.{y.waypoints[-1].label}->{x.waypoints[0].label}", + ).to_rjs() + xz = Route( + entry_point_direction=Direction.START_TO_STOP, + waypoints=[x.waypoints[0], z.waypoints[-1]], + release_waypoints=[], + switches_directions={switch.label: "A_B2"}, + label=f"rt.{x.waypoints[0].label}->{z.waypoints[-1].label}", + ).to_rjs() + zx = Route( + entry_point_direction=Direction.STOP_TO_START, + waypoints=[z.waypoints[-1], x.waypoints[0]], + release_waypoints=[], + switches_directions={switch.label: "A_B2"}, + label=f"rt.{z.waypoints[-1].label}->{x.waypoints[0].label}", + ).to_rjs() + # Route.__eq__ only compares labels, so let's compare resulting rjs instead. + routes = [route.to_rjs() for route in routes] + assert len(routes) == 4 + assert xy in routes + assert yx in routes + assert xz in routes + assert zx in routes + + def test_build_with_duplicates(self): + ib = InfraBuilder() + track = ib.add_track_section(length=1) + ib.add_track_section(length=1, label=track.label) + + with pytest.raises(ValueError, match="Duplicates found"): + ib.build() + + def test_build(self): + ib = InfraBuilder() + # Possible routes are x->y, x->z, and mirrors. + # x y + # ======d==o====== + # \===== + # z + detector = Detector(position=0.5) + x = ib.add_track_section(length=1, waypoints=[detector]) + base = x.end() + y = ib.add_track_section(length=1) + left = y.begin() + z = ib.add_track_section(length=1) + right = z.begin() + switch = ib.add_point_switch(base, left, right) + + infra = ib.build() + + assert x.waypoints[0] == BufferStop(position=0, label=x.waypoints[0].label) + assert y.waypoints[-1] == BufferStop(position=1, label=y.waypoints[-1].label) + assert z.waypoints[-1] == BufferStop(position=1, label=z.waypoints[-1].label) + xy = Route( + entry_point_direction=Direction.START_TO_STOP, + waypoints=[x.waypoints[0], y.waypoints[-1]], + release_waypoints=[detector], + switches_directions={switch.label: "A_B1"}, + label=f"rt.{x.waypoints[0].label}->{y.waypoints[-1].label}", + ).to_rjs() + yx = Route( + entry_point_direction=Direction.STOP_TO_START, + waypoints=[y.waypoints[-1], x.waypoints[0]], + release_waypoints=[detector], + switches_directions={switch.label: "A_B1"}, + label=f"rt.{y.waypoints[-1].label}->{x.waypoints[0].label}", + ).to_rjs() + xz = Route( + entry_point_direction=Direction.START_TO_STOP, + waypoints=[x.waypoints[0], z.waypoints[-1]], + release_waypoints=[detector], + switches_directions={switch.label: "A_B2"}, + label=f"rt.{x.waypoints[0].label}->{z.waypoints[-1].label}", + ).to_rjs() + zx = Route( + entry_point_direction=Direction.STOP_TO_START, + waypoints=[z.waypoints[-1], x.waypoints[0]], + release_waypoints=[detector], + switches_directions={switch.label: "A_B2"}, + label=f"rt.{z.waypoints[-1].label}->{x.waypoints[0].label}", + ).to_rjs() + # Route.__eq__ only compares labels, so let's compare resulting rjs instead. + routes = [route.to_rjs() for route in infra.routes] + assert len(routes) == 4 + assert xy in routes + assert yx in routes + assert xz in routes + assert zx in routes diff --git a/python/railjson_generator/railjson_generator/test_simulation_builder.py b/python/railjson_generator/railjson_generator/test_simulation_builder.py new file mode 100644 index 0000000000..2b0cc6bb3d --- /dev/null +++ b/python/railjson_generator/railjson_generator/test_simulation_builder.py @@ -0,0 +1,91 @@ +import pytest + +from railjson_generator.schema.infra.direction import Direction +from railjson_generator.schema.infra.track_section import TrackSection +from railjson_generator.schema.location import DirectedLocation, Location +from railjson_generator.schema.simulation.train_schedule import ( + TrainSchedule, + TrainScheduleGroup, +) +from railjson_generator.simulation_builder import SimulationBuilder + + +class TestSimulationBuilder: + def test_add_train_schedule(self): + sb = SimulationBuilder() + assert sb.simulation.train_schedule_groups == [] + ts = TrackSection(length=1) + location1 = Location(ts, offset=0) + location2 = Location(ts, offset=1) + + train_schedule = sb.add_train_schedule(location1, location2) + + assert train_schedule == TrainSchedule(label=train_schedule.label) + assert sb.simulation.train_schedule_groups == [ + TrainScheduleGroup( + schedules=[train_schedule], + waypoints=[ + [ + DirectedLocation.from_location(location1, Direction.START_TO_STOP), + DirectedLocation.from_location(location1, Direction.STOP_TO_START), + ], + [ + DirectedLocation.from_location(location2, Direction.START_TO_STOP), + DirectedLocation.from_location(location2, Direction.STOP_TO_START), + ], + ], + id=sb.simulation.train_schedule_groups[0].id, + ) + ] + + def test_add_train_schedule_group_missing_location(self): + sb = SimulationBuilder() + locations = [Location(TrackSection(length=1), offset=0)] + ts = TrainSchedule() + + with pytest.raises(ValueError, match="Expected at least 2 locations, got 1"): + sb.add_train_schedule_group(locations, ts) + + def test_add_train_schedule_group(self): + sb = SimulationBuilder() + assert sb.simulation.train_schedule_groups == [] + ts = TrackSection(length=1) + location1 = DirectedLocation(ts, offset=0, direction=Direction.START_TO_STOP) + location2 = DirectedLocation(ts, offset=1, direction=Direction.STOP_TO_START) + ts = TrainSchedule() + + train_schedule_group = sb.add_train_schedule_group([location1, location2], ts) + + tsg = TrainScheduleGroup( + schedules=[ts], + waypoints=[[location1], [location2]], + id=train_schedule_group.id, + ) + assert train_schedule_group == tsg + assert sb.simulation.train_schedule_groups == [tsg] + + def test_build(self): + sb = SimulationBuilder() + ts = TrackSection(length=1) + location1 = Location(ts, offset=0) + location2 = Location(ts, offset=1) + train_schedule = sb.add_train_schedule(location1, location2) + + simulation = sb.build() + + assert simulation.train_schedule_groups == [ + TrainScheduleGroup( + schedules=[train_schedule], + waypoints=[ + [ + DirectedLocation.from_location(location1, Direction.START_TO_STOP), + DirectedLocation.from_location(location1, Direction.STOP_TO_START), + ], + [ + DirectedLocation.from_location(location2, Direction.START_TO_STOP), + DirectedLocation.from_location(location2, Direction.STOP_TO_START), + ], + ], + id=sb.simulation.train_schedule_groups[0].id, + ) + ] diff --git a/python/railjson_generator/railjson_generator/utils/routes_generator.py b/python/railjson_generator/railjson_generator/utils/routes_generator.py index d74fca347d..192333b7f6 100644 --- a/python/railjson_generator/railjson_generator/utils/routes_generator.py +++ b/python/railjson_generator/railjson_generator/utils/routes_generator.py @@ -15,8 +15,8 @@ from railjson_generator.schema.infra.waypoint import BufferStop, Waypoint -def follow_track_link(connections) -> Optional[TrackEndpoint]: - """Follow a track link. If there is no track link on this endpoint, returns None""" +def follow_track_link(connections: List[Tuple[TrackEndpoint, Optional[SwitchGroup]]]) -> Optional[TrackEndpoint]: + """Follow a track link. If there is no track link on this endpoint, return None.""" if not connections: return None (endpoint, switch_group) = connections[0] @@ -28,7 +28,7 @@ def follow_track_link(connections) -> Optional[TrackEndpoint]: def _explore_signals( track: TrackSection, det_i: Optional[int], signal_direction: Direction ) -> Iterable[Tuple[TrackSection, Signal]]: - """Find signals which are associated with a given detector""" + """Find signals which are associated with a given detector.""" signal_iterator = reversed(track.signals) if signal_direction == Direction.START_TO_STOP else iter(track.signals) pos_filter: Callable[[float], bool] @@ -39,19 +39,19 @@ def _explore_signals( if not track.waypoints: continue_exploring = True - def pos_filter(pos): + def pos_filter(pos: float) -> bool: return True else: continue_exploring = False if signal_direction == Direction.START_TO_STOP: - def pos_filter(pos): + def pos_filter(pos: float) -> bool: return pos > track.waypoints[-1].position else: - def pos_filter(pos): + def pos_filter(pos: float) -> bool: return pos < track.waypoints[0].position else: @@ -65,7 +65,7 @@ def pos_filter(pos): continue_exploring = False prev_waypoint_pos = track.waypoints[det_i - 1].position - def pos_filter(pos): + def pos_filter(pos: float) -> bool: return prev_waypoint_pos < pos <= waypoint_pos else: @@ -76,7 +76,7 @@ def pos_filter(pos): continue_exploring = False prev_waypoint_pos = track.waypoints[det_i + 1].position - def pos_filter(pos): + def pos_filter(pos: float) -> bool: return waypoint_pos <= pos < prev_waypoint_pos # explore the signals in range on the track @@ -107,7 +107,7 @@ class DetectorProps: decr_signals: List[Signal] -def find_detector_properties(infra: Infra): +def find_detector_properties(infra: Infra) -> Dict[str, DetectorProps]: det_props: Dict[str, DetectorProps] = {} for track in infra.track_sections: for det_i, det in enumerate(track.waypoints): @@ -128,11 +128,11 @@ class ZonePath: switches_directions: Dict[str, str] = field(default_factory=dict) @property - def entry(self): + def entry(self) -> Tuple[str, Direction]: return (self.entry_det.label, self.entry_dir) @property - def exit(self): + def exit(self) -> Tuple[str, Direction]: return (self.exit_det.label, self.exit_dir) @@ -155,7 +155,7 @@ def build(self, entry_det: Waypoint, entry_dir: Direction, exit_det: Waypoint, e def search_zone_paths(infra: Infra) -> List[ZonePath]: - """Enumerate all possible paths inside zones""" + """Enumerate all possible paths inside zones.""" res = [] for track in infra.track_sections: @@ -164,9 +164,7 @@ def search_zone_paths(infra: Infra) -> List[ZonePath]: waypoints = track.waypoints # create paths between inner waypoints - for i in range(len(waypoints) - 1): - cur_waypoint = waypoints[i] - next_waypoint = waypoints[i + 1] + for cur_waypoint, next_waypoint in zip(waypoints, waypoints[1:]): res.append( ZonePath( entry_det=cur_waypoint, @@ -222,7 +220,7 @@ class IncompleteRoute: switches_directions: Dict[str, str] @staticmethod - def from_zonepath(zone_path: ZonePath): + def from_zonepath(zone_path: ZonePath) -> "IncompleteRoute": return IncompleteRoute(path=[zone_path], switches_directions={**zone_path.switches_directions}) def fork(self, new_zone_path: ZonePath) -> Optional["IncompleteRoute"]: @@ -232,13 +230,13 @@ def fork(self, new_zone_path: ZonePath) -> Optional["IncompleteRoute"]: new_switches_directions = {**self.switches_directions, **new_zone_path.switches_directions} return IncompleteRoute(path=new_path, switches_directions=new_switches_directions) - def dir_waypoints(self): + def dir_waypoints(self) -> List[Tuple[Waypoint, Direction]]: return [ (self.path[0].entry_det, self.path[0].entry_dir), *((zone_path.exit_det, zone_path.exit_dir) for zone_path in self.path), ] - def waypoints(self): + def waypoints(self) -> List[Waypoint]: return [waypoint for waypoint, _ in self.dir_waypoints()] @@ -247,7 +245,7 @@ def generate_route_paths(det_props: Dict[str, DetectorProps], zone_paths: List[Z for zone_path in zone_paths: graph[zone_path.entry].append(zone_path) - def is_route_delim(det, direction): + def is_route_delim(det: Waypoint, direction: Direction) -> bool: if isinstance(det, BufferStop): return True props = det_props[det.label] @@ -280,7 +278,7 @@ def is_route_delim(det, direction): incomplete_routes.append(new_route) -def generate_routes(infra: Infra, progressive_release=True) -> Iterable[Route]: +def generate_routes(infra: Infra, progressive_release: bool = True) -> Iterable[Route]: det_props = find_detector_properties(infra) zone_paths = search_zone_paths(infra) diff --git a/python/railjson_generator/railjson_generator/utils/test_routes_generator.py b/python/railjson_generator/railjson_generator/utils/test_routes_generator.py new file mode 100644 index 0000000000..414d0f3eda --- /dev/null +++ b/python/railjson_generator/railjson_generator/utils/test_routes_generator.py @@ -0,0 +1,428 @@ +from railjson_generator.infra_builder import InfraBuilder +from railjson_generator.schema.infra.direction import Direction +from railjson_generator.schema.infra.endpoint import Endpoint, TrackEndpoint +from railjson_generator.schema.infra.infra import Infra +from railjson_generator.schema.infra.route import Route +from railjson_generator.schema.infra.signal import Signal +from railjson_generator.schema.infra.switch import Switch, SwitchGroup +from railjson_generator.schema.infra.track_section import TrackSection +from railjson_generator.schema.infra.waypoint import BufferStop, Detector +from railjson_generator.utils.routes_generator import ( + DetectorProps, + IncompleteRoute, + ZonePath, + ZonePathStep, + find_detector_properties, + follow_track_link, + generate_route_paths, + generate_routes, + search_zone_paths, +) + + +def test_follow_track_link_empty(): + assert follow_track_link([]) is None + + +def test_follow_track_link_switch(): + ts = TrackSection(length=1) + endpoint = Endpoint.BEGIN + te = TrackEndpoint(ts, endpoint) + switch = Switch() + switch_group = SwitchGroup(switch, "group") + + next_te = follow_track_link([(te, switch_group)]) + + assert next_te is None + + +def test_follow_track_link(): + ts = TrackSection(length=1) + endpoint = Endpoint.BEGIN + te = TrackEndpoint(ts, endpoint) + + next_te = follow_track_link([(te, None)]) + + assert next_te == te + + +def test_find_detector_properties_empty(): + infra = Infra() + + assert find_detector_properties(infra) == {} + + +def test_find_detector_properties(): + # bs0======det======bs1 + # s025> y, x->z, and mirrors. + # x y + # =========o====== + # \===== + # z + x = ib.add_track_section(length=1) + base = x.end() + y = ib.add_track_section(length=1) + left = y.begin() + z = ib.add_track_section(length=1) + right = z.begin() + switch = ib.add_point_switch(base, left, right) + infra = ib.build() + zps = search_zone_paths(infra) + dps = find_detector_properties(infra) + + routes = list(generate_route_paths(dps, zps)) + + assert len(routes) == 4 + xy = IncompleteRoute( + path=[ + ZonePath( + x.waypoints[0], + Direction.START_TO_STOP, + y.waypoints[-1], + Direction.START_TO_STOP, + {switch.label: "A_B1"}, + ) + ], + switches_directions={switch.label: "A_B1"}, + ) + assert xy in routes + yx = IncompleteRoute( + path=[ + ZonePath( + y.waypoints[-1], + Direction.STOP_TO_START, + x.waypoints[0], + Direction.STOP_TO_START, + {switch.label: "A_B1"}, + ) + ], + switches_directions={switch.label: "A_B1"}, + ) + assert yx in routes + xz = IncompleteRoute( + path=[ + ZonePath( + x.waypoints[0], + Direction.START_TO_STOP, + z.waypoints[-1], + Direction.START_TO_STOP, + {switch.label: "A_B2"}, + ) + ], + switches_directions={switch.label: "A_B2"}, + ) + assert xz in routes + zx = IncompleteRoute( + path=[ + ZonePath( + z.waypoints[-1], + Direction.STOP_TO_START, + x.waypoints[0], + Direction.STOP_TO_START, + {switch.label: "A_B2"}, + ) + ], + switches_directions={switch.label: "A_B2"}, + ) + assert zx in routes + + +def test_generate_routes_without_progressive_release(): + ib = InfraBuilder() + # Possible routes are x->y, x->z, and mirrors. + # x y + # =========o====== + # \===== + # z + x = ib.add_track_section(length=1) + base = x.end() + y = ib.add_track_section(length=1) + left = y.begin() + z = ib.add_track_section(length=1) + right = z.begin() + switch = ib.add_point_switch(base, left, right) + infra = ib.build() + + routes = list(generate_routes(infra, progressive_release=False)) + + assert len(routes) == 4 + # Careful, Route.__eq__ only compares labels. + routes_rjs = [r.to_rjs() for r in routes] + xy = Route( + waypoints=[x.waypoints[0], y.waypoints[-1]], + release_waypoints=[], + entry_point_direction=Direction.START_TO_STOP, + switches_directions={switch.label: "A_B1"}, + ).to_rjs() + assert xy in routes_rjs + yx = Route( + waypoints=[y.waypoints[-1], x.waypoints[0]], + release_waypoints=[], + entry_point_direction=Direction.STOP_TO_START, + switches_directions={switch.label: "A_B1"}, + ).to_rjs() + assert yx in routes_rjs + xz = Route( + waypoints=[x.waypoints[0], z.waypoints[-1]], + release_waypoints=[], + entry_point_direction=Direction.START_TO_STOP, + switches_directions={switch.label: "A_B2"}, + ).to_rjs() + assert xz in routes_rjs + zx = Route( + waypoints=[z.waypoints[-1], x.waypoints[0]], + release_waypoints=[], + entry_point_direction=Direction.STOP_TO_START, + switches_directions={switch.label: "A_B2"}, + ).to_rjs() + assert zx in routes_rjs + + +def test_generate_routes(): + ib = InfraBuilder() + # Possible routes are x->z, and z->x. + # x y z + # ======d1=o=====d2=o======= + x = ib.add_track_section(length=1) + x.add_detector(position=0.75) + y = ib.add_track_section(length=1) + y.add_detector(position=0.75) + link_xy = ib.add_link(x.end(), y.begin()) + z = ib.add_track_section(length=1) + link_yz = ib.add_link(y.end(), z.begin()) + ib.build() + + routes = list(generate_routes(ib.infra, progressive_release=True)) + + assert len(routes) == 2 + # Careful, Route.__eq__ only compares labels. + routes_rjs = [r.to_rjs() for r in routes] + xz = Route( + waypoints=[x.waypoints[0], z.waypoints[-1]], + release_waypoints=[x.waypoints[-1], y.waypoints[-1]], + entry_point_direction=Direction.START_TO_STOP, + switches_directions={link_xy.label: "STATIC", link_yz.label: "STATIC"}, + ).to_rjs() + assert xz in routes_rjs + zx = Route( + waypoints=[z.waypoints[-1], x.waypoints[0]], + release_waypoints=[y.waypoints[-1], x.waypoints[-1]], + entry_point_direction=Direction.STOP_TO_START, + switches_directions={link_xy.label: "STATIC", link_yz.label: "STATIC"}, + ).to_rjs() + assert zx in routes_rjs