Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: load tool(s) via URL #121

Merged
merged 12 commits into from
Nov 29, 2023
8 changes: 6 additions & 2 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ source =
# Regexes for lines to exclude from consideration
exclude_lines =
# Have to re-enable the standard pragma
# (exclude_also would be better, but not available on Python 3.6)
pragma: no cover

# Don't complain about missing debug-only code:
def __repr__
if self\.debug

# Don't complain if tests don't hit defensive assertion code:
raise AssertionError
Expand All @@ -31,4 +31,8 @@ exclude_lines =
if __name__ == .__main__.:

if TYPE_CHECKING:
if typing.TYPE_CHECKING:
if typing\.TYPE_CHECKING:
^\s+\.\.\.$

# Support for Pyodide (WASM)
if sys\.platform == .emscripten. and .pyodide. in sys\.modules:
40 changes: 20 additions & 20 deletions src/validate_pyproject/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
import json
import logging
import sys
import typing
from enum import Enum
from functools import partial, reduce
from itertools import chain
from types import MappingProxyType, ModuleType
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Iterator,
Expand All @@ -19,7 +18,6 @@
Tuple,
TypeVar,
Union,
cast,
)

import fastjsonschema as FJS
Expand All @@ -30,14 +28,13 @@
from .types import FormatValidationFn, Schema, ValidationFn

_logger = logging.getLogger(__name__)
_chain_iter = chain.from_iterable

if TYPE_CHECKING: # pragma: no cover
from .plugins import PluginWrapper
if typing.TYPE_CHECKING: # pragma: no cover
from .plugins import PluginProtocol


