Skip to content

Commit

Permalink
Merge pull request #1449 from iafork/iwana-20211016T2236-more_typing
Browse files Browse the repository at this point in the history
Add type hints
  • Loading branch information
nicholascar committed Dec 1, 2021
2 parents 690312a + 2e812af commit 1ae579d
Show file tree
Hide file tree
Showing 18 changed files with 340 additions and 144 deletions.
164 changes: 108 additions & 56 deletions rdflib/graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from typing import Optional, Union, Type, cast, overload, Generator, Tuple
from typing import (
IO,
Any,
Iterable,
Optional,
Union,
Type,
cast,
overload,
Generator,
Tuple,
)
import logging
from warnings import warn
import random
Expand All @@ -21,7 +32,7 @@
import tempfile
import pathlib

from io import BytesIO, BufferedIOBase
from io import BytesIO
from urllib.parse import urlparse

assert Literal # avoid warning
Expand Down Expand Up @@ -313,15 +324,19 @@ class Graph(Node):
"""

def __init__(
self, store="default", identifier=None, namespace_manager=None, base=None
self,
store: Union[Store, str] = "default",
identifier: Optional[Union[Node, str]] = None,
namespace_manager: Optional[NamespaceManager] = None,
base: Optional[str] = None,
):
super(Graph, self).__init__()
self.base = base
self.__identifier = identifier or BNode()

self.__identifier: Node
self.__identifier = identifier or BNode() # type: ignore[assignment]
if not isinstance(self.__identifier, Node):
self.__identifier = URIRef(self.__identifier)

self.__identifier = URIRef(self.__identifier) # type: ignore[unreachable]
self.__store: Store
if not isinstance(store, Store):
# TODO: error handling
self.__store = store = plugin.get(store, Store)()
Expand Down Expand Up @@ -404,7 +419,7 @@ def close(self, commit_pending_transaction=False):
"""
return self.__store.close(commit_pending_transaction=commit_pending_transaction)

def add(self, triple):
def add(self, triple: Tuple[Node, Node, Node]):
"""Add a triple with self as context"""
s, p, o = triple
assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,)
Expand All @@ -413,7 +428,7 @@ def add(self, triple):
self.__store.add((s, p, o), self, quoted=False)
return self

def addN(self, quads):
def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]):
"""Add a sequence of triple with context"""

