Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: SPARQL XML result parsing #2044

Merged
merged 1 commit into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/validate.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ jobs:
- python-version: "3.9"
os: ubuntu-latest
TOX_EXTRA_COMMAND: "- black --check --diff ./rdflib"
TOXENV_SUFFIX: "-lxml"
- python-version: "3.10"
os: ubuntu-latest
TOX_EXTRA_COMMAND: "flake8 --exit-zero rdflib"
Expand Down
25 changes: 25 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,31 @@ and will be removed for release.
<!-- -->
<!-- -->


<!-- -->
<!-- -->
<!-- CHANGE BARRIER: START PR #2044 -->
<!-- -->
<!-- -->

- Fixed some issues with SPARQL XML result parsing that caused problems with
[`lxml`](https://lxml.de/). Closed [issue #2035](https://github.com/RDFLib/rdflib/issues/2035),
[issue #1847](https://github.com/RDFLib/rdflib/issues/1847).
[PR #2044](https://github.com/RDFLib/rdflib/pull/2044).
- Result parsing from
[`TextIO`](https://docs.python.org/3/library/typing.html#typing.TextIO)
streams now work correctly with `lxml` installed and with XML documents that
are not `utf-8` encoded.
- Elements inside `<results>` that are not `<result>` are now ignored.
- Elements inside `<result>` that are not `<binding>` are now ignored.
- Also added type hints to `rdflib.plugins.sparql.results.xmlresults`.

<!-- -->
<!-- -->
<!-- CHANGE BARRIER: END -->
<!-- -->
<!-- -->

<!-- -->
<!-- -->
<!-- CHANGE BARRIER: START -->
Expand Down
10 changes: 1 addition & 9 deletions rdflib/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,7 @@
import codecs
import re
import warnings
from typing import TYPE_CHECKING, Match

if TYPE_CHECKING:
import xml.etree.ElementTree as etree
else:
try:
from lxml import etree
except ImportError:
import xml.etree.ElementTree as etree
from typing import Match


def cast_bytes(s, enc="utf-8"):
Expand Down
113 changes: 83 additions & 30 deletions rdflib/plugins/sparql/results/xmlresults.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,34 @@
import logging
from typing import IO, Optional
import xml.etree.ElementTree as xml_etree # noqa: N813
from io import BytesIO
from typing import (
IO,
TYPE_CHECKING,
Any,
BinaryIO,
Dict,
Optional,
Sequence,
TextIO,
Tuple,
Union,
cast,
)
from xml.dom import XML_NAMESPACE
from xml.sax.saxutils import XMLGenerator
from xml.sax.xmlreader import AttributesNSImpl

from rdflib import BNode, Literal, URIRef, Variable
from rdflib.compat import etree
from rdflib.query import Result, ResultException, ResultParser, ResultSerializer
from rdflib.term import Identifier

try:
# https://adamj.eu/tech/2021/12/29/python-type-hints-optional-imports/
import lxml.etree as lxml_etree

FOUND_LXML = True
except ImportError:
FOUND_LXML = False

SPARQL_XML_NAMESPACE = "http://www.w3.org/2005/sparql-results#"
RESULTS_NS_ET = "{%s}" % SPARQL_XML_NAMESPACE
Expand All @@ -27,19 +49,32 @@

class XMLResultParser(ResultParser):
# TODO FIXME: content_type should be a keyword only arg.
def parse(self, source, content_type: Optional[str] = None): # type: ignore[override]
def parse(self, source: IO, content_type: Optional[str] = None): # type: ignore[override]
return XMLResult(source)


class XMLResult(Result):
def __init__(self, source, content_type: Optional[str] = None):

try:
# try use as if etree is from lxml, and if not use it as normal.
parser = etree.XMLParser(huge_tree=True) # type: ignore[call-arg]
tree = etree.parse(source, parser)
except TypeError:
tree = etree.parse(source)
def __init__(self, source: IO, content_type: Optional[str] = None):
parser_encoding: Optional[str] = None
if hasattr(source, "encoding"):
if TYPE_CHECKING:
assert isinstance(source, TextIO)
parser_encoding = "utf-8"
source_str = source.read()
source = BytesIO(source_str.encode(parser_encoding))
else:
if TYPE_CHECKING:
assert isinstance(source, BinaryIO)

if FOUND_LXML:
lxml_parser = lxml_etree.XMLParser(huge_tree=True, encoding=parser_encoding)
tree = cast(
xml_etree.ElementTree,
lxml_etree.parse(source, parser=lxml_parser),
)
else:
xml_parser = xml_etree.XMLParser(encoding=parser_encoding)
tree = xml_etree.parse(source, parser=xml_parser)

boolean = tree.find(RESULTS_NS_ET + "boolean")
results = tree.find(RESULTS_NS_ET + "results")
Expand All @@ -56,8 +91,18 @@ def __init__(self, source, content_type: Optional[str] = None):
if type_ == "SELECT":
self.bindings = []
for result in results: # type: ignore[union-attr]
if result.tag != f"{RESULTS_NS_ET}result":
# This is here because with lxml this also gets comments,
# not just elements. Also this should not operate on non
# "result" elements.
continue
r = {}
for binding in result:
if binding.tag != f"{RESULTS_NS_ET}binding":
# This is here because with lxml this also gets
# comments, not just elements. Also this should not
# operate on non "binding" elements.
continue
# type error: error: Argument 1 to "Variable" has incompatible type "Union[str, None, Any]"; expected "str"
# NOTE on type error: Element.get() can return None, and
# this will invariably fail if passed into Variable
Expand All @@ -80,7 +125,7 @@ def __init__(self, source, content_type: Optional[str] = None):
self.askAnswer = boolean.text.lower().strip() == "true" # type: ignore[union-attr]


def parseTerm(element):
def parseTerm(element: xml_etree.Element) -> Union[URIRef, Literal, BNode]:
"""rdflib object (Literal, URIRef, BNode) for the given
elementtree element"""
tag, text = element.tag, element.text
Expand All @@ -90,15 +135,17 @@ def parseTerm(element):
datatype = None
lang = None
if element.get("datatype", None):
datatype = URIRef(element.get("datatype"))
# type error: Argument 1 to "URIRef" has incompatible type "Optional[str]"; expected "str"
datatype = URIRef(element.get("datatype")) # type: ignore[arg-type]
elif element.get("{%s}lang" % XML_NAMESPACE, None):
lang = element.get("{%s}lang" % XML_NAMESPACE)

ret = Literal(text, datatype=datatype, lang=lang)

return ret
elif tag == RESULTS_NS_ET + "uri":
return URIRef(text)
# type error: Argument 1 to "URIRef" has incompatible type "Optional[str]"; expected "str"
return URIRef(text) # type: ignore[arg-type]
elif tag == RESULTS_NS_ET + "bnode":
return BNode(text)
else:
Expand All @@ -109,14 +156,14 @@ class XMLResultSerializer(ResultSerializer):
def __init__(self, result):
ResultSerializer.__init__(self, result)

def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs):

def serialize(self, stream: IO, encoding: str = "utf-8", **kwargs: Any) -> None:
writer = SPARQLXMLWriter(stream, encoding)
if self.result.type == "ASK":
writer.write_header([])
writer.write_ask(self.result.askAnswer)
else:
writer.write_header(self.result.vars)
# type error: Argument 1 to "write_header" of "SPARQLXMLWriter" has incompatible type "Optional[List[Variable]]"; expected "Sequence[Variable]"
writer.write_header(self.result.vars) # type: ignore[arg-type]
writer.write_results_header()
for b in self.result.bindings:
writer.write_start_result()
Expand All @@ -134,7 +181,7 @@ class SPARQLXMLWriter:
Python saxutils-based SPARQL XML Writer
"""

def __init__(self, output, encoding="utf-8"):
def __init__(self, output: IO, encoding: str = "utf-8"):
writer = XMLGenerator(output, encoding)
writer.startDocument()
writer.startPrefixMapping("", SPARQL_XML_NAMESPACE)
Expand All @@ -147,7 +194,7 @@ def __init__(self, output, encoding="utf-8"):
self._encoding = encoding
self._results = False

def write_header(self, allvarsL):
def write_header(self, allvarsL: Sequence[Variable]) -> None:
self.writer.startElementNS(
(SPARQL_XML_NAMESPACE, "head"), "head", AttributesNSImpl({}, {})
)
Expand All @@ -161,48 +208,52 @@ def write_header(self, allvarsL):
self.writer.startElementNS(
(SPARQL_XML_NAMESPACE, "variable"),
"variable",
AttributesNSImpl(attr_vals, attr_qnames),
# type error: Argument 1 to "AttributesNSImpl" has incompatible type "Dict[Tuple[None, str], str]"; expected "Mapping[Tuple[str, str], str]"
# type error: Argument 2 to "AttributesNSImpl" has incompatible type "Dict[Tuple[None, str], str]"; expected "Mapping[Tuple[str, str], str]" [arg-type]
AttributesNSImpl(attr_vals, attr_qnames), # type: ignore[arg-type]
)
self.writer.endElementNS((SPARQL_XML_NAMESPACE, "variable"), "variable")
self.writer.endElementNS((SPARQL_XML_NAMESPACE, "head"), "head")

def write_ask(self, val):
def write_ask(self, val: bool) -> None:
self.writer.startElementNS(
(SPARQL_XML_NAMESPACE, "boolean"), "boolean", AttributesNSImpl({}, {})
)
self.writer.characters(str(val).lower())
self.writer.endElementNS((SPARQL_XML_NAMESPACE, "boolean"), "boolean")

def write_results_header(self):
def write_results_header(self) -> None:
self.writer.startElementNS(
(SPARQL_XML_NAMESPACE, "results"), "results", AttributesNSImpl({}, {})
)
self._results = True

def write_start_result(self):
def write_start_result(self) -> None:
self.writer.startElementNS(
(SPARQL_XML_NAMESPACE, "result"), "result", AttributesNSImpl({}, {})
)
self._resultStarted = True

def write_end_result(self):
def write_end_result(self) -> None:
assert self._resultStarted
self.writer.endElementNS((SPARQL_XML_NAMESPACE, "result"), "result")
self._resultStarted = False

def write_binding(self, name, val):
def write_binding(self, name: Variable, val: Identifier):
assert self._resultStarted

attr_vals = {
attr_vals: Dict[Tuple[Optional[str], str], str] = {
(None, "name"): str(name),
}
attr_qnames = {
attr_qnames: Dict[Tuple[Optional[str], str], str] = {
(None, "name"): "name",
}
self.writer.startElementNS(
(SPARQL_XML_NAMESPACE, "binding"),
"binding",
AttributesNSImpl(attr_vals, attr_qnames),
# type error: Argument 1 to "AttributesNSImpl" has incompatible type "Dict[Tuple[None, str], str]"; expected "Mapping[Tuple[str, str], str]"
# type error: Argument 2 to "AttributesNSImpl" has incompatible type "Dict[Tuple[None, str], str]"; expected "Mapping[Tuple[str, str], str]"
AttributesNSImpl(attr_vals, attr_qnames), # type: ignore[arg-type]
)

if isinstance(val, URIRef):
Expand Down Expand Up @@ -230,7 +281,9 @@ def write_binding(self, name, val):
self.writer.startElementNS(
(SPARQL_XML_NAMESPACE, "literal"),
"literal",
AttributesNSImpl(attr_vals, attr_qnames),
# type error: Argument 1 to "AttributesNSImpl" has incompatible type "Dict[Tuple[Optional[str], str], str]"; expected "Mapping[Tuple[str, str], str]"
# type error: Argument 2 to "AttributesNSImpl" has incompatible type "Dict[Tuple[Optional[str], str], str]"; expected "Mapping[Tuple[str, str], str]"
AttributesNSImpl(attr_vals, attr_qnames), # type: ignore[arg-type]
)
self.writer.characters(val)
self.writer.endElementNS((SPARQL_XML_NAMESPACE, "literal"), "literal")
Expand All @@ -240,7 +293,7 @@ def write_binding(self, name, val):

self.writer.endElementNS((SPARQL_XML_NAMESPACE, "binding"), "binding")

def close(self):
def close(self) -> None:
if self._results:
self.writer.endElementNS((SPARQL_XML_NAMESPACE, "results"), "results")
self.writer.endElementNS((SPARQL_XML_NAMESPACE, "sparql"), "sparql")
Expand Down
29 changes: 22 additions & 7 deletions test/test_sparql/test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def make(cls, *result_format: ResultFormat) -> "ResultFormats":
ResultFormatTrait.HAS_PARSER,
ResultFormatTrait.HAS_SERIALIZER,
},
{"utf-8"},
{"utf-8", "utf-16"},
),
ResultFormat(
"tsv",
Expand Down Expand Up @@ -239,7 +239,7 @@ class DestRef:

@contextmanager
def make_dest(
tmp_path: Path, type: Optional[DestinationType]
tmp_path: Path, type: Optional[DestinationType], encoding: str
) -> Iterator[Optional[DestRef]]:
if type is None:
yield None
Expand All @@ -251,7 +251,8 @@ def make_dest(
with path.open("wb") as bfh:
yield DestRef(bfh, path)
elif type is DestinationType.TEXT_IO:
with path.open("w") as fh:
assert encoding is not None
with path.open("w", encoding=encoding) as fh:
yield DestRef(fh, path)
else:
raise ValueError(f"unsupported type {type}")
Expand Down Expand Up @@ -299,6 +300,10 @@ def make_select_result_serialize_parse_tests() -> Iterator[ParameterSet]:
raises=FileNotFoundError,
reason="string path handling does not work on windows",
)
xfails[("xml", DestinationType.STR_PATH, "utf-16")] = pytest.mark.xfail(
raises=FileNotFoundError,
reason="string path handling does not work on windows",
)
formats = [
format
for format in result_formats.values()
Expand Down Expand Up @@ -332,7 +337,7 @@ def test_select_result_serialize_parse(
specific format results in an equivalent result object.
"""
format, destination_type, encoding = args
with make_dest(tmp_path, destination_type) as dest_ref:
with make_dest(tmp_path, destination_type, encoding) as dest_ref:
destination = None if dest_ref is None else dest_ref.param
serialize_result = select_result.serialize(
destination=destination,
Expand All @@ -345,7 +350,8 @@ def test_select_result_serialize_parse(
serialized_data = serialize_result.decode(encoding)
else:
assert serialize_result is None
serialized_data = dest_ref.path.read_bytes().decode(encoding)
dest_bytes = dest_ref.path.read_bytes()
serialized_data = dest_bytes.decode(encoding)

logging.debug("serialized_data = %s", serialized_data)
check_serialized(format.name, select_result, serialized_data)
Expand All @@ -363,7 +369,7 @@ def serialize_select(select_result: Result, format: str, encoding: str) -> bytes
encoding
)
else:
result = select_result.serialize(format=format)
result = select_result.serialize(format=format, encoding=encoding)
assert result is not None
return result

Expand All @@ -377,8 +383,17 @@ def make_select_result_parse_serialized_tests() -> Iterator[ParameterSet]:
and ResultType.SELECT in format.supported_types
]
source_types = set(SourceType)
xfails[("csv", SourceType.BINARY_IO, "utf-16")] = pytest.mark.xfail(
raises=UnicodeDecodeError,
)
xfails[("json", SourceType.BINARY_IO, "utf-16")] = pytest.mark.xfail(
raises=UnicodeDecodeError,
)
xfails[("tsv", SourceType.BINARY_IO, "utf-16")] = pytest.mark.xfail(
raises=UnicodeDecodeError,
)
for format, destination_type in itertools.product(formats, source_types):
for encoding in {"utf-8"}:
for encoding in format.encodings:
xfail = xfails.get((format.name, destination_type, encoding))
marks = (xfail,) if xfail is not None else ()
yield pytest.param(
Expand Down