try: # pragma: no cover
if sys.version_info[:2] < (3, 7) or TYPE_CHECKING: # See #22
if sys.version_info[:2] < (3, 7) or typing.TYPE_CHECKING: # See #22
from importlib_resources import files
else:
from importlib.resources import files
Expand Down Expand Up @@ -90,11 +87,11 @@ class SchemaRegistry(Mapping[str, Schema]):
itself, all schemas provided by plugins **MUST** have a top level ``$id``.
"""

def __init__(self, plugins: Sequence["PluginWrapper"] = ()):
def __init__(self, plugins: Sequence["PluginProtocol"] = ()):
self._schemas: Dict[str, Tuple[str, str, Schema]] = {}
# (which part of the TOML, who defines, schema)

top_level = cast(dict, load(TOP_LEVEL_SCHEMA)) # Make it mutable
top_level = typing.cast(dict, load(TOP_LEVEL_SCHEMA)) # Make it mutable
self._spec_version = top_level["$schema"]
top_properties = top_level["properties"]
tool_properties = top_properties["tool"].setdefault("properties", {})
Expand All @@ -108,16 +105,15 @@ def __init__(self, plugins: Sequence["PluginWrapper"] = ()):
self._schemas = {sid: ("project", origin, project_table_schema)}

# Add tools using Plugins

for plugin in plugins:
pid, tool, schema = plugin.id, plugin.tool, plugin.schema
if plugin.tool in tool_properties:
_logger.warning(f"{plugin.id} overwrites `tool.{plugin.tool}` schema")
else:
_logger.info(f"{pid} defines `tool.{tool}` schema")
sid = self._ensure_compatibility(tool, schema)["$id"]
tool_properties[tool] = {"$ref": sid}
self._schemas[sid] = (f"tool.{tool}", pid, schema)
_logger.info(f"{plugin.id} defines `tool.{plugin.tool}` schema")
sid = self._ensure_compatibility(plugin.tool, plugin.schema)["$id"]
sref = f"{sid}#{plugin.fragment}" if plugin.fragment else sid
tool_properties[plugin.tool] = {"$ref": sref}
self._schemas[sid] = (f"tool.{plugin.tool}", plugin.id, plugin.schema)

self._main_id = sid = top_level["$id"]
main_schema = Schema(top_level)
Expand Down Expand Up @@ -186,11 +182,15 @@ def __getitem__(self, key: str) -> Callable[[str], Schema]:


class Validator:
_plugins: Sequence["PluginProtocol"]

def __init__(
self,
plugins: Union[Sequence["PluginWrapper"], AllPlugins] = ALL_PLUGINS,
plugins: Union[Sequence["PluginProtocol"], AllPlugins] = ALL_PLUGINS,
format_validators: Mapping[str, FormatValidationFn] = FORMAT_FUNCTIONS,
extra_validations: Sequence[ValidationFn] = EXTRA_VALIDATIONS,
*,
extra_plugins: Sequence["PluginProtocol"] = (),
):
self._code_cache: Optional[str] = None
self._cache: Optional[ValidationFn] = None
Expand All @@ -203,9 +203,9 @@ def __init__(
if plugins is ALL_PLUGINS:
from .plugins import list_from_entry_points

self._plugins = tuple(list_from_entry_points())
else:
self._plugins = tuple(plugins) # force immutability / read only
plugins = list_from_entry_points()

self._plugins = (*plugins, *extra_plugins)

self._schema_registry = SchemaRegistry(self._plugins)
self.handlers = RefHandler(self._schema_registry)
Expand Down Expand Up @@ -245,7 +245,7 @@ def __call__(self, pyproject: T) -> T:
if self._cache is None:
compiled = FJS.compile(self.schema, self.handlers, dict(self.formats))
fn = partial(compiled, custom_formats=self._format_validators)
self._cache = cast(ValidationFn, fn)
self._cache = typing.cast(ValidationFn, fn)

with detailed_errors():
self._cache(pyproject)
Expand Down
12 changes: 11 additions & 1 deletion src/validate_pyproject/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .errors import ValidationError
from .plugins import PluginWrapper
from .plugins import list_from_entry_points as list_plugins_from_entry_points
from .remote import RemotePlugin

_logger = logging.getLogger(__package__)
T = TypeVar("T", bound=NamedTuple)
Expand Down Expand Up @@ -99,12 +100,19 @@ def critical_logging():
action="store_true",
help="Print the JSON equivalent to the given TOML",
),
"tool": dict(
flags=("-t", "--tool"),
action="append",
dest="tool",
help="External tools file/url(s) to load, of the form name=URL#path",
),
}


class CliParams(NamedTuple):
input_file: List[io.TextIOBase]
plugins: List[PluginWrapper]
tool: List[str]
loglevel: int = logging.WARNING
dump_json: bool = False

Expand Down Expand Up @@ -147,6 +155,7 @@ def parse_args(
enabled = params.pop("enable", ())
disabled = params.pop("disable", ())
params["plugins"] = select_plugins(plugins, enabled, disabled)
params["tool"] = params["tool"] or []
return params_class(**params) # type: ignore[call-overload]


Expand Down Expand Up @@ -205,7 +214,8 @@ def run(args: Sequence[str] = ()):
plugins: List[PluginWrapper] = list_plugins_from_entry_points()
params: CliParams = parse_args(args, plugins)
setup_logging(params.loglevel)
validator = Validator(plugins=params.plugins)
tool_plugins = [RemotePlugin.from_str(t) for t in params.tool]
validator = Validator(params.plugins, extra_plugins=tool_plugins)

exceptions = _ExceptionGroup()
for file in params.input_file:
Expand Down
15 changes: 15 additions & 0 deletions src/validate_pyproject/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,21 @@
from .error_reporting import ValidationError


class URLMissingTool(RuntimeError):
_DESC = """\
The '--tool' option requires a tool name.

Correct form is '--tool <tool-name>={url}', with an optional
'#json/pointer' at the end.
"""
__doc__ = _DESC

def __init__(self, url: str):
msg = dedent(self._DESC).strip()
msg = msg.format(url=url)
super().__init__(msg)


class InvalidSchemaVersion(JsonSchemaDefinitionException):
_DESC = """\
All schemas used in the validator should be specified using the same version \
Expand Down
11 changes: 8 additions & 3 deletions src/validate_pyproject/formats.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
import logging
import os
import re
Expand Down Expand Up @@ -275,13 +276,17 @@ def python_entrypoint_reference(value: str) -> bool:
return all(python_identifier(i.strip()) for i in identifiers)


def uint8(value: int) -> bool:
def uint8(value: builtins.int) -> bool:
return 0 <= value < 2**8