self.__store.addN(
Expand All @@ -434,7 +449,9 @@ def remove(self, triple):
self.__store.remove(triple, context=self)
return self

def triples(self, triple):
def triples(
self, triple: Tuple[Optional[Node], Union[None, Path, Node], Optional[Node]]
):
"""Generator over the triple store
Returns triples that match the given triple pattern. If triple pattern
Expand Down Expand Up @@ -652,17 +669,17 @@ def set(self, triple):
self.add((subject, predicate, object_))
return self

def subjects(self, predicate=None, object=None):
def subjects(self, predicate=None, object=None) -> Iterable[Node]:
"""A generator of subjects with the given predicate and object"""
for s, p, o in self.triples((None, predicate, object)):
yield s

def predicates(self, subject=None, object=None):
def predicates(self, subject=None, object=None) -> Iterable[Node]:
"""A generator of predicates with the given subject and object"""
for s, p, o in self.triples((subject, None, object)):
yield p

def objects(self, subject=None, predicate=None):
def objects(self, subject=None, predicate=None) -> Iterable[Node]:
"""A generator of objects with the given subject and predicate"""
for s, p, o in self.triples((subject, predicate, None)):
yield o
Expand Down Expand Up @@ -1019,45 +1036,32 @@ def serialize(
@overload
def serialize(
self,
*,
destination: None = ...,
format: str = ...,
base: Optional[str] = ...,
*,
encoding: str,
**args,
) -> bytes:
...

# no destination and None positional encoding
@overload
def serialize(
self,
destination: None,
format: str,
base: Optional[str],
encoding: None,
**args,
) -> str:
...

# no destination and None keyword encoding
# no destination and None encoding
@overload
def serialize(
self,
*,
destination: None = ...,
format: str = ...,
base: Optional[str] = ...,
encoding: None = None,
encoding: None = ...,
**args,
) -> str:
...

# non-none destination
# non-None destination
@overload
def serialize(
self,
destination: Union[str, BufferedIOBase, pathlib.PurePath],
destination: Union[str, pathlib.PurePath, IO[bytes]],
format: str = ...,
base: Optional[str] = ...,
encoding: Optional[str] = ...,
Expand All @@ -1069,21 +1073,21 @@ def serialize(
@overload
def serialize(
self,
destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None,
format: str = "turtle",
base: Optional[str] = None,
encoding: Optional[str] = None,
destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = ...,
format: str = ...,
base: Optional[str] = ...,
encoding: Optional[str] = ...,
**args,
) -> Union[bytes, str, "Graph"]:
...

def serialize(
self,
destination: Union[str, BufferedIOBase, pathlib.PurePath, None] = None,
destination: Optional[Union[str, pathlib.PurePath, IO[bytes]]] = None,
format: str = "turtle",
base: Optional[str] = None,
encoding: Optional[str] = None,
**args,
**args: Any,
) -> Union[bytes, str, "Graph"]:
"""Serialize the Graph to destination
Expand All @@ -1104,7 +1108,7 @@ def serialize(
base = self.base

serializer = plugin.get(format, Serializer)(self)
stream: BufferedIOBase
stream: IO[bytes]
if destination is None:
stream = BytesIO()
if encoding is None:
Expand All @@ -1114,7 +1118,7 @@ def serialize(
serializer.serialize(stream, base=base, encoding=encoding, **args)
return stream.getvalue()
if hasattr(destination, "write"):
stream = cast(BufferedIOBase, destination)
stream = cast(IO[bytes], destination)
serializer.serialize(stream, base=base, encoding=encoding, **args)
else:
if isinstance(destination, pathlib.PurePath):
Expand Down Expand Up @@ -1149,10 +1153,10 @@ def parse(
self,
source=None,
publicID=None,
format=None,
format: Optional[str] = None,
location=None,
file=None,
data=None,
data: Optional[Union[str, bytes, bytearray]] = None,
**args,
):
"""
Expand Down Expand Up @@ -1249,7 +1253,8 @@ def parse(
could_not_guess_format = True
parser = plugin.get(format, Parser)()
try:
parser.parse(source, self, **args)
# TODO FIXME: Parser.parse should have **kwargs argument.
parser.parse(source, self, **args) # type: ignore[call-arg]
except SyntaxError as se:
if could_not_guess_format:
raise ParserError(
Expand Down Expand Up @@ -1537,7 +1542,12 @@ class ConjunctiveGraph(Graph):
All queries are carried out against the union of all graphs.
"""

def __init__(self, store="default", identifier=None, default_graph_base=None):
def __init__(
self,
store: Union[Store, str] = "default",
identifier: Optional[Union[Node, str]] = None,
default_graph_base: Optional[str] = None,
):
super(ConjunctiveGraph, self).__init__(store, identifier=identifier)
assert self.store.context_aware, (
"ConjunctiveGraph must be backed by" " a context aware store."
Expand All @@ -1555,7 +1565,31 @@ def __str__(self):
)
return pattern % self.store.__class__.__name__

def _spoc(self, triple_or_quad, default=False):
@overload
def _spoc(
self,
triple_or_quad: Union[
Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]
],
default: bool = False,
) -> Tuple[Node, Node, Node, Optional[Graph]]:
...

@overload
def _spoc(
self,
triple_or_quad: None,
default: bool = False,
) -> Tuple[None, None, None, Optional[Graph]]:
...

def _spoc(
self,
triple_or_quad: Optional[
Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]]
],
default: bool = False,
) -> Tuple[Optional[Node], Optional[Node], Optional[Node], Optional[Graph]]:
"""
helper method for having methods that support
either triples or quads
Expand All @@ -1564,9 +1598,9 @@ def _spoc(self, triple_or_quad, default=False):
return (None, None, None, self.default_context if default else None)
if len(triple_or_quad) == 3:
c = self.default_context if default else None
(s, p, o) = triple_or_quad
(s, p, o) = triple_or_quad # type: ignore[misc]
elif len(triple_or_quad) == 4:
(s, p, o, c) = triple_or_quad
(s, p, o, c) = triple_or_quad # type: ignore[misc]
c = self._graph(c)
return s, p, o, c

Expand All @@ -1577,7 +1611,7 @@ def __contains__(self, triple_or_quad):
return True
return False

def add(self, triple_or_quad):
def add(self, triple_or_quad: Union[Tuple[Node, Node, Node, Optional[Any]], Tuple[Node, Node, Node]]) -> "ConjunctiveGraph": # type: ignore[override]
"""
Add a triple or quad to the store.
Expand All @@ -1591,15 +1625,23 @@ def add(self, triple_or_quad):
self.store.add((s, p, o), context=c, quoted=False)
return self

