Skip to content

Commit

Permalink
Allow overwrite when schemas refer to the same tool (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
abravalheri committed May 20, 2024
2 parents ddafc5a + b1a973c commit 8781a9f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 7 deletions.
15 changes: 11 additions & 4 deletions src/validate_pyproject/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,16 @@ def __init__(self, plugins: Sequence["PluginProtocol"] = ()):

# Add tools using Plugins
for plugin in plugins:
allow_overwrite: Optional[str] = None
if plugin.tool in tool_properties:
_logger.warning(f"{plugin.id} overwrites `tool.{plugin.tool}` schema")
allow_overwrite = plugin.schema.get("$id")
else:
_logger.info(f"{plugin.id} defines `tool.{plugin.tool}` schema")
sid = self._ensure_compatibility(plugin.tool, plugin.schema)["$id"]
compatible = self._ensure_compatibility(
plugin.tool, plugin.schema, allow_overwrite
)
sid = compatible["$id"]
sref = f"{sid}#{plugin.fragment}" if plugin.fragment else sid
tool_properties[plugin.tool] = {"$ref": sref}
self._schemas[sid] = (f"tool.{plugin.tool}", plugin.id, plugin.schema)
Expand All @@ -133,11 +138,13 @@ def main(self) -> str:
"""Top level schema for validating a ``pyproject.toml`` file"""
return self._main_id

def _ensure_compatibility(self, reference: str, schema: Schema) -> Schema:
if "$id" not in schema:
def _ensure_compatibility(
self, reference: str, schema: Schema, allow_overwrite: Optional[str] = None
) -> Schema:
if "$id" not in schema or not schema["$id"]:
raise errors.SchemaMissingId(reference)
sid = schema["$id"]
if sid in self._schemas:
if sid in self._schemas and sid != allow_overwrite:
raise errors.SchemaWithDuplicatedId(sid)
version = schema.get("$schema")
# Support schemas with missing trailing # (incorrect, but required before 0.15)
Expand Down
4 changes: 3 additions & 1 deletion src/validate_pyproject/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def iterate_entry_points(group: str = ENTRYPOINT_GROUP) -> Iterable[EntryPoint]:
# TODO: Once Python 3.10 becomes the oldest version supported, this fallback and
# conditional statement can be removed.
entries_ = (plugin for plugin in entries.get(group, []))
deduplicated = {e.name: e for e in sorted(entries_, key=lambda e: e.name)}
deduplicated = {
e.name: e for e in sorted(entries_, key=lambda e: (e.name, e.value))
}
return list(deduplicated.values())


Expand Down
12 changes: 10 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,19 @@ def test_incompatible_versions(self):
with pytest.raises(errors.InvalidSchemaVersion):
api.SchemaRegistry([plg])

def test_duplicated_id(self):
plg = [plugins.PluginWrapper("plg", self.fake_plugin) for _ in range(2)]
def test_duplicated_id_different_tools(self):
schema = self.fake_plugin("plg")
fn = wraps(self.fake_plugin)(lambda _: schema) # Same ID
plg = [plugins.PluginWrapper(f"plg{i}", fn) for i in range(2)]
with pytest.raises(errors.SchemaWithDuplicatedId):
api.SchemaRegistry(plg)

def test_allow_overwrite_same_tool(self):
plg = [plugins.PluginWrapper("plg", self.fake_plugin) for _ in range(2)]
registry = api.SchemaRegistry(plg)
sid = self.fake_plugin("plg")["$id"]
assert sid in registry

def test_missing_id(self):
def _fake_plugin(name):
plg = dict(self.fake_plugin(name))
Expand Down

0 comments on commit 8781a9f

Please sign in to comment.