Skip to content

Commit

Permalink
refactor: plugin API
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <henryschreineriii@gmail.com>
  • Loading branch information
henryiii committed Nov 28, 2023
1 parent 58a595f commit 9c31493
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 74 deletions.
68 changes: 5 additions & 63 deletions src/validate_pyproject/api.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
"""
Retrieve JSON schemas for validating dicts representing a ``pyproject.toml`` file.
"""
import io
import json
import logging
import sys
import typing
import urllib.parse
import urllib.request
from enum import Enum
from functools import partial, reduce
from itertools import chain
from types import MappingProxyType, ModuleType
from typing import (
Any,
Callable,
Dict,
Iterator,
Expand All @@ -33,7 +28,6 @@
from .types import FormatValidationFn, Schema, ValidationFn

_logger = logging.getLogger(__name__)
_chain_iter = chain.from_iterable

if typing.TYPE_CHECKING: # pragma: no cover
from .plugins import PluginProtocol
Expand All @@ -51,16 +45,6 @@ def read_text(package: Union[str, ModuleType], resource) -> str:
except ImportError: # pragma: no cover
from importlib.resources import read_text

if sys.platform == "emscripten" and "pyodide" in sys.modules:
from pyodide.http import open_url
else:

def open_url(url: str) -> io.StringIO:
if not url.startswith(("http:", "https:")):
raise ValueError("URL must start with 'http:' or 'https:'")
with urllib.request.urlopen(url) as response: # noqa: S310
return io.StringIO(response.read().decode("utf-8"))


T = TypeVar("T", bound=Mapping)
AllPlugins = Enum("AllPlugins", "ALL_PLUGINS")
Expand Down Expand Up @@ -197,33 +181,6 @@ def __getitem__(self, key: str) -> Callable[[str], Schema]:
return self._registry.__getitem__


def load_from_uri(tool_uri: str) -> Tuple[str, Any]:
tool_info = urllib.parse.urlparse(tool_uri)
if tool_info.netloc:
url = f"{tool_info.scheme}://{tool_info.netloc}{tool_info.path}"
with open_url(url) as f:
contents = json.load(f)
else:
with open(tool_info.path, "rb") as f:
contents = json.load(f)
return tool_info.fragment, contents


class RemotePlugin:
def __init__(self, tool: str, fragment: str, schema: Schema):
self.id = schema["$id"]
self.tool = tool
self.schema = schema
self.help_text = f"{tool} <external>"
self.fragment = fragment


if typing.TYPE_CHECKING:
from .plugins import PluginProtocol

_: PluginProtocol = typing.cast(RemotePlugin, None)


class Validator:
_plugins: Sequence["PluginProtocol"]

Expand All @@ -232,18 +189,12 @@ def __init__(
plugins: Union[Sequence["PluginProtocol"], AllPlugins] = ALL_PLUGINS,
format_validators: Mapping[str, FormatValidationFn] = FORMAT_FUNCTIONS,
extra_validations: Sequence[ValidationFn] = EXTRA_VALIDATIONS,
load_tools: Sequence[str] = (),
*,
extra_plugins: Sequence["PluginProtocol"] = (),
):
self._code_cache: Optional[str] = None
self._cache: Optional[ValidationFn] = None
self._schema: Optional[Schema] = None
self._external = {}

for tool in load_tools:
tool_name, _, tool_uri = tool.partition("=")
if not tool_uri:
raise errors.URLMissingTool(tool)
self._external[tool_name] = load_from_uri(tool_uri)

# Let's make the following options readonly
self._format_validators = MappingProxyType(format_validators)
Expand All @@ -252,18 +203,9 @@ def __init__(
if plugins is ALL_PLUGINS:
from .plugins import list_from_entry_points

self._plugins = tuple(list_from_entry_points())
else:
self._plugins = tuple(plugins) # force immutability / read only

if self._external:
self._plugins = (
*self._plugins,
*(
RemotePlugin(tool=k, fragment=v[0], schema=v[1])
for k, v in self._external.items()
),
)
plugins = list_from_entry_points()

self._plugins = (*plugins, *extra_plugins)

self._schema_registry = SchemaRegistry(self._plugins)
self.handlers = RefHandler(self._schema_registry)
Expand Down
4 changes: 3 additions & 1 deletion src/validate_pyproject/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .errors import ValidationError
from .plugins import PluginWrapper
from .plugins import list_from_entry_points as list_plugins_from_entry_points
from .remote import RemotePlugin

_logger = logging.getLogger(__package__)
T = TypeVar("T", bound=NamedTuple)
Expand Down Expand Up @@ -213,7 +214,8 @@ def run(args: Sequence[str] = ()):
plugins: List[PluginWrapper] = list_plugins_from_entry_points()
params: CliParams = parse_args(args, plugins)
setup_logging(params.loglevel)
validator = Validator(plugins=params.plugins, load_tools=params.tool)
tool_plugins = [RemotePlugin.from_str(t) for t in params.tool]
validator = Validator(params.plugins, extra_plugins=tool_plugins)

exceptions = _ExceptionGroup()
for file in params.input_file:
Expand Down
8 changes: 4 additions & 4 deletions src/validate_pyproject/pre_compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import importlib_metadata as _M

if TYPE_CHECKING: # pragma: no cover
from ..plugins import PluginWrapper
from ..plugins import PluginProtocol


_logger = logging.getLogger(__name__)
Expand All @@ -32,10 +32,10 @@ def pre_compile( # noqa: PLR0913
output_dir: Union[str, os.PathLike] = ".",
main_file: str = "__init__.py",
original_cmd: str = "",
plugins: Union[api.AllPlugins, Sequence["PluginWrapper"]] = api.ALL_PLUGINS,
plugins: Union[api.AllPlugins, Sequence["PluginProtocol"]] = api.ALL_PLUGINS,
text_replacements: Mapping[str, str] = TEXT_REPLACEMENTS,
*,
load_tools: Sequence[str] = (),
extra_plugins: Sequence["PluginProtocol"] = (),
) -> Path:
"""Populate the given ``output_dir`` with all files necessary to perform
the validation.
Expand All @@ -47,7 +47,7 @@ def pre_compile( # noqa: PLR0913
out.mkdir(parents=True, exist_ok=True)
replacements = {**TEXT_REPLACEMENTS, **text_replacements}

validator = api.Validator(plugins, load_tools=load_tools)
validator = api.Validator(plugins, extra_plugins=extra_plugins)
header = "\n".join(NOCHECK_HEADERS)
code = replace_text(validator.generated_code, replacements)
_write(out / "fastjsonschema_validations.py", header + code)
Expand Down
6 changes: 5 additions & 1 deletion src/validate_pyproject/pre_compile/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .. import cli
from ..plugins import PluginWrapper
from ..plugins import list_from_entry_points as list_plugins_from_entry_points
from ..remote import RemotePlugin
from . import pre_compile

if sys.platform == "win32": # pragma: no cover
Expand Down Expand Up @@ -98,13 +99,16 @@ def run(args: Sequence[str] = ()):
desc = 'Generate files for "pre-compiling" `validate-pyproject`'
prms = cli.parse_args(args, plugins, desc, parser_spec, CliParams)
cli.setup_logging(prms.loglevel)

tool_plugins = [RemotePlugin.from_str(t) for t in prms.tool]

pre_compile(
prms.output_dir,
prms.main_file,
cmd,
prms.plugins,
prms.replacements,
load_tools=prms.tool,
extra_plugins=tool_plugins,
)
return 0

Expand Down
62 changes: 62 additions & 0 deletions src/validate_pyproject/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import io
import json
import sys
import typing
import urllib.parse
import urllib.request
from typing import Tuple

from . import errors
from .types import Schema

if typing.TYPE_CHECKING:
if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self

if sys.platform == "emscripten" and "pyodide" in sys.modules:
from pyodide.http import open_url
else:

def open_url(url: str) -> io.StringIO:
if not url.startswith(("http:", "https:")):
raise ValueError("URL must start with 'http:' or 'https:'")
with urllib.request.urlopen(url) as response: # noqa: S310
return io.StringIO(response.read().decode("utf-8"))


__all__ = ["RemotePlugin"]


def load_from_uri(tool_uri: str) -> Tuple[str, Schema]:
tool_info = urllib.parse.urlparse(tool_uri)
if tool_info.netloc:
url = f"{tool_info.scheme}://{tool_info.netloc}{tool_info.path}"
with open_url(url) as f:
contents = json.load(f)
else:
with open(tool_info.path, "rb") as f:
contents = json.load(f)
return tool_info.fragment, contents


class RemotePlugin:
def __init__(self, tool: str, url: str):
self.tool = tool
self.fragment, self.schema = load_from_uri(url)
self.id = self.schema["$id"]
self.help_text = f"{tool} <external>"

@classmethod
def from_str(cls, tool_url: str) -> "Self":
tool, _, url = tool_url.partition("=")
if not url:
raise errors.URLMissingTool(tool)
return cls(tool, url)


if typing.TYPE_CHECKING:
from .plugins import PluginProtocol

_: PluginProtocol = typing.cast(RemotePlugin, None)
9 changes: 5 additions & 4 deletions tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@
from validate_pyproject import _tomllib as tomllib
from validate_pyproject import api, cli
from validate_pyproject.error_reporting import ValidationError
from validate_pyproject.remote import RemotePlugin

from .helpers import error_file, get_test_config


def test_examples_api(example: Path) -> None:
tools = get_test_config(example).get("tools", {})
load_tools = [f"{k}={v}" for k, v in tools.items()]
load_tools = [RemotePlugin.from_str(f"{k}={v}") for k, v in tools.items()]

toml_equivalent = tomllib.loads(example.read_text())
validator = api.Validator(load_tools=load_tools)
validator = api.Validator(extra_plugins=load_tools)
assert validator(toml_equivalent) is not None


Expand All @@ -28,11 +29,11 @@ def test_examples_cli(example: Path) -> None:

def test_invalid_examples_api(invalid_example: Path) -> None:
tools = get_test_config(invalid_example).get("tools", {})
load_tools = [f"{k}={v}" for k, v in tools.items()]
load_tools = [RemotePlugin.from_str(f"{k}={v}") for k, v in tools.items()]

expected_error = error_file(invalid_example).read_text("utf-8")
toml_equivalent = tomllib.loads(invalid_example.read_text())
validator = api.Validator(load_tools=load_tools)
validator = api.Validator(extra_plugins=load_tools)
with pytest.raises(ValidationError) as exc_info:
validator(toml_equivalent)
exception_message = str(exc_info.value)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_pre_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from validate_pyproject import _tomllib as tomllib
from validate_pyproject.pre_compile import cli, pre_compile
from validate_pyproject.remote import RemotePlugin

from .helpers import error_file, get_test_config

Expand Down Expand Up @@ -99,7 +100,8 @@ def test_vendoring_cli(tmp_path):


def api_pre_compile(tmp_path, *, load_tools: Sequence[str]) -> Path:
return pre_compile(Path(tmp_path / PRE_COMPILED_NAME), load_tools=load_tools)
plugins = [RemotePlugin.from_str(v) for v in load_tools]
return pre_compile(Path(tmp_path / PRE_COMPILED_NAME), extra_plugins=plugins)


def cli_pre_compile(tmp_path, *, load_tools: Sequence[str]) -> Path:
Expand Down

0 comments on commit 9c31493

Please sign in to comment.