diff --git a/pyproject.toml b/pyproject.toml index 5145c2bfe..31c733ec3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ GitHub = "https://github.com/DiamondLightSource/python-murfey" "murfey.db_sql" = "murfey.cli.murfey_db_sql:run" "murfey.decrypt_password" = "murfey.cli.decrypt_db_password:run" "murfey.generate_key" = "murfey.cli.generate_crypto_key:run" +"murfey.generate_openapi_schema" = "murfey.cli.generate_openapi_schema:run" "murfey.generate_password" = "murfey.cli.generate_db_password:run" "murfey.generate_route_manifest" = "murfey.cli.generate_route_manifest:run" "murfey.instrument_server" = "murfey.instrument_server:run" diff --git a/src/murfey/cli/__init__.py b/src/murfey/cli/__init__.py index e69de29bb..a6c338cbd 100644 --- a/src/murfey/cli/__init__.py +++ b/src/murfey/cli/__init__.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +import argparse +import re +import textwrap + +import yaml + + +class LineWrapHelpFormatter(argparse.RawDescriptionHelpFormatter): + """ + A helper class for formatting the help messages the CLIs nicely. This implementation + will preserve indents at the start of a line and interpret newline metacharacters + accordingly. + + Credits: https://stackoverflow.com/a/35925919 + """ + + def _add_whitespace(self, idx, wspace_idx, text): + if idx == 0: + return text + return (" " * wspace_idx) + text + + def _split_lines(self, text, width): + text_rows = text.splitlines() + for idx, line in enumerate(text_rows): + search = re.search(r"\s*[0-9\-]{0,}\.?\s*", line) + if line.strip() == "": + text_rows[idx] = " " + elif search: + wspace_line = search.end() + lines = [ + self._add_whitespace(i, wspace_line, x) + for i, x in enumerate(textwrap.wrap(line, width)) + ] + text_rows[idx] = lines + return [item for sublist in text_rows for item in sublist] + + +class PrettierDumper(yaml.Dumper): + """ + Custom YAML Dumper class that sets `indentless` to False. This generates a YAML + file that is then compliant with Prettier's formatting style + """ + + def increase_indent(self, flow=False, indentless=False): + # Force 'indentless=False' so list items align with Prettier + return super(PrettierDumper, self).increase_indent(flow, indentless=False) + + +def prettier_str_representer(dumper, data): + """ + Helper function to format strings according to Prettier's standards: + - No quoting unless it can be misinterpreted as another data type + - When quoting, use double quotes unless string already contains double quotes + """ + + def is_implicitly_resolved(value: str) -> bool: + for ( + first_char, + resolvers, + ) in yaml.resolver.Resolver.yaml_implicit_resolvers.items(): + if first_char is None or (value and value[0] in first_char): + for resolver in resolvers: + if len(resolver) == 3: + _, regexp, _ = resolver + else: + _, regexp = resolver + if regexp.match(value): + return True + return False + + # If no quoting is needed, use default plain style + if not is_implicitly_resolved(data): + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + # If the string already contains double quotes, fall back to single quotes + if '"' in data and "'" not in data: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="'") + + # Otherwise, prefer double quotes + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"') + + +# Add the custom string representer to PrettierDumper +PrettierDumper.add_representer(str, prettier_str_representer) diff --git a/src/murfey/cli/generate_openapi_schema.py b/src/murfey/cli/generate_openapi_schema.py new file mode 100644 index 000000000..ed12fcdf5 --- /dev/null +++ b/src/murfey/cli/generate_openapi_schema.py @@ -0,0 +1,123 @@ +import contextlib +import io +import json +from argparse import ArgumentParser +from pathlib import Path + +import yaml +from fastapi.openapi.utils import get_openapi + +import murfey +from murfey.cli import LineWrapHelpFormatter, PrettierDumper + + +def run(): + # Set up argument parser + parser = ArgumentParser( + description=( + "Generates an OpenAPI schema of the chosen FastAPI server " + "and outputs it as either a JSON or YAML file" + ), + formatter_class=LineWrapHelpFormatter, + ) + parser.add_argument( + "--target", + "-t", + default="server", + help=( + "The target FastAPI server to construct the OpenAPI schema for. \n" + "OPTIONS: instrument-server | server \n" + "DEFAULT: server" + ), + ) + parser.add_argument( + "--output", + "-o", + default="yaml", + help=( + "Set the output format of the OpenAPI schema. \n" + "OPTIONS: json | yaml \n" + "DEFAULT: yaml" + ), + ) + parser.add_argument( + "--to-file", + "-f", + default="", + help=( + "Alternative file path and file name to save the schema as. " + "Can be a relative or absolute path. \n" + "By default, the schema will be saved to 'murfey/utils/', " + "and it will have the name 'openapi.json' or 'openapi.yaml'." + ), + ) + parser.add_argument( + "--debug", + action="store_true", + default=False, + help="Shows additional steps when setting ", + ) + args = parser.parse_args() + + # Load the relevant FastAPI app + target = str(args.target).lower() + + # Silence output during import; only return genuine errors + buffer = io.StringIO() + with contextlib.redirect_stdout(buffer), contextlib.redirect_stderr(buffer): + if target == "server": + from murfey.server.main import app + elif target == "instrument-server": + from murfey.instrument_server.main import app + else: + raise ValueError( + "Unexpected value for target server. It must be one of " + "'instrument-server' or 'server'" + ) + if args.debug: + print(f"Imported FastAPI app for {target}") + + if not app.openapi_schema: + schema = get_openapi( + title=app.title, + version=app.version, + openapi_version=app.openapi_version, + description=app.description, + routes=app.routes, + ) + if args.debug: + print(f"Constructed OpenAPI schema for {target}") + else: + schema = app.openapi_schema + if args.debug: + print(f"Loaded OpenAPI schema for {target}") + + output = str(args.output).lower() + if output not in ("json", "yaml"): + raise ValueError( + "Invalid file format selected. Output must be either 'json' or 'yaml'" + ) + murfey_dir = Path(murfey.__path__[0]) + save_path = ( + murfey_dir / "util" / f"openapi-{target}.{output}" + if not args.to_file + else Path(args.to_file) + ) + with open(save_path, "w") as f: + if output == "json": + json.dump(schema, f, indent=2) + else: + yaml.dump( + schema, + f, + Dumper=PrettierDumper, + default_flow_style=False, + sort_keys=False, + indent=2, + ) + print(f"OpenAPI schema saved to {save_path}") + exit() + + +if __name__ == "__main__": + run() diff --git a/src/murfey/cli/generate_route_manifest.py b/src/murfey/cli/generate_route_manifest.py index d0ee1e464..e88cba7a1 100644 --- a/src/murfey/cli/generate_route_manifest.py +++ b/src/murfey/cli/generate_route_manifest.py @@ -17,54 +17,7 @@ from fastapi import APIRouter import murfey - - -class PrettierDumper(yaml.Dumper): - """ - Custom YAML Dumper class that sets `indentless` to False. This generates a YAML - file that is then compliant with Prettier's formatting style - """ - - def increase_indent(self, flow=False, indentless=False): - # Force 'indentless=False' so list items align with Prettier - return super(PrettierDumper, self).increase_indent(flow, indentless=False) - - -def prettier_str_representer(dumper, data): - """ - Helper function to format strings according to Prettier's standards: - - No quoting unless it can be misinterpreted as another data type - - When quoting, use double quotes unless string already contains double quotes - """ - - def is_implicitly_resolved(value: str) -> bool: - for ( - first_char, - resolvers, - ) in yaml.resolver.Resolver.yaml_implicit_resolvers.items(): - if first_char is None or (value and value[0] in first_char): - for resolver in resolvers: - if len(resolver) == 3: - _, regexp, _ = resolver - else: - _, regexp = resolver - if regexp.match(value): - return True - return False - - # If no quoting is needed, use default plain style - if not is_implicitly_resolved(data): - return dumper.represent_scalar("tag:yaml.org,2002:str", data) - - # If the string already contains double quotes, fall back to single quotes - if '"' in data and "'" not in data: - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="'") - - # Otherwise, prefer double quotes - return dumper.represent_scalar("tag:yaml.org,2002:str", data, style='"') - - -PrettierDumper.add_representer(str, prettier_str_representer) +from murfey.cli import PrettierDumper def find_routers(name: str) -> dict[str, APIRouter]: diff --git a/tests/cli/test_generate_openapi_schema.py b/tests/cli/test_generate_openapi_schema.py new file mode 100644 index 000000000..8eab13b2c --- /dev/null +++ b/tests/cli/test_generate_openapi_schema.py @@ -0,0 +1,74 @@ +import sys +from pathlib import Path + +import pytest +from pytest_mock import MockerFixture + +import murfey +from murfey.cli.generate_openapi_schema import run + +params_matrix: tuple[tuple[str | None, str | None, bool], ...] = ( + # Target | Output | To File + (None, None, False), + ("instrument-server", "json", True), + ("server", "yaml", False), + ("instrument-server", "yaml", False), + ("server", "json", True), +) + + +@pytest.mark.parametrize("test_params", params_matrix) +def test_run( + mocker: MockerFixture, + tmp_path: Path, + test_params: tuple[str | None, str | None, bool], +): + # Unpack test params + target, output, to_file = test_params + + # Mock out print() and exit() + mock_print = mocker.patch("builtins.print") + mock_exit = mocker.patch("builtins.exit") + + # Construct the CLI args + sys_args = [""] + if target is not None: + sys_args.extend(["-t", target]) + if output is not None: + sys_args.extend(["-o", output]) + + target = target if target is not None else "server" + output = output if output is not None else "yaml" + if to_file: + save_path = tmp_path / f"openapi.{output}" + sys_args.extend(["-f", str(save_path)]) + else: + save_path = Path(murfey.__path__[0]) / "util" / f"openapi-{target}.{output}" + sys_args.extend(["--debug"]) + sys.argv = sys_args + + # Run the function and check that it runs as expected + run() + print_calls = mock_print.call_args_list + last_print_call = print_calls[-1] + last_printed = last_print_call.args[0] + assert last_printed.startswith("OpenAPI schema saved to") + mock_exit.assert_called_once() + assert save_path.exists() + + +failure_params_matrix = ( + ["-t", "blah"], + ["-o", "blah"], +) + + +@pytest.mark.parametrize("test_params", failure_params_matrix) +def test_run_fails(test_params: list[str]): + # Construct the CLI args + sys_args = [""] + sys_args.extend(test_params) + sys.argv = sys_args + + with pytest.raises(ValueError): + run()