def uint16(value: int) -> bool:
def uint16(value: builtins.int) -> bool:
return 0 <= value < 2**16


def uint(value: int) -> bool:
def uint(value: builtins.int) -> bool:
return 0 <= value < 2**64


def int(value: builtins.int) -> bool:
return -(2**63) <= value < 2**63
50 changes: 44 additions & 6 deletions src/validate_pyproject/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,40 +6,74 @@
"""

import sys
import typing
from string import Template
from textwrap import dedent
from typing import Any, Callable, Iterable, List, Optional, cast
from typing import Any, Callable, Iterable, List, Optional

from .. import __version__
from ..types import Plugin
from ..types import Plugin, Schema

if sys.version_info[:2] >= (3, 8): # pragma: no cover
# TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
from importlib.metadata import EntryPoint, entry_points
else: # pragma: no cover
from importlib_metadata import EntryPoint, entry_points

if typing.TYPE_CHECKING:
if sys.version_info < (3, 8):
from typing_extensions import Protocol
else:
from typing import Protocol
else:
Protocol = object

ENTRYPOINT_GROUP = "validate_pyproject.tool_schema"


class PluginProtocol(Protocol):
@property
def id(self) -> str:
...

@property
def tool(self) -> str:
...

@property
def schema(self) -> Schema:
...

@property
def help_text(self) -> str:
...

@property
def fragment(self) -> str:
...


class PluginWrapper:
def __init__(self, tool: str, load_fn: Plugin):
self._tool = tool
self._load_fn = load_fn

@property
def id(self):
def id(self) -> str:
return f"{self._load_fn.__module__}.{self._load_fn.__name__}"

@property
def tool(self):
def tool(self) -> str:
return self._tool

@property
def schema(self):
def schema(self) -> Schema:
return self._load_fn(self.tool)

@property
def fragment(self) -> str:
return ""

@property
def help_text(self) -> str:
tpl = self._load_fn.__doc__
Expand All @@ -51,6 +85,10 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.tool!r}, {self.id})"


if typing.TYPE_CHECKING:
_: PluginProtocol = typing.cast(PluginWrapper, None)


def iterate_entry_points(group=ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
"""Produces a generator yielding an EntryPoint object for each plugin registered
via ``setuptools`` `entry point`_ mechanism.
Expand All @@ -62,7 +100,7 @@ def iterate_entry_points(group=ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
if hasattr(entries, "select"): # pragma: no cover
# The select method was introduced in importlib_metadata 3.9 (and Python 3.10)
# and the previous dict interface was declared deprecated
select = cast(
select = typing.cast(
Any, getattr(entries, "select") # noqa: B009
) # typecheck gymnastics
entries_: Iterable[EntryPoint] = select(group=group)
Expand Down
10 changes: 6 additions & 4 deletions src/validate_pyproject/pre_compile/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import importlib_metadata as _M

if TYPE_CHECKING: # pragma: no cover
from ..plugins import PluginWrapper
from ..plugins import PluginProtocol


_logger = logging.getLogger(__name__)
Expand All @@ -28,12 +28,14 @@
)


def pre_compile(
def pre_compile( # noqa: PLR0913
output_dir: Union[str, os.PathLike] = ".",
main_file: str = "__init__.py",
original_cmd: str = "",
plugins: Union[api.AllPlugins, Sequence["PluginWrapper"]] = api.ALL_PLUGINS,
plugins: Union[api.AllPlugins, Sequence["PluginProtocol"]] = api.ALL_PLUGINS,
text_replacements: Mapping[str, str] = TEXT_REPLACEMENTS,
*,
extra_plugins: Sequence["PluginProtocol"] = (),
) -> Path:
"""Populate the given ``output_dir`` with all files necessary to perform
the validation.
Expand All @@ -45,7 +47,7 @@ def pre_compile(
out.mkdir(parents=True, exist_ok=True)
replacements = {**TEXT_REPLACEMENTS, **text_replacements}

validator = api.Validator(plugins)
validator = api.Validator(plugins, extra_plugins=extra_plugins)
header = "\n".join(NOCHECK_HEADERS)
code = replace_text(validator.generated_code, replacements)
_write(out / "fastjsonschema_validations.py", header + code)
Expand Down
Loading