def _graph(self, c):
@overload
def _graph(self, c: Union[Graph, Node, str]) -> Graph:
...

@overload
def _graph(self, c: None) -> None:
...

def _graph(self, c: Optional[Union[Graph, Node, str]]) -> Optional[Graph]:
if c is None:
return None
if not isinstance(c, Graph):
return self.get_context(c)
else:
return c

def addN(self, quads):
def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]):
"""Add a sequence of triples with context"""

self.store.addN(
Expand Down Expand Up @@ -1689,13 +1731,19 @@ def contexts(self, triple=None):
else:
yield self.get_context(context)

def get_context(self, identifier, quoted=False, base=None):
def get_context(
self,
identifier: Optional[Union[Node, str]],
quoted: bool = False,
base: Optional[str] = None,
) -> Graph:
"""Return a context graph for the given identifier
identifier must be a URIRef or BNode.
"""
# TODO: FIXME - why is ConjunctiveGraph passed as namespace_manager?
return Graph(
store=self.store, identifier=identifier, namespace_manager=self, base=base
store=self.store, identifier=identifier, namespace_manager=self, base=base # type: ignore[arg-type]
)

def remove_context(self, context):
Expand Down Expand Up @@ -1747,6 +1795,7 @@ def parse(
context = Graph(store=self.store, identifier=g_id)
context.remove((None, None, None)) # hmm ?
context.parse(source, publicID=publicID, format=format, **args)
# TODO: FIXME: This should not return context, but self.
return context

def __reduce__(self):
Expand Down Expand Up @@ -1977,7 +2026,7 @@ class QuotedGraph(Graph):
def __init__(self, store, identifier):
super(QuotedGraph, self).__init__(store, identifier)

def add(self, triple):
def add(self, triple: Tuple[Node, Node, Node]):
"""Add a triple with self as context"""
s, p, o = triple
assert isinstance(s, Node), "Subject %s must be an rdflib term" % (s,)
Expand All @@ -1987,7 +2036,7 @@ def add(self, triple):
self.store.add((s, p, o), self, quoted=True)
return self

def addN(self, quads):
def addN(self, quads: Tuple[Node, Node, Node, Any]) -> "QuotedGraph": # type: ignore[override]
"""Add a sequence of triple with context"""

self.store.addN(
Expand Down Expand Up @@ -2261,7 +2310,7 @@ class BatchAddGraph(object):
"""

def __init__(self, graph, batch_size=1000, batch_addn=False):
def __init__(self, graph: Graph, batch_size: int = 1000, batch_addn: bool = False):
if not batch_size or batch_size < 2:
raise ValueError("batch_size must be a positive number")
self.graph = graph
Expand All @@ -2278,7 +2327,10 @@ def reset(self):
self.count = 0
return self

def add(self, triple_or_quad):
def add(
self,
triple_or_quad: Union[Tuple[Node, Node, Node], Tuple[Node, Node, Node, Any]],
) -> "BatchAddGraph":
"""
Add a triple to the buffer
Expand All @@ -2294,7 +2346,7 @@ def add(self, triple_or_quad):
self.batch.append(triple_or_quad)
return self

def addN(self, quads):
def addN(self, quads: Iterable[Tuple[Node, Node, Node, Any]]):
if self.__batch_addn:
for q in quads:
self.add(q)
Expand Down

0 comments on commit 1ae579d

Please sign in to comment.