Skip to content

Commit

Permalink
refactor: use protocol
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>
  • Loading branch information
henryiii committed Nov 1, 2023
1 parent e32ad8e commit 1474016
Show file tree
Hide file tree
Showing 14 changed files with 99 additions and 63 deletions.
38 changes: 20 additions & 18 deletions src/validate_pyproject/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from enum import Enum
from functools import partial, reduce
from itertools import chain
from types import MappingProxyType, ModuleType, SimpleNamespace
from types import MappingProxyType, ModuleType
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -36,7 +36,7 @@
_chain_iter = chain.from_iterable

if TYPE_CHECKING: # pragma: no cover
from .plugins import PluginWrapper
from .plugins import PluginProtocol


try: # pragma: no cover
Expand Down Expand Up @@ -93,7 +93,7 @@ class SchemaRegistry(Mapping[str, Schema]):
itself, all schemas provided by plugins **MUST** have a top level ``$id``.
"""

def __init__(self, plugins: Sequence["PluginWrapper"] = ()):
def __init__(self, plugins: Sequence["PluginProtocol"] = ()):
self._schemas: Dict[str, Tuple[str, str, Schema]] = {}
# (which part of the TOML, who defines, schema)

Expand Down Expand Up @@ -188,28 +188,35 @@ def __getitem__(self, key: str) -> Callable[[str], Schema]:
return self._registry.__getitem__


def load_from_uri(tool_uri: str) -> Any:
def load_from_uri(tool_uri: str) -> tuple[str, Any]:
tool_info = urllib.parse.urlparse(tool_uri)
if tool_info.netloc:
url = f"{tool_info.scheme}://{tool_info.netloc}/{tool_info.path}"
with urllib.request.urlopen(url) as f:
if not url.startswith(("http:", "https:")):
raise ValueError("URL must start with 'http:' or 'https:'")
with urllib.request.urlopen(url) as f: # noqa: S310
contents = json.load(f)
else:
with open(tool_info.path, "rb") as f:
contents = json.load(f)
for fragment in tool_info.fragment.split("/"):
if fragment:
schema = contents["$schema"]
contents = contents[fragment]
contents["$schema"] = schema
return tool_info.fragment, contents

return contents

class RemotePlugin:
def __init__(self, tool: str, fragment: str, schema: Dict[str, Any]):
self.id = schema.get("$id", f"external:{tool}")
self.tool = tool
self.schema = schema
self.help_text = f"{tool} <external>"
self._fragment = fragment # Unused ATM


class Validator:
_plugins: Sequence["PluginProtocol"]

def __init__(
self,
plugins: Union[Sequence["PluginWrapper"], AllPlugins] = ALL_PLUGINS,
plugins: Union[Sequence["PluginProtocol"], AllPlugins] = ALL_PLUGINS,
format_validators: Mapping[str, FormatValidationFn] = FORMAT_FUNCTIONS,
extra_validations: Sequence[ValidationFn] = EXTRA_VALIDATIONS,
load_tools: Sequence[str] = (),
Expand Down Expand Up @@ -238,12 +245,7 @@ def __init__(
self._plugins = (
*self._plugins,
*(
SimpleNamespace(
id=v.get("$id", f"external:{k}"),
tool=k,
schema=v,
help_text=f"{k} <external>",
)
RemotePlugin(tool=k, fragment=v[0], schema=v[1])
for k, v in self._external.items()
),
)
Expand Down
4 changes: 0 additions & 4 deletions src/validate_pyproject/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,3 @@ def uint16(value: int) -> bool:

def uint(value: int) -> bool:
return 0 <= value < 2**64


def int(value: int) -> bool:
return -(2**63) <= value < 2**63
34 changes: 32 additions & 2 deletions src/validate_pyproject/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
"""

import sys
import typing
from string import Template
from textwrap import dedent
from typing import Any, Callable, Iterable, List, Optional, cast
from typing import Any, Callable, Iterable, List, Optional

from .. import __version__
from ..types import Plugin
Expand All @@ -19,10 +20,35 @@
else: # pragma: no cover
from importlib_metadata import EntryPoint, entry_points

if typing.TYPE_CHECKING:
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
else:
Protocol = object

ENTRYPOINT_GROUP = "validate_pyproject.tool_schema"


class PluginProtocol(Protocol):
@property
def id(self):
...

@property
def tool(self):
...

@property
def schema(self):
...

@property
def help_text(self) -> str:
...


class PluginWrapper:
def __init__(self, tool: str, load_fn: Plugin):
self._tool = tool
Expand Down Expand Up @@ -51,6 +77,10 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.tool!r}, {self.id})"


if typing.TYPE_CHECKING:
_: PluginProtocol = typing.cast(PluginWrapper, None)


