Skip to content

Commit

Permalink
Merge pull request #1633 from aucampia/iwana-20211025T2151-sparql_typing
Browse files Browse the repository at this point in the history
Add some typing for evaluation related functions in the SPARQL plugin.
  • Loading branch information
nicholascar committed Jan 4, 2022
2 parents 1226b52 + 251121f commit 12b5320
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 36 deletions.
68 changes: 36 additions & 32 deletions rdflib/plugins/sparql/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
import collections
import itertools
import re
from typing import Any, Deque, Dict, List, Union
from urllib.request import urlopen, Request
from urllib.parse import urlencode
import json as j
from pyparsing import ParseException

from rdflib import Variable, Graph, BNode, URIRef, Literal
from rdflib.plugins.sparql import CUSTOM_EVALS
from rdflib.plugins.sparql.parserutils import value
from rdflib.plugins.sparql.parserutils import CompValue, value
from rdflib.plugins.sparql.sparql import (
QueryContext,
AlreadyBound,
Expand All @@ -45,9 +46,10 @@

from rdflib.plugins.sparql.aggregates import Aggregator
from rdflib.plugins.sparql import parser
from rdflib.term import Identifier


def evalBGP(ctx, bgp):
def evalBGP(ctx: QueryContext, bgp: List[Any]):
"""
A basic graph pattern
"""
Expand Down Expand Up @@ -87,7 +89,7 @@ def evalBGP(ctx, bgp):
yield x


def evalExtend(ctx, extend):
def evalExtend(ctx: QueryContext, extend: CompValue):
# TODO: Deal with dict returned from evalPart from GROUP BY

for c in evalPart(ctx, extend.p):
Expand All @@ -102,7 +104,7 @@ def evalExtend(ctx, extend):
yield c


def evalLazyJoin(ctx, join):
def evalLazyJoin(ctx: QueryContext, join: CompValue):
"""
A lazy join will push the variables bound
in the first part to the second part,
Expand All @@ -115,7 +117,7 @@ def evalLazyJoin(ctx, join):
yield b.merge(a) # merge, as some bindings may have been forgotten


def evalJoin(ctx, join):
def evalJoin(ctx: QueryContext, join: CompValue):

# TODO: Deal with dict returned from evalPart from GROUP BY
# only ever for join.p1
Expand All @@ -128,7 +130,7 @@ def evalJoin(ctx, join):
return _join(a, b)


def evalUnion(ctx, union):
def evalUnion(ctx: QueryContext, union: CompValue):
branch1_branch2 = []
for x in evalPart(ctx, union.p1):
branch1_branch2.append(x)
Expand All @@ -137,13 +139,13 @@ def evalUnion(ctx, union):
return branch1_branch2


def evalMinus(ctx, minus):
def evalMinus(ctx: QueryContext, minus: CompValue):
a = evalPart(ctx, minus.p1)
b = set(evalPart(ctx, minus.p2))
return _minus(a, b)


def evalLeftJoin(ctx, join):
def evalLeftJoin(ctx: QueryContext, join: CompValue):
# import pdb; pdb.set_trace()
for a in evalPart(ctx, join.p1):
ok = False
Expand All @@ -167,7 +169,7 @@ def evalLeftJoin(ctx, join):
yield a


def evalFilter(ctx, part):
def evalFilter(ctx: QueryContext, part: CompValue):
# TODO: Deal with dict returned from evalPart!
for c in evalPart(ctx, part.p):
if _ebv(
Expand All @@ -177,7 +179,7 @@ def evalFilter(ctx, part):
yield c


def evalGraph(ctx, part):
def evalGraph(ctx: QueryContext, part: CompValue):

if ctx.dataset is None:
raise Exception(
Expand Down Expand Up @@ -210,7 +212,7 @@ def evalGraph(ctx, part):
yield x


def evalValues(ctx, part):
def evalValues(ctx: QueryContext, part):
for r in part.p.res:
c = ctx.push()
try:
Expand All @@ -223,15 +225,15 @@ def evalValues(ctx, part):
yield c.solution()


def evalMultiset(ctx, part):
def evalMultiset(ctx: QueryContext, part: CompValue):

if part.p.name == "values":
return evalValues(ctx, part)

return evalPart(ctx, part.p)


def evalPart(ctx, part):
def evalPart(ctx: QueryContext, part: CompValue):

# try custom evaluation functions
for name, c in CUSTOM_EVALS.items():
Expand Down Expand Up @@ -299,7 +301,7 @@ def evalPart(ctx, part):
raise Exception("I dont know: %s" % part.name)


def evalServiceQuery(ctx, part):
def evalServiceQuery(ctx: QueryContext, part):
res = {}
match = re.match(
"^service <(.*)>[ \n]*{(.*)}[ \n]*$",
Expand Down Expand Up @@ -345,14 +347,14 @@ def evalServiceQuery(ctx, part):


"""
Build a query string to be used by the service call.
Build a query string to be used by the service call.
It is supposed to pass in the existing bound solutions.
Re-adds prefixes if added and sets the base.
Wraps it in select if needed.
"""


def _buildQueryStringForServiceCall(ctx, match):
def _buildQueryStringForServiceCall(ctx: QueryContext, match):

service_query = match.group(2)
try:
Expand All @@ -376,8 +378,8 @@ def _buildQueryStringForServiceCall(ctx, match):
return service_query


def _yieldBindingsFromServiceCallResult(ctx, r, variables):
res_dict = {}
def _yieldBindingsFromServiceCallResult(ctx: QueryContext, r, variables):
res_dict: Dict[Variable, Identifier] = {}
for var in variables:
if var in r and r[var]:
if r[var]["type"] == "uri":
Expand All @@ -395,21 +397,23 @@ def _yieldBindingsFromServiceCallResult(ctx, r, variables):
yield FrozenBindings(ctx, res_dict)


def evalGroup(ctx, group):
def evalGroup(ctx: QueryContext, group):
"""
http://www.w3.org/TR/sparql11-query/#defn_algGroup
"""
# grouping should be implemented by evalAggregateJoin
return evalPart(ctx, group.p)


def evalAggregateJoin(ctx, agg):
def evalAggregateJoin(ctx: QueryContext, agg):
# import pdb ; pdb.set_trace()
p = evalPart(ctx, agg.p)
# p is always a Group, we always get a dict back

group_expr = agg.p.expr
res = collections.defaultdict(lambda: Aggregator(aggregations=agg.A))
res: Dict[Any, Any] = collections.defaultdict(
lambda: Aggregator(aggregations=agg.A)
)

if group_expr is None:
# no grouping, just COUNT in SELECT clause
Expand All @@ -432,7 +436,7 @@ def evalAggregateJoin(ctx, agg):
yield FrozenBindings(ctx)


def evalOrderBy(ctx, part):
def evalOrderBy(ctx: QueryContext, part):

res = evalPart(ctx, part.p)

Expand All @@ -446,7 +450,7 @@ def evalOrderBy(ctx, part):
return res


def evalSlice(ctx, slice):
def evalSlice(ctx: QueryContext, slice):
res = evalPart(ctx, slice.p)

return itertools.islice(
Expand All @@ -456,7 +460,7 @@ def evalSlice(ctx, slice):
)


def evalReduced(ctx, part):
def evalReduced(ctx: QueryContext, part):
"""apply REDUCED to result
REDUCED is not as strict as DISTINCT, but if the incoming rows were sorted
Expand All @@ -477,7 +481,7 @@ def evalReduced(ctx, part):

# mixed data structure: set for lookup, deque for append/pop/remove
mru_set = set()
mru_queue = collections.deque()
mru_queue: Deque[Any] = collections.deque()

for row in evalPart(ctx, part.p):
if row in mru_set:
Expand All @@ -494,7 +498,7 @@ def evalReduced(ctx, part):
mru_queue.appendleft(row)


def evalDistinct(ctx, part):
def evalDistinct(ctx: QueryContext, part):
res = evalPart(ctx, part.p)

done = set()
Expand All @@ -504,13 +508,13 @@ def evalDistinct(ctx, part):
done.add(x)


def evalProject(ctx, project):
def evalProject(ctx: QueryContext, project):
res = evalPart(ctx, project.p)

return (row.project(project.PV) for row in res)


def evalSelectQuery(ctx, query):
def evalSelectQuery(ctx: QueryContext, query):

res = {}
res["type_"] = "SELECT"
Expand All @@ -519,8 +523,8 @@ def evalSelectQuery(ctx, query):
return res


def evalAskQuery(ctx, query):
res = {}
def evalAskQuery(ctx: QueryContext, query):
res: Dict[str, Union[bool, str]] = {}
res["type_"] = "ASK"
res["askAnswer"] = False
for x in evalPart(ctx, query.p):
Expand All @@ -530,7 +534,7 @@ def evalAskQuery(ctx, query):
return res


def evalConstructQuery(ctx, query):
def evalConstructQuery(ctx: QueryContext, query):
template = query.template

if not template:
Expand All @@ -542,7 +546,7 @@ def evalConstructQuery(ctx, query):
for c in evalPart(ctx, query.p):
graph += _fillTemplate(template, c)

res = {}
res: Dict[str, Union[str, Graph]] = {}
res["type_"] = "CONSTRUCT"
res["graph"] = graph

Expand Down
9 changes: 5 additions & 4 deletions rdflib/plugins/sparql/evalutils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import collections
from typing import Dict, Iterable

from rdflib.term import Variable, Literal, BNode, URIRef

from rdflib.plugins.sparql.operators import EBV
from rdflib.plugins.sparql.parserutils import Expr, CompValue
from rdflib.plugins.sparql.sparql import SPARQLError, NotBoundError
from rdflib.plugins.sparql.sparql import FrozenDict, SPARQLError, NotBoundError


def _diff(a, b, expr):
def _diff(a: Iterable[FrozenDict], b: Iterable[FrozenDict], expr):
res = set()

for x in a:
Expand All @@ -17,13 +18,13 @@ def _diff(a, b, expr):
return res


def _minus(a, b):
def _minus(a: Iterable[FrozenDict], b: Iterable[FrozenDict]):
for x in a:
if all((not x.compatible(y)) or x.disjointDomain(y) for y in b):
yield x


def _join(a, b):
def _join(a: Iterable[FrozenDict], b: Iterable[Dict]):
for x in a:
for y in b:
if x.compatible(y):
Expand Down

0 comments on commit 12b5320

Please sign in to comment.