From d5d354f765151589cadeb723b4a946cc405fdc9f Mon Sep 17 00:00:00 2001 From: Iwan Aucamp Date: Sun, 15 May 2022 15:03:38 +0200 Subject: [PATCH] feat: add typing for `rdflib/plugins/sparql` (#1926) This patch adds typing to two files and changes imports so that they are more specific to the module in which classes are defined. This patch contains no runtime changes. I'm adding this to make it easier to spot bugs in new PRs to SPARQL code. Also: * Disable some pep8-naming errors for `rdflib/plugins/sparql/*` --- pyproject.toml | 11 ++- rdflib/plugins/sparql/evaluate.py | 92 +++++++++++++------ rdflib/plugins/sparql/sparql.py | 147 +++++++++++++++++------------- 3 files changed, 158 insertions(+), 92 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 5ee2a25b8..3d400e56b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,16 @@ pyflakes = [ "+*", ] pep8-naming = ["+*"] -"flake8-*" = ["+*"] + +[tool.flakeheaven.exceptions."rdflib/plugins/sparql/*"] +pep8-naming = [ + "-N802", + "-N803", + "-N806", + "-N812", + "-N816", + "-N801", +] [tool.black] diff --git a/rdflib/plugins/sparql/evaluate.py b/rdflib/plugins/sparql/evaluate.py index 6cdbe7540..edd322e56 100644 --- a/rdflib/plugins/sparql/evaluate.py +++ b/rdflib/plugins/sparql/evaluate.py @@ -18,13 +18,13 @@ import itertools import json as j import re -from typing import Any, Deque, Dict, List, Union +from typing import Any, Deque, Dict, Generator, Iterable, List, Tuple, Union from urllib.parse import urlencode from urllib.request import Request, urlopen from pyparsing import ParseException -from rdflib import BNode, Graph, Literal, URIRef, Variable +from rdflib.graph import Graph from rdflib.plugins.sparql import CUSTOM_EVALS, parser from rdflib.plugins.sparql.aggregates import Aggregator from rdflib.plugins.sparql.evalutils import ( @@ -42,13 +42,19 @@ AlreadyBound, Bindings, FrozenBindings, + FrozenDict, + Query, QueryContext, SPARQLError, ) -from rdflib.term import Identifier +from rdflib.term import BNode, Identifier, Literal, URIRef, Variable +_Triple = Tuple[Identifier, Identifier, Identifier] -def evalBGP(ctx: QueryContext, bgp: List[Any]): + +def evalBGP( + ctx: QueryContext, bgp: List[_Triple] +) -> Generator[FrozenBindings, None, None]: """ A basic graph pattern """ @@ -63,7 +69,8 @@ def evalBGP(ctx: QueryContext, bgp: List[Any]): _p = ctx[p] _o = ctx[o] - for ss, sp, so in ctx.graph.triples((_s, _p, _o)): + # type error: Item "None" of "Optional[Graph]" has no attribute "triples" + for ss, sp, so in ctx.graph.triples((_s, _p, _o)): # type: ignore[union-attr] if None in (_s, _p, _o): c = ctx.push() else: @@ -88,7 +95,9 @@ def evalBGP(ctx: QueryContext, bgp: List[Any]): yield x -def evalExtend(ctx: QueryContext, extend: CompValue): +def evalExtend( + ctx: QueryContext, extend: CompValue +) -> Generator[FrozenBindings, None, None]: # TODO: Deal with dict returned from evalPart from GROUP BY for c in evalPart(ctx, extend.p): @@ -103,7 +112,9 @@ def evalExtend(ctx: QueryContext, extend: CompValue): yield c -def evalLazyJoin(ctx: QueryContext, join: CompValue): +def evalLazyJoin( + ctx: QueryContext, join: CompValue +) -> Generator[FrozenBindings, None, None]: """ A lazy join will push the variables bound in the first part to the second part, @@ -116,7 +127,7 @@ def evalLazyJoin(ctx: QueryContext, join: CompValue): yield b.merge(a) # merge, as some bindings may have been forgotten -def evalJoin(ctx: QueryContext, join: CompValue): +def evalJoin(ctx: QueryContext, join: CompValue) -> Generator[FrozenDict, None, None]: # TODO: Deal with dict returned from evalPart from GROUP BY # only ever for join.p1 @@ -129,7 +140,7 @@ def evalJoin(ctx: QueryContext, join: CompValue): return _join(a, b) -def evalUnion(ctx: QueryContext, union: CompValue): +def evalUnion(ctx: QueryContext, union: CompValue) -> Iterable[FrozenBindings]: branch1_branch2 = [] for x in evalPart(ctx, union.p1): branch1_branch2.append(x) @@ -138,13 +149,15 @@ def evalUnion(ctx: QueryContext, union: CompValue): return branch1_branch2 -def evalMinus(ctx: QueryContext, minus: CompValue): +def evalMinus(ctx: QueryContext, minus: CompValue) -> Generator[FrozenDict, None, None]: a = evalPart(ctx, minus.p1) b = set(evalPart(ctx, minus.p2)) return _minus(a, b) -def evalLeftJoin(ctx: QueryContext, join: CompValue): +def evalLeftJoin( + ctx: QueryContext, join: CompValue +) -> Generator[FrozenBindings, None, None]: # import pdb; pdb.set_trace() for a in evalPart(ctx, join.p1): ok = False @@ -168,7 +181,9 @@ def evalLeftJoin(ctx: QueryContext, join: CompValue): yield a -def evalFilter(ctx: QueryContext, part: CompValue): +def evalFilter( + ctx: QueryContext, part: CompValue +) -> Generator[FrozenBindings, None, None]: # TODO: Deal with dict returned from evalPart! for c in evalPart(ctx, part.p): if _ebv( @@ -178,7 +193,9 @@ def evalFilter(ctx: QueryContext, part: CompValue): yield c -def evalGraph(ctx: QueryContext, part: CompValue): +def evalGraph( + ctx: QueryContext, part: CompValue +) -> Generator[FrozenBindings, None, None]: if ctx.dataset is None: raise Exception( @@ -211,7 +228,9 @@ def evalGraph(ctx: QueryContext, part: CompValue): yield x -def evalValues(ctx: QueryContext, part): +def evalValues( + ctx: QueryContext, part: CompValue +) -> Generator[FrozenBindings, None, None]: for r in part.p.res: c = ctx.push() try: @@ -337,7 +356,8 @@ def evalServiceQuery(ctx: QueryContext, part): res = json["results"]["bindings"] if len(res) > 0: for r in res: - for bound in _yieldBindingsFromServiceCallResult(ctx, r, variables): + # type error: Argument 2 to "_yieldBindingsFromServiceCallResult" has incompatible type "str"; expected "Dict[str, Dict[str, str]]" + for bound in _yieldBindingsFromServiceCallResult(ctx, r, variables): # type: ignore[arg-type] yield bound else: raise Exception( @@ -353,7 +373,7 @@ def evalServiceQuery(ctx: QueryContext, part): """ -def _buildQueryStringForServiceCall(ctx: QueryContext, match): +def _buildQueryStringForServiceCall(ctx: QueryContext, match: re.Match) -> str: service_query = match.group(2) try: @@ -361,10 +381,12 @@ def _buildQueryStringForServiceCall(ctx: QueryContext, match): except ParseException: # This could be because we don't have a select around the service call. service_query = "SELECT REDUCED * WHERE {" + service_query + "}" - for p in ctx.prologue.namespace_manager.store.namespaces(): + # type error: Item "None" of "Optional[Prologue]" has no attribute "namespace_manager" + for p in ctx.prologue.namespace_manager.store.namespaces(): # type: ignore[union-attr] service_query = "PREFIX " + p[0] + ":" + p[1].n3() + " " + service_query # re add the base if one was defined - base = ctx.prologue.base + # type error: Item "None" of "Optional[Prologue]" has no attribute "base" [union-attr] + base = ctx.prologue.base # type: ignore[union-attr] if base is not None and len(base) > 0: service_query = "BASE <" + base + "> " + service_query sol = ctx.solution() @@ -377,7 +399,9 @@ def _buildQueryStringForServiceCall(ctx: QueryContext, match): return service_query -def _yieldBindingsFromServiceCallResult(ctx: QueryContext, r, variables): +def _yieldBindingsFromServiceCallResult( + ctx: QueryContext, r: Dict[str, Dict[str, str]], variables: List[str] +) -> Generator[FrozenBindings, None, None]: res_dict: Dict[Variable, Identifier] = {} for var in variables: if var in r and r[var]: @@ -396,7 +420,7 @@ def _yieldBindingsFromServiceCallResult(ctx: QueryContext, r, variables): yield FrozenBindings(ctx, res_dict) -def evalGroup(ctx: QueryContext, group): +def evalGroup(ctx: QueryContext, group: CompValue): """ http://www.w3.org/TR/sparql11-query/#defn_algGroup """ @@ -404,7 +428,9 @@ def evalGroup(ctx: QueryContext, group): return evalPart(ctx, group.p) -def evalAggregateJoin(ctx: QueryContext, agg): +def evalAggregateJoin( + ctx: QueryContext, agg: CompValue +) -> Generator[FrozenBindings, None, None]: # import pdb ; pdb.set_trace() p = evalPart(ctx, agg.p) # p is always a Group, we always get a dict back @@ -435,7 +461,9 @@ def evalAggregateJoin(ctx: QueryContext, agg): yield FrozenBindings(ctx) -def evalOrderBy(ctx: QueryContext, part): +def evalOrderBy( + ctx: QueryContext, part: CompValue +) -> Generator[FrozenBindings, None, None]: res = evalPart(ctx, part.p) @@ -449,7 +477,7 @@ def evalOrderBy(ctx: QueryContext, part): return res -def evalSlice(ctx: QueryContext, slice): +def evalSlice(ctx: QueryContext, slice: CompValue): res = evalPart(ctx, slice.p) return itertools.islice( @@ -459,7 +487,9 @@ def evalSlice(ctx: QueryContext, slice): ) -def evalReduced(ctx: QueryContext, part): +def evalReduced( + ctx: QueryContext, part: CompValue +) -> Generator[FrozenBindings, None, None]: """apply REDUCED to result REDUCED is not as strict as DISTINCT, but if the incoming rows were sorted @@ -497,7 +527,9 @@ def evalReduced(ctx: QueryContext, part): mru_queue.appendleft(row) -def evalDistinct(ctx: QueryContext, part): +def evalDistinct( + ctx: QueryContext, part: CompValue +) -> Generator[FrozenBindings, None, None]: res = evalPart(ctx, part.p) done = set() @@ -507,13 +539,13 @@ def evalDistinct(ctx: QueryContext, part): done.add(x) -def evalProject(ctx: QueryContext, project): +def evalProject(ctx: QueryContext, project: CompValue): res = evalPart(ctx, project.p) return (row.project(project.PV) for row in res) -def evalSelectQuery(ctx: QueryContext, query): +def evalSelectQuery(ctx: QueryContext, query: CompValue): res = {} res["type_"] = "SELECT" @@ -522,7 +554,7 @@ def evalSelectQuery(ctx: QueryContext, query): return res -def evalAskQuery(ctx: QueryContext, query): +def evalAskQuery(ctx: QueryContext, query: CompValue): res: Dict[str, Union[bool, str]] = {} res["type_"] = "ASK" res["askAnswer"] = False @@ -533,7 +565,7 @@ def evalAskQuery(ctx: QueryContext, query): return res -def evalConstructQuery(ctx: QueryContext, query): +def evalConstructQuery(ctx: QueryContext, query) -> Dict[str, Union[str, Graph]]: template = query.template if not template: @@ -552,7 +584,7 @@ def evalConstructQuery(ctx: QueryContext, query): return res -def evalQuery(graph, query, initBindings, base=None): +def evalQuery(graph: Graph, query: Query, initBindings, base=None): initBindings = dict((Variable(k), v) for k, v in initBindings.items()) diff --git a/rdflib/plugins/sparql/sparql.py b/rdflib/plugins/sparql/sparql.py index 552530083..4846ec740 100644 --- a/rdflib/plugins/sparql/sparql.py +++ b/rdflib/plugins/sparql/sparql.py @@ -1,24 +1,26 @@ import collections import datetime import itertools +import typing as t +from typing import Any, Container, Dict, Iterable, List, Optional, Tuple, Union import isodate import rdflib.plugins.sparql -from rdflib import BNode, ConjunctiveGraph, Graph, Literal, URIRef, Variable from rdflib.compat import Mapping, MutableMapping +from rdflib.graph import ConjunctiveGraph, Graph from rdflib.namespace import NamespaceManager from rdflib.plugins.sparql.parserutils import CompValue -from rdflib.term import Node +from rdflib.term import BNode, Identifier, Literal, Node, URIRef, Variable class SPARQLError(Exception): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): Exception.__init__(self, msg) class NotBoundError(SPARQLError): - def __init__(self, msg=None): + def __init__(self, msg: Optional[str] = None): SPARQLError.__init__(self, msg) @@ -30,7 +32,7 @@ def __init__(self): class SPARQLTypeError(SPARQLError): - def __init__(self, msg): + def __init__(self, msg: Optional[str]): SPARQLError.__init__(self, msg) @@ -45,11 +47,11 @@ class Bindings(MutableMapping): In python 3.3 this could be a collections.ChainMap """ - def __init__(self, outer=None, d=[]): - self._d = dict(d) + def __init__(self, outer: Optional["Bindings"] = None, d=[]): + self._d: Dict[str, str] = dict(d) self.outer = outer - def __getitem__(self, key): + def __getitem__(self, key: str) -> str: if key in self._d: return self._d[key] @@ -57,26 +59,26 @@ def __getitem__(self, key): raise KeyError() return self.outer[key] - def __contains__(self, key): + def __contains__(self, key: Any) -> bool: try: self[key] return True except KeyError: return False - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: Any) -> None: self._d[key] = value - def __delitem__(self, key): + def __delitem__(self, key: str) -> None: raise Exception("DelItem is not implemented!") def __len__(self) -> int: i = 0 - d = self + d: Optional[Bindings] = self while d is not None: i += len(d._d) d = d.outer - return i # type: ignore[unreachable] + return i def __iter__(self): d = self @@ -84,10 +86,11 @@ def __iter__(self): yield from d._d d = d.outer - def __str__(self): - return "Bindings({" + ", ".join((k, self[k]) for k in self) + "})" + def __str__(self) -> str: + # type error: Generator has incompatible item type "Tuple[Any, str]"; expected "str" + return "Bindings({" + ", ".join((k, self[k]) for k in self) + "})" # type: ignore[misc] - def __repr__(self): + def __repr__(self) -> str: return str(self) @@ -99,20 +102,20 @@ class FrozenDict(Mapping): """ - def __init__(self, *args, **kwargs): - self._d = dict(*args, **kwargs) - self._hash = None + def __init__(self, *args: Any, **kwargs: Any): + self._d: Dict[Identifier, Identifier] = dict(*args, **kwargs) + self._hash: Optional[int] = None def __iter__(self): return iter(self._d) - def __len__(self): + def __len__(self) -> int: return len(self._d) - def __getitem__(self, key): + def __getitem__(self, key: Identifier) -> Identifier: return self._d[key] - def __hash__(self): + def __hash__(self) -> int: # It would have been simpler and maybe more obvious to # use hash(tuple(sorted(self._d.items()))) from this discussion # so far, but this solution is O(n). I don't know what kind of @@ -125,13 +128,13 @@ def __hash__(self): self._hash ^= hash(value) return self._hash - def project(self, vars): + def project(self, vars: Container[Variable]) -> "FrozenDict": return FrozenDict((x for x in self.items() if x[0] in vars)) - def disjointDomain(self, other): + def disjointDomain(self, other: t.Mapping[Identifier, Identifier]) -> bool: return not bool(set(self).intersection(other)) - def compatible(self, other): + def compatible(self, other: t.Mapping[Identifier, Identifier]) -> bool: for k in self: try: if self[k] != other[k]: @@ -141,24 +144,24 @@ def compatible(self, other): return True - def merge(self, other): + def merge(self, other: t.Mapping[Identifier, Identifier]) -> "FrozenDict": res = FrozenDict(itertools.chain(self.items(), other.items())) return res - def __str__(self): + def __str__(self) -> str: return str(self._d) - def __repr__(self): + def __repr__(self) -> str: return repr(self._d) class FrozenBindings(FrozenDict): - def __init__(self, ctx, *args, **kwargs): + def __init__(self, ctx: "QueryContext", *args, **kwargs): FrozenDict.__init__(self, *args, **kwargs) self.ctx = ctx - def __getitem__(self, key): + def __getitem__(self, key: Union[Identifier, str]) -> Identifier: if not isinstance(key, Node): key = Variable(key) @@ -167,30 +170,34 @@ def __getitem__(self, key): return key if key not in self._d: - return self.ctx.initBindings[key] + # type error: Value of type "Optional[Dict[Variable, Identifier]]" is not indexable + # type error: Invalid index type "Union[BNode, Variable]" for "Optional[Dict[Variable, Identifier]]"; expected type "Variable" + return self.ctx.initBindings[key] # type: ignore[index] else: return self._d[key] - def project(self, vars): + def project(self, vars: Container[Variable]) -> "FrozenBindings": return FrozenBindings(self.ctx, (x for x in self.items() if x[0] in vars)) - def merge(self, other): + def merge(self, other: t.Mapping[Identifier, Identifier]) -> "FrozenBindings": res = FrozenBindings(self.ctx, itertools.chain(self.items(), other.items())) return res @property - def now(self): + def now(self) -> datetime.datetime: return self.ctx.now @property - def bnodes(self): + def bnodes(self) -> t.Mapping[Identifier, BNode]: return self.ctx.bnodes @property - def prologue(self): + def prologue(self) -> Optional["Prologue"]: return self.ctx.prologue - def forget(self, before, _except=None): + def forget( + self, before: "QueryContext", _except: Optional[Container[Variable]] = None + ): """ return a frozen dict only of bindings made in self since before @@ -206,7 +213,8 @@ def forget(self, before, _except=None): for x in self.items() if ( x[0] in _except - or x[0] in self.ctx.initBindings + # type error: Unsupported right operand type for in ("Optional[Dict[Variable, Identifier]]") + or x[0] in self.ctx.initBindings # type: ignore[operator] or before[x[0]] is None ) ), @@ -224,12 +232,19 @@ class QueryContext(object): Query context - passed along when evaluating the query """ - def __init__(self, graph=None, bindings=None, initBindings=None): + def __init__( + self, + graph: Optional[Graph] = None, + bindings: Optional[Union[Bindings, FrozenBindings, List[Any]]] = None, + initBindings: Optional[Dict[Variable, Identifier]] = None, + ): self.initBindings = initBindings self.bindings = Bindings(d=bindings or []) if initBindings: self.bindings.update(initBindings) + self.graph: Optional[Graph] + self._dataset: Optional[ConjunctiveGraph] if isinstance(graph, ConjunctiveGraph): self._dataset = graph if rdflib.plugins.sparql.SPARQL_DEFAULT_GRAPH_UNION: @@ -240,10 +255,12 @@ def __init__(self, graph=None, bindings=None, initBindings=None): self._dataset = None self.graph = graph - self.prologue = None - self._now = None + self.prologue: Optional[Prologue] = None + self._now: Optional[datetime.datetime] = None - self.bnodes = collections.defaultdict(BNode) + self.bnodes: t.MutableMapping[Identifier, BNode] = collections.defaultdict( + BNode + ) @property def now(self) -> datetime.datetime: @@ -251,7 +268,9 @@ def now(self) -> datetime.datetime: self._now = datetime.datetime.now(isodate.tzinfo.UTC) return self._now - def clone(self, bindings=None): + def clone( + self, bindings: Optional[Union[FrozenBindings, Bindings, List[Any]]] = None + ) -> "QueryContext": r = QueryContext( self._dataset if self._dataset is not None else self.graph, bindings or self.bindings, @@ -263,7 +282,7 @@ def clone(self, bindings=None): return r @property - def dataset(self): + def dataset(self) -> ConjunctiveGraph: """ "current dataset""" if self._dataset is None: raise Exception( @@ -273,7 +292,7 @@ def dataset(self): ) return self._dataset - def load(self, source, default=False, **kwargs): + def load(self, source: URIRef, default: bool = False, **kwargs): def _load(graph, source): try: return graph.parse(source, format="turtle", **kwargs) @@ -298,7 +317,8 @@ def _load(graph, source): # we are not loading - if we already know the graph # being "loaded", just add it to the default-graph if default: - self.graph += self.dataset.get_context(source) + # Unsupported left operand type for + ("None") + self.graph += self.dataset.get_context(source) # type: ignore[operator] else: if default: @@ -306,7 +326,7 @@ def _load(graph, source): else: _load(self.dataset, source) - def __getitem__(self, key): + def __getitem__(self, key) -> Any: # in SPARQL BNodes are just labels if not isinstance(key, (BNode, Variable)): return key @@ -315,13 +335,13 @@ def __getitem__(self, key): except KeyError: return None - def get(self, key, default=None): + def get(self, key: Variable, default: Optional[Any] = None): try: return self[key] except KeyError: return default - def solution(self, vars=None): + def solution(self, vars: Optional[Iterable[Variable]] = None) -> FrozenBindings: """ Return a static copy of the current variable bindings as dict """ @@ -332,25 +352,25 @@ def solution(self, vars=None): else: return FrozenBindings(self, self.bindings.items()) - def __setitem__(self, key, value): + def __setitem__(self, key: Identifier, value: Identifier) -> None: if key in self.bindings and self.bindings[key] != value: raise AlreadyBound() self.bindings[key] = value - def pushGraph(self, graph): + def pushGraph(self, graph: Optional[Graph]) -> "QueryContext": r = self.clone() r.graph = graph return r - def push(self): + def push(self) -> "QueryContext": r = self.clone(Bindings(self.bindings)) return r - def clean(self): + def clean(self) -> "QueryContext": return self.clone([]) - def thaw(self, frozenbindings): + def thaw(self, frozenbindings: FrozenBindings) -> "QueryContext": """ Create a new read/write query context from the given solution """ @@ -365,19 +385,21 @@ class Prologue: """ def __init__(self): - self.base = None + self.base: Optional[str] = None self.namespace_manager = NamespaceManager(Graph()) # ns man needs a store - def resolvePName(self, prefix, localname): + def resolvePName(self, prefix: Optional[str], localname: Optional[str]) -> URIRef: ns = self.namespace_manager.store.namespace(prefix or "") if ns is None: raise Exception("Unknown namespace prefix : %s" % prefix) return URIRef(ns + (localname or "")) - def bind(self, prefix, uri): + def bind(self, prefix: Optional[str], uri: Any) -> None: self.namespace_manager.bind(prefix, uri, replace=True) - def absolutize(self, iri): + def absolutize( + self, iri: Optional[Union[CompValue, str]] + ) -> Optional[Union[CompValue, str]]: """ Apply BASE / PREFIXes to URIs (and to datatypes in Literals) @@ -389,8 +411,9 @@ def absolutize(self, iri): if iri.name == "pname": return self.resolvePName(iri.prefix, iri.localname) if iri.name == "literal": + # type error: Argument "datatype" to "Literal" has incompatible type "Union[CompValue, Identifier, None]"; expected "Optional[str]" return Literal( - iri.string, lang=iri.lang, datatype=self.absolutize(iri.datatype) + iri.string, lang=iri.lang, datatype=self.absolutize(iri.datatype) # type: ignore[arg-type] ) elif isinstance(iri, URIRef) and not ":" in iri: return URIRef(iri, base=self.base) @@ -403,9 +426,10 @@ class Query: A parsed and translated query """ - def __init__(self, prologue, algebra): + def __init__(self, prologue: Prologue, algebra: CompValue): self.prologue = prologue self.algebra = algebra + self._original_args: Tuple[str, Mapping[str, str], Optional[str]] class Update: @@ -413,6 +437,7 @@ class Update: A parsed and translated update """ - def __init__(self, prologue, algebra): + def __init__(self, prologue: Prologue, algebra: List[CompValue]): self.prologue = prologue self.algebra = algebra + self._original_args: Tuple[str, Mapping[str, str], Optional[str]]