def iterate_entry_points(group=ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
"""Produces a generator yielding an EntryPoint object for each plugin registered
via ``setuptools`` `entry point`_ mechanism.
Expand All @@ -62,7 +92,7 @@ def iterate_entry_points(group=ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
if hasattr(entries, "select"): # pragma: no cover
# The select method was introduced in importlib_metadata 3.9 (and Python 3.10)
# and the previous dict interface was declared deprecated
select = cast(
select = typing.cast(
Any, getattr(entries, "select") # noqa: B009
) # typecheck gymnastics
entries_: Iterable[EntryPoint] = select(group=group)
Expand Down
30 changes: 27 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
"""
Dummy conftest.py for validate_pyproject.
conftest.py for validate_pyproject.
If you don't know what this is for, just leave it empty.
Read more about conftest.py under:
- https://docs.pytest.org/en/stable/fixture.html
- https://docs.pytest.org/en/stable/writing_plugins.html
"""

# import pytest
from pathlib import Path
from typing import List

import pytest

HERE = Path(__file__).parent.resolve()
EXAMPLES = HERE / "examples"
INVALID = HERE / "invalid-examples"


def examples() -> List[str]:
return [str(f.relative_to(EXAMPLES)) for f in EXAMPLES.glob("**/*.toml")]


def invalid_examples() -> List[str]:
return [str(f.relative_to(INVALID)) for f in INVALID.glob("**/*.toml")]


@pytest.fixture(params=examples())
def example(request) -> Path:
return EXAMPLES / request.param


@pytest.fixture(params=invalid_examples())
def invalid_example(request) -> Path:
return INVALID / request.param
12 changes: 1 addition & 11 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from pathlib import Path

HERE = Path(__file__).parent
EXAMPLES = HERE / "examples"
INVALID = HERE / "invalid-examples"


def examples():
return [str(f.relative_to(EXAMPLES)) for f in EXAMPLES.glob("**/*.toml")]


def invalid_examples():
return [str(f.relative_to(INVALID)) for f in INVALID.glob("**/*.toml")]
HERE = Path(__file__).parent.resolve()


def error_file(p: Path) -> Path:
Expand Down
File renamed without changes.
29 changes: 12 additions & 17 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,28 @@
import logging
from pathlib import Path

import pytest

from validate_pyproject import _tomllib as tomllib
from validate_pyproject import api, cli
from validate_pyproject.error_reporting import ValidationError

from .helpers import EXAMPLES, INVALID, error_file, examples, invalid_examples
from .helpers import error_file


@pytest.mark.parametrize("example", examples())
def test_examples_api(example):
toml_equivalent = tomllib.loads((EXAMPLES / example).read_text())
def test_examples_api(example: Path) -> None:
toml_equivalent = tomllib.loads(example.read_text())
validator = api.Validator()
assert validator(toml_equivalent) is not None


@pytest.mark.parametrize("example", examples())
def test_examples_cli(example):
assert cli.run(["--dump-json", str(EXAMPLES / example)]) == 0 # no errors
def test_examples_cli(example: Path) -> None:
assert cli.run(["--dump-json", str(example)]) == 0 # no errors


@pytest.mark.parametrize("example", invalid_examples())
def test_invalid_examples_api(example):
example_file = INVALID / example
expected_error = error_file(example_file).read_text("utf-8")
toml_equivalent = tomllib.loads(example_file.read_text())
def test_invalid_examples_api(invalid_example: Path) -> None:
expected_error = error_file(invalid_example).read_text("utf-8")
toml_equivalent = tomllib.loads(invalid_example.read_text())
validator = api.Validator()
with pytest.raises(ValidationError) as exc_info:
validator(toml_equivalent)
Expand All @@ -36,13 +33,11 @@ def test_invalid_examples_api(example):
assert error in summary


@pytest.mark.parametrize("example", invalid_examples())
def test_invalid_examples_cli(example, caplog):
def test_invalid_examples_cli(invalid_example: Path, caplog) -> None:
caplog.set_level(logging.DEBUG)
example_file = INVALID / example
expected_error = error_file(example_file).read_text("utf-8")
expected_error = error_file(invalid_example).read_text("utf-8")
with pytest.raises(SystemExit) as exc_info:
cli.main([str(example_file)])
cli.main([str(invalid_example)])
assert exc_info.value.args == (1,)
for error in expected_error.splitlines():
assert error in caplog.text
15 changes: 7 additions & 8 deletions tests/test_pre_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from validate_pyproject import _tomllib as tomllib
from validate_pyproject.pre_compile import cli, pre_compile

from .helpers import EXAMPLES, INVALID, error_file, examples, invalid_examples
from .helpers import error_file

MAIN_FILE = "hello_world.py" # Let's use something different that `__init__.py`

Expand Down Expand Up @@ -136,20 +136,19 @@ def _validate(vendored_path, toml_equivalent):
return _validate


@pytest.mark.parametrize("example", examples())
@pytest.mark.parametrize("pre_compiled", _PRE_COMPILED)
def test_examples_api(tmp_path, pre_compiled_validate, example, pre_compiled):
toml_equivalent = tomllib.loads((EXAMPLES / example).read_text())
toml_equivalent = tomllib.loads(example.read_text())
pre_compiled_path = pre_compiled(Path(tmp_path))
assert pre_compiled_validate(pre_compiled_path, toml_equivalent) is not None


@pytest.mark.parametrize("example", invalid_examples())
@pytest.mark.parametrize("pre_compiled", _PRE_COMPILED)
def test_invalid_examples_api(tmp_path, pre_compiled_validate, example, pre_compiled):
example_file = INVALID / example
expected_error = error_file(example_file).read_text("utf-8")
toml_equivalent = tomllib.loads(example_file.read_text())
def test_invalid_examples_api(
tmp_path, pre_compiled_validate, invalid_example, pre_compiled
):
expected_error = error_file(invalid_example).read_text("utf-8")
toml_equivalent = tomllib.loads(invalid_example.read_text())
pre_compiled_path = pre_compiled(Path(tmp_path))
with pytest.raises(JsonSchemaValueException) as exc_info:
pre_compiled_validate(pre_compiled_path, toml_equivalent)
Expand Down

0 comments on commit 1474016

Please sign in to comment.