diff --git a/scripts/commands/config.py b/scripts/commands/config.py index 403ee8d..edd34dd 100644 --- a/scripts/commands/config.py +++ b/scripts/commands/config.py @@ -18,7 +18,7 @@ from common.fs import validate_repo from lib.output import run_and_output -from lib.parsers import load_toml_safe +from lib.parsers import load_toml_safe, load_yaml_safe MAX_CONFIG_READ_BYTES = 32_000 @@ -55,80 +55,13 @@ def _get_nested(d: Any, *keys: str) -> Any | None: return d -def _parse_yaml_simple(content: str) -> dict: - result: dict = {} - stack: list[tuple[int, dict | list]] = [(-1, result)] - list_key_stack: list[str | None] = [None] - - for line in content.splitlines(): - if not line.strip() or line.strip().startswith("#"): - continue - - indent = len(line) - len(line.lstrip()) - stripped = line.strip() - - while len(stack) > 1 and stack[-1][0] >= indent: - stack.pop() - list_key_stack.pop() - - parent = stack[-1][1] - - if stripped.startswith("- "): - value = stripped[2:].strip() - - if isinstance(parent, dict): - last_key = list_key_stack[-1] - - if last_key and isinstance(parent.get(last_key), list): - parent[last_key].append(_yaml_scalar(value)) - elif ":" in stripped: - key, _, rest = stripped.partition(":") - key = key.strip() - rest = rest.strip() - - if rest: - if isinstance(parent, dict): - parent[key] = _yaml_scalar(rest) - list_key_stack[-1] = key - else: - if isinstance(parent, dict): - child: dict = {} - parent[key] = child - stack.append((indent, child)) - list_key_stack.append(key) - - return result - - -def _yaml_scalar(s: str): - s = s.strip().strip('"').strip("'") - - if s.lower() == "true": - return True - - if s.lower() == "false": - return False - - try: - return int(s) - except ValueError: - pass - - try: - return float(s) - except ValueError: - pass - - return s - - def _parse_by_extension(raw: str, fname: str) -> dict: """Parse raw config file content based on file extension.""" if fname.endswith(".json") or fname in (".prettierrc", ".eslintrc"): return _parse_json_safe(raw) if raw.strip().startswith("{") else {} if fname.endswith((".yml", ".yaml")): - return _parse_yaml_simple(raw) + return load_yaml_safe(raw) if fname.endswith(".toml"): return load_toml_safe(raw) diff --git a/scripts/lib/parsers.py b/scripts/lib/parsers.py index 3a60831..aaeee51 100644 --- a/scripts/lib/parsers.py +++ b/scripts/lib/parsers.py @@ -101,3 +101,12 @@ def load_yaml(content: str) -> Any: raise ParserUnavailableError("YAML parser unavailable: install 'PyYAML'") return mod.safe_load(content) + + +def load_yaml_safe(content: str) -> dict[str, Any]: + """Parse a YAML document, returning {} on any error.""" + try: + data = load_yaml(content) + return data if isinstance(data, dict) else {} + except (ParserUnavailableError, Exception): + return {} diff --git a/tests/test_config.py b/tests/test_config.py index 822b8c2..81052d7 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -2,10 +2,9 @@ _parse_editorconfig, _parse_editorconfig_for_lang, _parse_ini_section, - _parse_yaml_simple, detect, ) -from lib.parsers import load_toml_safe +from lib.parsers import load_toml_safe, load_yaml_safe from test_support import create_repo, create_sample_repo @@ -23,7 +22,7 @@ def test_config_parsers_cover_toml_yaml_ini_and_editorconfig(tmp_path): '[tool.demo]\nenabled = true\nnames = [\n "a",\n "b",\n]\n' ) - parsed_yaml = _parse_yaml_simple("tool:\n enabled: true\n count: 3\n") + parsed_yaml = load_yaml_safe("tool:\n enabled: true\n count: 3\n") parsed_ini = _parse_ini_section("[flake8]\nmax-line-length = 88\n", "[flake8]") ec_path = tmp_path / ".editorconfig" @@ -104,3 +103,33 @@ def fake_import(name, *args, **kwargs): monkeypatch.setattr(builtins, "__import__", fake_import) assert load_toml_safe('[tool]\nkey = "value"') == {} + + +def test_config_load_yaml_safe_handles_invalid_yaml(): + assert load_yaml_safe("invalid: [broken") == {} + + +def test_config_load_yaml_safe_normalizes_non_dict_output(): + assert load_yaml_safe("tool:\n enabled: true") == {"tool": {"enabled": True}} + assert load_yaml_safe("- item1\n- item2") == {} + + +def test_config_load_yaml_safe_returns_empty_on_unavailable(monkeypatch): + import lib.parsers as parsers_mod + + monkeypatch.setattr(parsers_mod, "_yaml_module", None) + monkeypatch.setattr(parsers_mod, "_yaml_checked", False) + + import builtins + + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "yaml": + raise ImportError(name) + + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + + assert load_yaml_safe("tool:\n key: value") == {}