Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions gemd/entity/attribute/base_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from gemd.entity.file_link import FileLink
from gemd.entity.link_by_uid import LinkByUID

from typing import Optional, Union, Iterable, List, Type
from abc import abstractmethod


class BaseAttribute(DictSerializable):
"""
Expand All @@ -31,8 +34,14 @@ class BaseAttribute(DictSerializable):

"""

def __init__(self, name, *, template=None, origin="unknown", value=None, notes=None,
file_links=None):
def __init__(self,
name: str,
*,
template: Union[AttributeTemplate, LinkByUID, None] = None,
origin: Union[Origin, str] = Origin.UNKNOWN,
value: BaseValue = None,
notes: str = None,
file_links: Optional[Union[Iterable[FileLink], FileLink]] = None):
self.name = name
self.notes = notes

Expand All @@ -47,12 +56,12 @@ def __init__(self, name, *, template=None, origin="unknown", value=None, notes=N
self.file_links = file_links

@property
def value(self):
def value(self) -> BaseValue:
"""Get value."""
return self._value

@value.setter
def value(self, value):
def value(self, value: BaseValue):
if value is None:
self._value = None
elif isinstance(value, (BaseValue, str, bool)):
Expand All @@ -61,36 +70,41 @@ def value(self, value):
raise TypeError("value must be a BaseValue, string or bool: {}".format(value))

@property
def template(self):
def template(self) -> Optional[Union[AttributeTemplate, LinkByUID]]:
"""Get template."""
return self._template

@template.setter
def template(self, template):
def template(self, template: Optional[Union[AttributeTemplate, LinkByUID]]):
if template is None:
self._template = None
elif isinstance(template, (LinkByUID, AttributeTemplate)):
elif isinstance(template, (self._template_type(), LinkByUID)):
self._template = template
else:
raise TypeError("template must be a BaseAttributeTemplate or "
"LinkByUID: {}".format(template))

@staticmethod
@abstractmethod
def _template_type() -> Type:
"""Get the expected type of template for this object (property of child)."""

@property
def origin(self):
def origin(self) -> str:
"""Get origin."""
return self._origin

@origin.setter
def origin(self, origin):
def origin(self, origin: Union[Origin, str]):
if origin is None:
raise ValueError("origin must be specified (but may be `unknown`)")
self._origin = Origin.get_value(origin)

@property
def file_links(self):
def file_links(self) -> List[FileLink]:
"""Get file links."""
return self._file_links

@file_links.setter
def file_links(self, file_links):
def file_links(self, file_links: Optional[Union[Iterable[FileLink], FileLink]]):
self._file_links = validate_list(file_links, FileLink)
7 changes: 7 additions & 0 deletions gemd/entity/attribute/condition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from gemd.entity.attribute.base_attribute import BaseAttribute
from gemd.entity.template import ConditionTemplate

from typing import Type


class Condition(BaseAttribute):
Expand Down Expand Up @@ -29,3 +32,7 @@ class Condition(BaseAttribute):
"""

typ = "condition"

@staticmethod
def _template_type() -> Type:
return ConditionTemplate
7 changes: 7 additions & 0 deletions gemd/entity/attribute/parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from gemd.entity.attribute.base_attribute import BaseAttribute
from gemd.entity.template import ParameterTemplate

from typing import Type


class Parameter(BaseAttribute):
Expand Down Expand Up @@ -30,3 +33,7 @@ class Parameter(BaseAttribute):
"""

typ = "parameter"

@staticmethod
def _template_type() -> Type:
return ParameterTemplate
7 changes: 7 additions & 0 deletions gemd/entity/attribute/property.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from gemd.entity.attribute.base_attribute import BaseAttribute
from gemd.entity.template import PropertyTemplate

from typing import Type


class Property(BaseAttribute):
Expand Down Expand Up @@ -29,3 +32,7 @@ class Property(BaseAttribute):
"""

typ = "property"

@staticmethod
def _template_type() -> Type:
return PropertyTemplate
23 changes: 14 additions & 9 deletions gemd/entity/attribute/property_and_conditions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from gemd.entity.attribute.condition import Condition
from gemd.entity.attribute.property import Property
from gemd.entity.template.property_template import PropertyTemplate
from gemd.entity.value.base_value import BaseValue
from gemd.entity.dict_serializable import DictSerializable
from gemd.entity.link_by_uid import LinkByUID
from gemd.entity.setters import validate_list

from typing import Optional, Union, Iterable, List


class PropertyAndConditions(DictSerializable):
"""
Expand All @@ -21,50 +26,50 @@ class PropertyAndConditions(DictSerializable):

typ = "property_and_conditions"

def __init__(self, property=None, conditions=None):
def __init__(self, property: Property = None, conditions: Iterable[Condition] = None):
self._property = None
self.property = property
self._conditions = None
self.conditions = conditions

@property
def conditions(self):
def conditions(self) -> List[Condition]:
"""Get conditions."""
return self._conditions

@conditions.setter
def conditions(self, conditions):
def conditions(self, conditions: Iterable[Condition]):
self._conditions = validate_list(conditions, Condition)

# Horrible hacks to make templates work in the short term
@property
def name(self):
def name(self) -> str:
"""Get name of attribute (use name of property)."""
return self.property.name

@property
def template(self):
def template(self) -> Optional[Union[PropertyTemplate, LinkByUID]]:
"""Get template of attribute (use template of property)."""
return self.property.template

@property
def origin(self):
def origin(self) -> str:
"""Get origin of attribute (use origin of property)."""
return self.property.origin

@property
def value(self):
def value(self) -> BaseValue:
"""Get value of attribute (use value of property)."""
return self.property.value

# NOTE: this definition must go last, or else it overrides the property decorator
@property
def property(self):
def property(self) -> Property:
"""Get property."""
return self._property

@property.setter
def property(self, value):
def property(self, value: Property):
if isinstance(value, Property):
self._property = value
else:
Expand Down
40 changes: 24 additions & 16 deletions gemd/entity/base_entity.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Base class for all entities."""
from typing import Optional, Dict, FrozenSet
from collections.abc import Collection
from typing import Optional, Union, Iterable, List, Set, FrozenSet, Mapping, Dict

from gemd.entity.dict_serializable import DictSerializable
from gemd.entity.has_dependencies import HasDependencies
from gemd.entity.case_insensitive_dict import CaseInsensitiveDict
from gemd.entity.setters import validate_list


class BaseEntity(DictSerializable):
Expand All @@ -25,42 +26,37 @@ class BaseEntity(DictSerializable):

typ = "base"

def __init__(self, uids, tags):
def __init__(self, uids: Mapping[str, str], tags: Iterable[str]):
self._tags = None
self.tags = tags

self._uids = None
self.uids = uids

@property
def tags(self):
def tags(self) -> List[str]:
"""Get the tags."""
return self._tags

@tags.setter
def tags(self, tags):
if tags is None:
self._tags = []
elif isinstance(tags, list):
self._tags = tags
else:
self._tags = [tags]
def tags(self, tags: Iterable[str]):
self._tags = validate_list(tags, str)

@property
def uids(self):
def uids(self) -> Mapping[str, str]:
"""Get the uids."""
return self._uids

@uids.setter
def uids(self, uids):
def uids(self, uids: Mapping[str, str]):
if uids is None:
self._uids = CaseInsensitiveDict()
elif isinstance(uids, dict):
elif isinstance(uids, Mapping):
self._uids = CaseInsensitiveDict(**uids)
else:
self._uids = CaseInsensitiveDict(**{uids[0]: uids[1]})

def add_uid(self, scope, uid):
def add_uid(self, scope: str, uid: str):
"""
Add a uid.

Expand Down Expand Up @@ -106,6 +102,18 @@ def to_link(self,

return LinkByUID(scope=scope, id=uid)

def all_dependencies(self) -> Set[Union["BaseEntity", "LinkByUID"]]:
"""Return a set of all immediate dependencies (no recursion)."""
result = set()
queue = [type(self)]
while queue:
cls = queue.pop()
if issubclass(cls, HasDependencies) and \
"_local_dependencies" not in cls.__abstractmethods__:
result |= cls._local_dependencies(self)
queue.extend(cls.__bases__)
return result

@staticmethod
def _cached_equals(this: 'BaseEntity',
that: 'BaseEntity',
Expand Down Expand Up @@ -146,7 +154,7 @@ def _cached_equals(this: 'BaseEntity',
if BaseEntity._cached_equals(this_value, that_value, cache=cache) is False:
cache[cache_key] = False # Mark as failed
return False
elif isinstance(this_value, Collection) and isinstance(that_value, Collection) \
elif isinstance(this_value, Iterable) and isinstance(that_value, Iterable) \
and not isinstance(this_value, str) and not isinstance(that_value, str):
# Necessary to maintain context for recursive parts of the structure
this_list = list(this_value)
Expand Down
21 changes: 11 additions & 10 deletions gemd/entity/dict_serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import inspect
import functools
from typing import Union, Iterable, List, Mapping, Dict, Any

# There are some weird (probably resolvable) errors during object cloning if this is an
# instance variable of DictSerializable.
Expand All @@ -17,7 +18,7 @@ class DictSerializable(ABC):
skip = set()

@classmethod
def from_dict(cls, d):
def from_dict(cls, d: Mapping[str, Any]) -> "DictSerializable":
"""
Reconstitute the object from a dictionary.

Expand Down Expand Up @@ -47,13 +48,13 @@ def from_dict(cls, d):

@classmethod
@functools.lru_cache(maxsize=None)
def _init_sig(cls):
def _init_sig(cls) -> List[str]:
"""Internal method for generating the argument names for the class init method."""
expected_arg_names = inspect.getfullargspec(cls.__init__).args
expected_arg_names += inspect.getfullargspec(cls.__init__).kwonlyargs
return expected_arg_names

def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
"""
Convert the object to a dictionary.

Expand All @@ -68,7 +69,7 @@ def as_dict(self):
attributes["type"] = self.typ
return attributes

def dump(self):
def dump(self) -> str:
"""
Convert the object to a JSON dictionary, so that every entry is serialized.

Expand All @@ -85,7 +86,7 @@ def dump(self):
return json.loads(encoder.raw_dumps(self))

@staticmethod
def build(d):
def build(d: Mapping[str, Any]) -> "DictSerializable":
"""
Build an object from a JSON dictionary.

Expand All @@ -107,7 +108,7 @@ def build(d):
encoder = GEMDJson()
return encoder.raw_loads(encoder.raw_dumps(d))

def __repr__(self):
def __repr__(self) -> str:
object_dict = self.as_dict()
# as_dict() skips over keys in `skip`, but they should be in the representation.
skipped_keys = {x.lstrip('_') for x in vars(self) if x in self.skip}
Expand All @@ -116,7 +117,7 @@ def __repr__(self):
object_dict[key] = self._name_repr(skipped_field)
return str(object_dict)

def _name_repr(self, entity):
def _name_repr(self, entity: Union[Iterable["DictSerializable"], "DictSerializable"]) -> str:
"""
A representation of an object or a list of objects that uses the name and type.

Expand All @@ -135,15 +136,15 @@ def _name_repr(self, entity):
A representation of `entity` using its name.

"""
if isinstance(entity, (list, tuple)):
if isinstance(entity, Iterable):
return [self._name_repr(item) for item in entity]
elif entity is None:
return None
else:
name = getattr(entity, 'name', '<unknown name>')
return "<{} '{}'>".format(type(entity).__name__, name)
return f"<{type(entity).__name__} '{name}'>"

def _dict_for_compare(self):
def _dict_for_compare(self) -> Dict[str, Any]:
"""Which fields & values are relevant to an equality test."""
return self.as_dict()

Expand Down
Loading