diff --git a/.travis.yml b/.travis.yml index aa18042f..8788a5a9 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,5 @@ language: python -dist: bionic +dist: focal python: - 2.7 diff --git a/gdcdatamodel/models/__init__.py b/gdcdatamodel/models/__init__.py index edd9bf94..41c74196 100644 --- a/gdcdatamodel/models/__init__.py +++ b/gdcdatamodel/models/__init__.py @@ -17,7 +17,7 @@ import os import sys -from sqlalchemy.orm.attributes import flag_modified +import six try: from functools import lru_cache @@ -44,16 +44,9 @@ versioning, ) -from sqlalchemy import ( - event, - and_ -) +from sqlalchemy import event, and_ -from psqlgraph import ( - Node, - Edge, - pg_property -) +from psqlgraph import Node, Edge, pg_property from psqlgraph import ext @@ -76,15 +69,15 @@ cls_add_indexes, get_secondary_key_indexes, ) -from gdcdatamodel.models.misc import FileReport # noqa +from gdcdatamodel.models.misc import FileReport # noqa from gdcdatamodel.models.versioned_nodes import VersionedNode # noqa from gdcdatamodel.models.utils import py3_to_bytes -logger = logging.getLogger('gdcdatamodel') +logger = logging.getLogger("gdcdatamodel") # These are properties that are defined outside of the JSONB column in # the database, inform later code to skip these -excluded_props = ['id', 'type'] +excluded_props = ["id", "type"] def remove_spaces(s): @@ -93,7 +86,7 @@ def remove_spaces(s): :param str s: String to remove spaces from """ - return s.replace(' ', '') + return s.replace(" ", "") def get_cls_package(package_namespace=None): @@ -119,7 +112,7 @@ def register_class(cls, package_namespace=None): pkg = ModuleType(m) sys.modules[m] = pkg globals()[package_namespace] = pkg - setattr(pkg, cls.__name__, cls) + setattr(pkg, cls.__name__, cls) else: globals()[cls.__name__] = cls @@ -133,54 +126,55 @@ def get_links(schema): """ links = {} - for entry in schema.get('links') or []: - if 'subgroup' in entry: - for link in entry['subgroup']: - links[link['name']] = link + for entry in schema.get("links") or []: + if "subgroup" in entry: + for link in entry["subgroup"]: + links[link["name"]] = link else: - links[entry['name']] = entry + links[entry["name"]] = entry return links def types_from_str(types): - return [a for type_ in types for a in { - 'string': [str], - 'number': [float, int], - 'integer': [int], - 'float': [float], - 'null': [str], - 'boolean': [bool], - 'array': [list], - None: [str], - }[type_]] + return [ + a + for type_ in types + for a in { + "string": [str], + "number": [float, int], + "integer": [int], + "float": [float], + "null": [str], + "boolean": [bool], + "array": [list], + None: [str], + }[type_] + ] def PropertyFactory(name, schema, key=None): - """Returns a pg_property (psqlgraph specific type of hybrid_property) - - """ + """Returns a pg_property (psqlgraph specific type of hybrid_property)""" key = name if key is None else key # Assert the dictionary has no references for properties - assert '$ref' not in schema.keys(), ( + assert "$ref" not in schema.keys(), ( "Found a JSON reference in dictionary. These should be resolved " - "at gdcdictionary module load time as of 2016-02-24") + "at gdcdictionary module load time as of 2016-02-24" + ) # None is the default for a schema type. types = None - if schema.get('oneOf'): + if schema.get("oneOf"): # We will handle an empty list after the oneOf/types field checks. # If it really an empty list, then use None as a value. types = [ - oneOf['type'] - for oneOf in schema['oneOf'] - if oneOf.get('type') + oneOf["type"] for oneOf in schema["oneOf"] if oneOf.get("type") ] or None # If there's both overwrite the 'oneOf' field in favor of the 'type' field. - if schema.get('type'): + if schema.get("type"): # Lookup property type and coerce to list - types = schema.get('type') + types = schema.get("type") # If None is all we have left over, then turn it into a list of None. types = [types] if not isinstance(types, list) else types @@ -189,23 +183,24 @@ def PropertyFactory(name, schema, key=None): python_types = types_from_str(types) # If there is an enum defined, grab it for pg_property validation - enum = schema.get('enum') + enum = schema.get("enum") # Create pg_property setter @pg_property(*python_types, enum=enum) def setter(self, val): self._set_property(key, val) + setter.__name__ = name return setter def get_class_name_from_id(_id): - return ''.join([a.capitalize() for a in _id.split('_')]) + return "".join([a.capitalize() for a in _id.split("_")]) def get_class_tablename_from_id(_id): - return 'node_{}'.format(_id.replace('_', '')) + return "node_{}".format(_id.replace("_", "")) def cls_inject_versioned_nodes_lookup(cls): @@ -226,36 +221,36 @@ def _versions(self): session = self.get_session() if not session: raise RuntimeError( - '{} not bound to a session. Try .get_versions(session).' - .format(self)) + "{} not bound to a session. Try .get_versions(session).".format(self) + ) return self.get_versions(session) def get_versions(self, session): - """Returns a query for node versions given a session. + """Returns a query for node versions given a session.""" - """ - - return session.query(VersionedNode)\ - .filter(VersionedNode.node_id == self.node_id)\ - .filter(VersionedNode.label == self.label)\ - .order_by(VersionedNode.key.desc()) + return ( + session.query(VersionedNode) + .filter(VersionedNode.node_id == self.node_id) + .filter(VersionedNode.label == self.label) + .order_by(VersionedNode.key.desc()) + ) cls._versions = _versions cls.get_versions = get_versions -def cls_inject_created_datetime_hook(cls, - updated_key="updated_datetime", - created_key="created_datetime"): +def cls_inject_created_datetime_hook( + cls, updated_key="updated_datetime", created_key="created_datetime" +): """Given a class, inject a SQLAlchemy hook that will write the timestamp of the last session flush to the :param:`updated_key` and :param:`created_key` properties. """ - @event.listens_for(cls, 'before_insert') + @event.listens_for(cls, "before_insert") def set_created_updated_datetimes(mapper, connection, target): - ts = target.get_session()._flush_timestamp.isoformat('T') + ts = target.get_session()._flush_timestamp.isoformat("T") if updated_key in target.props: target._props[updated_key] = ts if created_key in target.props: @@ -269,13 +264,13 @@ def cls_inject_updated_datetime_hook(cls, updated_key="updated_datetime"): """ - @event.listens_for(cls, 'before_update') + @event.listens_for(cls, "before_update") def set_updated_datetimes(mapper, connection, target): # SQLAlchemy fires this event when associations change, but we should # only adjust the timestamp if the object itself was modified. target_session = target.get_session() if target_session.is_modified(target, include_collections=False): - ts = target_session._flush_timestamp.isoformat('T') + ts = target_session._flush_timestamp.isoformat("T") if updated_key in target.props: target._props[updated_key] = ts @@ -298,33 +293,27 @@ def cls_inject_secondary_keys(cls, schema): """ - unique_keys = schema.get('uniqueKeys', []) - cls.__pg_secondary_keys = [ - keys for keys in unique_keys if 'id' not in keys - ] + unique_keys = schema.get("uniqueKeys", []) + cls.__pg_secondary_keys = [keys for keys in unique_keys if "id" not in keys] class SecondaryKeyComparator(Comparator): def __eq__(self, other): filters = [] cls = self.__clause_element__() - secondary_keys = getattr(cls, '__pg_secondary_keys', []) + secondary_keys = getattr(cls, "__pg_secondary_keys", []) for keys, values in zip(secondary_keys, other): - if 'id' in keys: + if "id" in keys: continue - other = { - key: val - for key, val - in zip(keys, values) - } + other = {key: val for key, val in zip(keys, values)} filters.append(cls._props.contains(other)) return and_(*filters) @property def _secondary_keys_dicts(self): vals = [] - secondary_keys = getattr(self, '__pg_secondary_keys', []) + secondary_keys = getattr(self, "__pg_secondary_keys", []) for keys in secondary_keys: - if 'id' in keys: + if "id" in keys: continue vals.append({key: getattr(self, key, None) for key in keys}) return vals @@ -332,7 +321,7 @@ def _secondary_keys_dicts(self): @hybrid_property def _secondary_keys(self): vals = [] - for keys in getattr(self, '__pg_secondary_keys', []): + for keys in getattr(self, "__pg_secondary_keys", []): vals.append(tuple(getattr(self, key) for key in keys)) return tuple(vals) @@ -349,14 +338,13 @@ def _secondary_keys(cls): def NodeFactory(_id, schema, node_cls=Node, package_namespace=None): - """Returns a node class given a schema. - - """ + """Returns a node class given a schema.""" name = get_class_name_from_id(_id) links = get_links(schema) tag_props = schema.get("tagProperties") + tag_config = schema.get("tagBuilderConfig", {}) @property def node_id(self, value): @@ -375,6 +363,41 @@ def tag_properties(self): """ return tag_props + @property + def tag_builder_config(self): + """Tagging configuration instance used for checking if node instance participates in versioning + + Returns: + versioning.TaggingConfig: config instance + """ + return versioning.TagBuilderConfig(cfg=tag_config) + + def is_taggable(self): + """Returns True if a node instance can be tagged and versioned + Returns: + bool: True if node can be tagged + """ + + return self.tag_builder_config.is_taggable(self) + + def get_tag_property_values(self): + """Values that are used for computing the node's tag + Returns: + list[str]: List of property values + """ + keys = [] + for prop in self.tag_properties: + property_val = self[prop] + + if not property_val: + raise ValueError( + "Property {0} must have a value on instance {1} for tagging to proceed".format( + prop, self + ) + ) + keys.append(six.ensure_str(property_val)) + return keys + @property def is_latest(self): """latest version of the node based on tagging @@ -407,67 +430,70 @@ def tag(self): # Pull the JSONB properties from the `properties` key attributes = { key: PropertyFactory(key, schema) - for key, schema in schema.get('properties', {}).items() - if key not in links - and key not in excluded_props + for key, schema in schema.get("properties", {}).items() + if key not in links and key not in excluded_props } skipped_dict_vals = [ - '$schema', - 'systemProperties', - 'additionalProperties', - 'links', - 'properties', - 'uniqueKeys', - 'id' , - 'tagProperties' + "$schema", + "systemProperties", + "additionalProperties", + "links", + "properties", + "uniqueKeys", + "id", + "tagProperties", + "tagBuilderConfig", ] - attributes['_dictionary'] = { + attributes["_dictionary"] = { key: schema[key] for key in schema if key not in skipped_dict_vals } # _defaults: default value for specified fields in the dictionary - attributes['_defaults'] = { - key: values['default'] - for key, values in schema.get('properties', {}).items() - if 'default' in values + attributes["_defaults"] = { + key: values["default"] + for key, values in schema.get("properties", {}).items() + if "default" in values } # _pg_links are out_edges, links TO other types - attributes['_pg_links'] = {} + attributes["_pg_links"] = {} # _pg_backrefs are in_edges, links FROM other types - attributes['_pg_backrefs'] = {} + attributes["_pg_backrefs"] = {} # _pg_edges are all edges, links to AND from other types - attributes['_pg_edges'] = {} + attributes["_pg_edges"] = {} # _related_cases_from_parents: get ids of related cases from this # nodes's sysan - attributes['_related_cases_from_cache'] = property( - related_cases_from_cache - ) + attributes["_related_cases_from_cache"] = property(related_cases_from_cache) if tag_props: attributes[versioning.TagKeys.tag] = tag attributes[versioning.TagKeys.version] = ver attributes["tag_properties"] = tag_properties attributes["is_latest"] = is_latest + attributes["tag_builder_config"] = tag_builder_config + attributes["is_taggable"] = is_taggable + attributes["get_tag_property_values"] = get_tag_property_values # _related_cases_from_parents: get ids of related cases from this # nodes parents - attributes['_related_cases_from_parents'] = property( - related_cases_from_parents - ) + attributes["_related_cases_from_parents"] = property(related_cases_from_parents) # Create the Node subclass! - cls = type(name, (node_cls,), dict( - __module__=get_cls_package(package_namespace), - __tablename__=get_class_tablename_from_id(_id), - __label__=_id, - id=node_id, - **attributes - )) + cls = type( + name, + (node_cls,), + dict( + __module__=get_cls_package(package_namespace), + __tablename__=get_class_tablename_from_id(_id), + __label__=_id, + id=node_id, + **attributes + ), + ) cls_inject_created_datetime_hook(cls) cls_inject_updated_datetime_hook(cls) @@ -496,36 +522,42 @@ def generate_edge_tablename(src_label, label, dst_label): """ - tablename = 'edge_{}{}{}'.format( - src_label.replace('_', ''), - label.replace('_', ''), - dst_label.replace('_', ''), + tablename = "edge_{}{}{}".format( + src_label.replace("_", ""), + label.replace("_", ""), + dst_label.replace("_", ""), ) # If the name is too long, prepend it with the first 8 hex of it's hash # truncate the each part of the name if len(tablename) > 40: oldname = tablename - logger.debug('Edge tablename {} too long, shortening'.format(oldname)) - tablename = 'edge_{}_{}'.format( + logger.debug("Edge tablename {} too long, shortening".format(oldname)) + tablename = "edge_{}_{}".format( hashlib.md5(py3_to_bytes(tablename)).hexdigest()[:8], "{}{}{}".format( - ''.join([a[:2] for a in src_label.split('_')])[:10], - ''.join([a[:2] for a in label.split('_')])[:7], - ''.join([a[:2] for a in dst_label.split('_')])[:10], - ) + "".join([a[:2] for a in src_label.split("_")])[:10], + "".join([a[:2] for a in label.split("_")])[:7], + "".join([a[:2] for a in dst_label.split("_")])[:10], + ), ) - logger.debug('Shortening {} -> {}'.format(oldname, tablename)) + logger.debug("Shortening {} -> {}".format(oldname, tablename)) return tablename -def EdgeFactory(name, label, src_label, dst_label, src_dst_assoc, - dst_src_assoc, - node_cls=Node, - edge_cls=Edge, - package_namespace=None, - _assigned_association_proxies=defaultdict(lambda: defaultdict(set))): +def EdgeFactory( + name, + label, + src_label, + dst_label, + src_dst_assoc, + dst_src_assoc, + node_cls=Node, + edge_cls=Edge, + package_namespace=None, + _assigned_association_proxies=defaultdict(lambda: defaultdict(set)), +): """Returns an edge class. :param name: The name of the edge class. @@ -566,14 +598,22 @@ def EdgeFactory(name, label, src_label, dst_label, src_dst_assoc, # Assert that we're not clobbering link names assoc_proxy_key = package_namespace - assert dst_src_assoc not in _assigned_association_proxies[assoc_proxy_key][dst_label], ( + assert ( + dst_src_assoc not in _assigned_association_proxies[assoc_proxy_key][dst_label] + ), ( "Attempted to assign backref '{link}' to node '{node}' but " - "the node already has an attribute called '{link}'" - .format(link=dst_src_assoc, node=dst_label)) - assert src_dst_assoc not in _assigned_association_proxies[assoc_proxy_key][src_label], ( + "the node already has an attribute called '{link}'".format( + link=dst_src_assoc, node=dst_label + ) + ) + assert ( + src_dst_assoc not in _assigned_association_proxies[assoc_proxy_key][src_label] + ), ( "Attempted to assign link '{link}' to node '{node}' but " - "the node already has an attribute called '{link}'" - .format(link=src_dst_assoc, node=src_label)) + "the node already has an attribute called '{link}'".format( + link=src_dst_assoc, node=src_label + ) + ) # Remember that we're adding this link and this backref _assigned_association_proxies[assoc_proxy_key][dst_label].add(dst_src_assoc) @@ -591,20 +631,24 @@ def EdgeFactory(name, label, src_label, dst_label, src_dst_assoc, cache_related_cases_on_delete, ] - cls = type(name, (edge_cls,), { - '__module__': cls_package, - '__label__': label, - '__tablename__': tablename, - '__src_class__': get_class_name_from_id(src_label), - '__dst_class__': get_class_name_from_id(dst_label), - '__src_dst_assoc__': src_dst_assoc, - '__dst_src_assoc__': dst_src_assoc, - '__src_table__': src_cls.__tablename__, - '__dst_table__': dst_cls.__tablename__, - '_session_hooks_before_insert': hooks_before_insert, - '_session_hooks_before_update': hooks_before_update, - '_session_hooks_before_delete': hooks_before_delete, - }) + cls = type( + name, + (edge_cls,), + { + "__module__": cls_package, + "__label__": label, + "__tablename__": tablename, + "__src_class__": get_class_name_from_id(src_label), + "__dst_class__": get_class_name_from_id(dst_label), + "__src_dst_assoc__": src_dst_assoc, + "__dst_src_assoc__": dst_src_assoc, + "__src_table__": src_cls.__tablename__, + "__dst_table__": dst_cls.__tablename__, + "_session_hooks_before_insert": hooks_before_insert, + "_session_hooks_before_update": hooks_before_update, + "_session_hooks_before_delete": hooks_before_delete, + }, + ) edge_cls.add_subclass(cls) register_class(cls, package_namespace) @@ -620,47 +664,50 @@ def load_nodes(dictionary, node_cls=None, package_namespace=None): """ node_cls = node_cls or Node for entity, subschema in dictionary.schema.items(): - _id = subschema['id'] + _id = subschema["id"] name = get_class_name_from_id(_id) if not node_cls.is_subclass_loaded(name): try: cls = NodeFactory(_id, subschema, node_cls, package_namespace) register_class(cls, package_namespace) except Exception: - print('Unable to load {}'.format(name)) + print("Unable to load {}".format(name)) raise -def parse_edge(src_label, - name, - edge_label, - subschema, - link, - dictionary, - node_cls=Node, - edge_cls=Edge, - package_namespace=None): +def parse_edge( + src_label, + name, + edge_label, + subschema, + link, + dictionary, + node_cls=Node, + edge_cls=Edge, + package_namespace=None, +): """Parse an edge from the dictionary and create and Edge subclass :returns: The outbound name of the edge """ - dst_label = link['target_type'] - backref = link['backref'] + dst_label = link["target_type"] + backref = link["backref"] - src_label = subschema['id'] + src_label = subschema["id"] if dst_label not in dictionary.schema: raise RuntimeError( - "Destination '{}' for edge '{}' from '{}' not defined" - .format(dst_label, name, src_label)) + "Destination '{}' for edge '{}' from '{}' not defined".format( + dst_label, name, src_label + ) + ) - dst_label = dictionary.schema[dst_label]['id'] - edge_name = ''.join(map(get_class_name_from_id, [ - src_label, edge_label, dst_label])) + dst_label = dictionary.schema[dst_label]["id"] + edge_name = "".join(map(get_class_name_from_id, [src_label, edge_label, dst_label])) if edge_cls.is_subclass_loaded(name): - return '_{}_out'.format(edge_name) + return "_{}_out".format(edge_name) edge = EdgeFactory( edge_name, @@ -671,10 +718,10 @@ def parse_edge(src_label, backref, node_cls=node_cls, edge_cls=edge_cls, - package_namespace=package_namespace + package_namespace=package_namespace, ) - return '_{}_out'.format(edge.__name__) + return "_{}_out".format(edge.__name__) def load_edges(dictionary, node_cls=Node, edge_cls=Edge, package_namespace=None): @@ -688,45 +735,48 @@ def load_edges(dictionary, node_cls=Node, edge_cls=Edge, package_namespace=None) src_cls = node_cls.get_subclass(src_label) if not src_cls: - raise RuntimeError('No source class labeled {}'.format(src_label)) + raise RuntimeError("No source class labeled {}".format(src_label)) for name, link in get_links(subschema).items(): - edge_label = link['label'] + edge_label = link["label"] edge_name = parse_edge( - src_label, name, edge_label, subschema, link, + src_label, + name, + edge_label, + subschema, + link, dictionary=dictionary, node_cls=node_cls, edge_cls=edge_cls, package_namespace=package_namespace, ) - src_cls._pg_links[link['name']] = { - 'edge_out': edge_name, - 'dst_type': node_cls.get_subclass(link['target_type']) + src_cls._pg_links[link["name"]] = { + "edge_out": edge_name, + "dst_type": node_cls.get_subclass(link["target_type"]), } for src_cls in node_cls.get_subclasses(): - cache_case = ( - not src_cls._dictionary['category'] in NOT_RELATED_CASES_CATEGORIES - or src_cls.get_label() in ['annotation'] - ) + cache_case = not src_cls._dictionary[ + "category" + ] in NOT_RELATED_CASES_CATEGORIES or src_cls.get_label() in ["annotation"] if not cache_case: continue link = { - 'name': RELATED_CASES_LINK_NAME, - 'multiplicity': 'many_to_one', - 'required': False, - 'target_type': 'case', - 'label': 'relates_to', - 'backref': '_related_{}'.format(src_cls.label), + "name": RELATED_CASES_LINK_NAME, + "multiplicity": "many_to_one", + "required": False, + "target_type": "case", + "label": "relates_to", + "backref": "_related_{}".format(src_cls.label), } parse_edge( src_cls.label, - link['name'], - 'relates_to', - {'id': src_cls.label}, + link["name"], + "relates_to", + {"id": src_cls.label}, link, dictionary=dictionary, node_cls=node_cls, @@ -745,10 +795,10 @@ def inject_pg_backrefs(dictionary, node_cls): for src_label, subschema in dictionary.schema.items(): for name, link in get_links(subschema).items(): - dst_cls = node_cls.get_subclass(link['target_type']) - dst_cls._pg_backrefs[link['backref']] = { - 'name': link['name'], - 'src_type': node_cls.get_subclass(src_label) + dst_cls = node_cls.get_subclass(link["target_type"]) + dst_cls._pg_backrefs[link["backref"]] = { + "name": link["name"], + "src_type": node_cls.get_subclass(src_label), } @@ -766,8 +816,8 @@ def find_backref(link, src_cls): """ - for prop, backref in link['dst_type']._pg_backrefs.items(): - if backref['src_type'] == cls: + for prop, backref in link["dst_type"]._pg_backrefs.items(): + if backref["src_type"] == cls: return prop def cls_inject_forward_edges(cls): @@ -780,8 +830,8 @@ def cls_inject_forward_edges(cls): for name, link in cls._pg_links.items(): cls._pg_edges[name] = { - 'backref': find_backref(link, cls), - 'type': link['dst_type'], + "backref": find_backref(link, cls), + "type": link["dst_type"], } def cls_inject_backward_edges(cls): @@ -794,8 +844,8 @@ def cls_inject_backward_edges(cls): for name, backref in cls._pg_backrefs.items(): cls._pg_edges[name] = { - 'backref': backref['name'], - 'type': backref['src_type'], + "backref": backref["name"], + "type": backref["src_type"], } for cls in node_cls.get_subclasses(): @@ -805,7 +855,7 @@ def cls_inject_backward_edges(cls): @lru_cache(maxsize=10) def load_dictionary(dictionary=None, package_namespace=None): - """ Loads all classes defined in dictionary, this method is expected to be called only once + """Loads all classes defined in dictionary, this method is expected to be called only once and very early in the application lifecycle. Subsequent calls are cached Args: dictionary: gdc dictionary or an extension of it @@ -817,6 +867,7 @@ def load_dictionary(dictionary=None, package_namespace=None): if dictionary is None: from gdcdictionary import gdcdictionary + dictionary = gdcdictionary node_cls, edge_cls = ext.register_base_class(package_namespace) @@ -838,4 +889,3 @@ def load_dictionary(dictionary=None, package_namespace=None): # load default dictionary if os.environ.get("LOAD_GDC_DICTIONARY", "True") == "True": load_dictionary() - diff --git a/gdcdatamodel/models/versioning.py b/gdcdatamodel/models/versioning.py index 78ada1ef..4b29ff16 100644 --- a/gdcdatamodel/models/versioning.py +++ b/gdcdatamodel/models/versioning.py @@ -3,8 +3,15 @@ import six from sqlalchemy import and_, event, select +try: + from functools import lru_cache +except ImportError: + from functools32 import lru_cache -UUID_NAMESPACE_SEED = os.getenv("UUID_NAMESPACE_SEED", "86bb916a-24c5-48e4-8a46-5ea73a379d47") + +UUID_NAMESPACE_SEED = os.getenv( + "UUID_NAMESPACE_SEED", "86bb916a-24c5-48e4-8a46-5ea73a379d47" +) UUID_NAMESPACE = uuid.UUID("urn:uuid:{}".format(UUID_NAMESPACE_SEED), version=4) @@ -14,12 +21,101 @@ class TagKeys: version = "ver" +class TaggingConstraint: + """Computes whether a node instance supports tagging or not""" + + def __init__(self, path, prop, values): + """ + Args: + path (str): full psqlgraph path to a parent node + prop (str): valid node property name + values (list[str]): list of possible value + """ + self.path = path + self.prop = prop + self.values = values + + def _resolve_target_node_from_path(self, node): + """Resolves to the final node instance that can be used to perform the matching + + e.g: if path = `aligned_reads.submitted_alinged_reads`, the final node used to perform the matching + is an instance of SubmittedAlignedReads, which can be reached by following the relationships defined in + the path + i.e: node["aligned_reads"][0]["submitted_aligned_reads"][0] + + this is equivalent to: + node.aligned_reads[0].submitted_aligned_reads[0] + + Args: + node (models.Node): Node instance + + Returns: + models.Node: node instance whose properties will be used for matching + """ + if not self.path: + return node + + for path in self.path.split("."): + # Since a node type can have multiple paths to a given parent + # this check allows instances that do not have this specific path + if len(node[path]) == 0: + return None + + node = node[path][0] + return node + + def match(self, node): + """Checks if a node has a value matching the prop and values field + + if it does, the particular instance will not participate in the entire tagging process + Args: + node (psqlgraph.Node): node instance + Returns: + Returns (bool) + """ + node = self._resolve_target_node_from_path(node) + return node and node[self.prop] in self.values + + +class TagBuilderConfig: + """A wrapper around the tagBuilderConfig definition in the dictionary yaml""" + + def __init__(self, cfg): + """ + + Args: + cfg (dict[str, Any]): The tagConfig section of the dictionary + """ + self.cfg = cfg + + def _constraints(self): + """Returns all constraints defined for a particular node type""" + + skip_criterion = self.cfg.get("ignoreEntries", []) + for criteria in skip_criterion: + yield TaggingConstraint( + path=criteria.get("path"), + prop=criteria["prop"], + values=criteria["values"], + ) + + def is_taggable(self, node): + """Returns true if node supports tagging else False. Ideally, instances that return false will not + have tag and version number set on them + + Returns: + bool: True for nodes that can be tagged + """ + return not any(criteria.match(node) for criteria in self._constraints()) + + def __generate_hash(seed, label): namespace = UUID_NAMESPACE name = "{}-{}".format(seed, label) return six.ensure_str(str(uuid.uuid5(namespace, name))) +@lru_cache(maxsize=None) def compute_tag(node): """Computes unique tag for given node Args: @@ -27,15 +123,12 @@ def compute_tag(node): Returns: str: computed tag """ - keys = [ - six.ensure_str(node.node_id if p == "node_id" else node.props[p]) - for p in node.tag_properties - ] + keys = node.get_tag_property_values() keys += sorted( [ - six.ensure_str(p.dst.tag or compute_tag(p.dst)) + six.ensure_str(compute_tag(p.dst)) for p in node.edges_out - if p.label != "relates_to" + if p.dst.is_taggable() and p.label != "relates_to" ] ) return __generate_hash(keys, node.label) @@ -76,6 +169,10 @@ def inject_set_tag_after_insert(cls): @event.listens_for(cls, "after_insert") def set_node_tag(mapper, conn, node): table = node.__table__ + + if not node.is_taggable(): + return # do nothing + tag = compute_tag(node) version = __get_tagged_version(node.node_id, table, tag, conn) diff --git a/requirements.txt b/requirements.txt index fe402540..2cfa5610 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ # -# This file is autogenerated by pip-compile +# This file is autogenerated by pip-compile with python 3.9 # To update, run: # # pip-compile @@ -8,9 +8,9 @@ attrs==21.2.0 # via jsonschema decorator==4.4.2 # via gdcdatamodel (setup.py) -git+https://github.com/NCI-GDC/gdc-ng-models.git@1.5.2#egg=gdc-ng-models +gdc-ng-models @ git+https://github.com/NCI-GDC/gdc-ng-models.git@1.5.2 # via gdcdatamodel (setup.py) -git+https://github.com/NCI-GDC/gdcdictionary.git@2.4.2#egg=gdcdictionary +gdcdictionary @ git+https://github.com/NCI-GDC/gdcdictionary.git@2.5.0-rc.0 # via gdcdatamodel (setup.py) graphviz==0.14.2 # via gdcdatamodel (setup.py) @@ -20,7 +20,7 @@ jsonschema==3.2.0 # via # gdcdatamodel (setup.py) # gdcdictionary -git+https://github.com/NCI-GDC/psqlgraph.git@3.4.0#egg=psqlgraph +psqlgraph @ git+https://github.com/NCI-GDC/psqlgraph.git@3.4.0 # via gdcdatamodel (setup.py) psycopg2==2.8.6 # via psqlgraph diff --git a/setup.py b/setup.py index 2ec1cac7..99eee56e 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ "jsonschema~=3.2", "pyrsistent<0.17.0", "decorator<=5.0.0", - "gdcdictionary @ git+https://github.com/NCI-GDC/gdcdictionary.git@2.4.2#egg=gdcdictionary", + "gdcdictionary @ git+https://github.com/NCI-GDC/gdcdictionary.git@2.5.0-rc.0#egg=gdcdictionary", "gdc-ng-models @ git+https://github.com/NCI-GDC/gdc-ng-models.git@1.5.2#egg=gdc-ng-models", "psqlgraph @ git+https://github.com/NCI-GDC/psqlgraph.git@3.4.0#egg=psqlgraph", ], diff --git a/test/schema/basic.yaml b/test/schema/basic.yaml index 63d8b3a5..0a084d97 100644 --- a/test/schema/basic.yaml +++ b/test/schema/basic.yaml @@ -80,6 +80,11 @@ center: type: string tagProperties: - code + tagBuilderConfig: + ignoreEntries: + - prop: code + values: + - A101 portion: id: portion category: biospecimen diff --git a/test/unit/test_tagging.py b/test/unit/test_tagging.py index 100a5185..f68a87ce 100644 --- a/test/unit/test_tagging.py +++ b/test/unit/test_tagging.py @@ -1,3 +1,5 @@ +import pytest + from gdcdatamodel.models import versioning as v from gdcdatamodel.models import basic # noqa @@ -20,13 +22,17 @@ def test_compute_tag(sample_data): for node in sample_data: print("\n..........{}...........".format(node)) v_tag = v.compute_tag(node) - assert v_tag == EXPECTED_TAGS[node.node_id], "invalid tag computed for {}".format(node.node_id) + assert ( + v_tag == EXPECTED_TAGS[node.node_id] + ), "invalid tag computed for {}".format(node.node_id) def test_multi_parent(sample_data): """Test version tag resolves to the same value independent of how the parents were attached""" - portion = basic.Portion(node_id="b9b6fdb3-6c31-4ed3-9f8c-67d4eae72102", submitter_id="portion_2") + portion = basic.Portion( + node_id="b9b6fdb3-6c31-4ed3-9f8c-67d4eae72102", submitter_id="portion_2" + ) v_tag = v.compute_tag(portion) assert v_tag == "5776f97a-a58b-5900-83da-43cbc7105796" @@ -42,16 +48,33 @@ def test_multi_parent(sample_data): portion.samples.append(sample) portion.centers.append(center) + + v.compute_tag.cache_clear() v_tag = v.compute_tag(portion) assert v_tag == "a9a67fae-d916-5843-bdf3-b7db0b7a82a2" # unlink portion.samples = [] portion.centers = [] + + v.compute_tag.cache_clear() v_tag = v.compute_tag(portion) assert v_tag == "5776f97a-a58b-5900-83da-43cbc7105796" portion.centers.append(center) portion.samples.append(sample) + + v.compute_tag.cache_clear() v_tag = v.compute_tag(portion) assert v_tag == "a9a67fae-d916-5843-bdf3-b7db0b7a82a2" + + +@pytest.mark.parametrize( + "node, is_taggable", + [ + (basic.Portion(node_id="A101", submitter_id="portion_2"), True), + (basic.Center(node_id="TEST-1", code="A101"), False), + ], +) +def test_node_is_taggable(node, is_taggable): + assert node.is_taggable() is is_taggable