diff --git a/conftest.py b/conftest.py index 285f73efa9..90873cc0ab 100644 --- a/conftest.py +++ b/conftest.py @@ -29,6 +29,7 @@ ] CORE_FIXTURE_LOCATIONS = [ + "tests.core.fixtures.core_database", "tests.core.fixtures.core_datasets", "tests.core.fixtures.core_plugins", "tests.core.fixtures.core_projects", diff --git a/renku/cli/graph.py b/renku/cli/graph.py index c40574e57b..13ef1a9469 100644 --- a/renku/cli/graph.py +++ b/renku/cli/graph.py @@ -32,7 +32,6 @@ ) from renku.core.incubation.graph import status as get_status from renku.core.incubation.graph import update as perform_update -from renku.core.models.workflow.dependency_graph import DependencyGraph from renku.core.utils.contexts import measure @@ -117,8 +116,7 @@ def save(path): with measure("CREATE DEPENDENCY GRAPH"): def _to_png(client, path): - dg = DependencyGraph.from_json(client.dependency_graph_path) - dg.to_png(path=path) + client.dependency_graph.to_png(path=path) Command().command(_to_png).build().execute(path=path) diff --git a/renku/core/incubation/database.py b/renku/core/incubation/database.py new file mode 100644 index 0000000000..8de5e9f4f7 --- /dev/null +++ b/renku/core/incubation/database.py @@ -0,0 +1,652 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018-2021- Swiss Data Science Center (SDSC) +# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and +# Eidgenössische Technische Hochschule Zürich (ETHZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Custom database for store Persistent objects.""" + +import datetime +import hashlib +import json +from pathlib import Path +from typing import Dict, List, Optional, Union +from uuid import uuid4 + +from BTrees.OOBTree import OOBTree +from persistent import GHOST, UPTODATE, Persistent +from persistent.interfaces import IPickleCache +from ZODB.POSException import POSKeyError +from ZODB.utils import z64 +from zope.interface import implementer + +OID_TYPE = str +MARKER = object() + +"""NOTE: These are used as _p_serial to mark if an object was read from storage or is new""" +NEW = z64 # NOTE: Do not change this value since this is the default when a Persistent object is created +PERSISTED = b"1" * 8 + + +def get_type_name(object) -> Optional[str]: + """Return fully-qualified object's type name.""" + if object is None: + return None + + object_type = object if isinstance(object, type) else type(object) + return f"{object_type.__module__}.{object_type.__qualname__}" + + +def get_class(type_name: Optional[str]) -> Optional[type]: + """Return the class for a fully-qualified type name.""" + if type_name is None: + return None + + components = type_name.split(".") + module_name = components[0] + + if module_name not in ["renku", "datetime", "BTrees", "persistent"]: + raise TypeError(f"Objects of type '{type_name}' are not allowed") + + module = __import__(module_name) + + return get_attribute(module, components[1:]) + + +def get_attribute(object, name: Union[List[str], str]): + """Return an attribute of an object.""" + components = name.split(".") if isinstance(name, str) else name + + for component in components: + object = getattr(object, component) + + return object + + +class Database: + """The Metadata Object Database. + + This class is equivalent to a persistent.DataManager and implements persistent.interfaces.IPersistentDataManager + interface. + """ + + ROOT_OID = "root" + + def __init__(self, storage): + self._storage: Storage = storage + self._cache = Cache() + # The pre-cache is used by get to avoid infinite loops when objects load their state + self._pre_cache: Dict[OID_TYPE, Persistent] = {} + # Objects added explicitly by add() or when serializing other objects. After commit they are moved to _cache. + self._objects_to_commit: Dict[OID_TYPE, Persistent] = {} + self._reader: ObjectReader = ObjectReader(database=self) + self._writer: ObjectWriter = ObjectWriter(database=self) + self._root: Optional[OOBTree] = None + + self._initialize_root() + + @classmethod + def from_path(cls, path: Union[Path, str]) -> "Database": + """Create a Storage and Database using the given path.""" + storage = Storage(path) + return Database(storage=storage) + + @staticmethod + def generate_oid(object: Persistent) -> OID_TYPE: + """Generate oid for a Persistent object based on its id.""" + oid = getattr(object, "_p_oid") + if oid: + assert isinstance(oid, OID_TYPE) + return oid + + id: str = getattr(object, "id", None) or getattr(object, "_id", None) + if id: + return Database.hash_id(id) + + return Database.new_oid() + + @staticmethod + def hash_id(id: str) -> OID_TYPE: + """Return oid from id.""" + return hashlib.sha3_256(id.encode("utf-8")).hexdigest() + + @staticmethod + def new_oid(): + """Generate a random oid.""" + return f"{uuid4().hex}{uuid4().hex}" + + @staticmethod + def _get_filename_from_oid(oid: OID_TYPE) -> str: + return oid.lower() + + def __getitem__(self, key) -> "Index": + return self._root[key] + + @property + def root(self): + """Return the database root object.""" + return self._root + + def _initialize_root(self): + """Initialize root object.""" + if not self._root: + try: + self._root = self.get(Database.ROOT_OID) + except POSKeyError: + self._root = OOBTree() + self._root._p_oid = Database.ROOT_OID + self.register(self._root) + + def add_index(self, name: str, object_type: type, attribute: str = None, key_type: type = None) -> "Index": + """Add an index.""" + assert name not in self._root, f"Index already exists: '{name}'" + + index = Index(name=name, object_type=object_type, attribute=attribute, key_type=key_type) + index._p_jar = self + + self._root[name] = index + + return index + + def register(self, object: Persistent): + """Register a Persistent object to be stored. + + NOTE: When a Persistent object is changed it calls this method. + """ + assert isinstance(object, Persistent), f"Cannot add non-Persistent object: '{object}'" + + if object._p_oid is None: + object._p_oid = self.generate_oid(object) + + object._p_jar = self + # object._p_serial = NEW + self._objects_to_commit[object._p_oid] = object + + def get(self, oid: OID_TYPE) -> Persistent: + """Get the object by oid.""" + if oid != Database.ROOT_OID and oid in self._root: # NOTE: Avoid looping if getting "root" + return self._root[oid] + + object = self.get_cached(oid) + if object is not None: + return object + + data = self._storage.load(filename=self._get_filename_from_oid(oid)) + object = self._reader.deserialize(data) + object._p_changed = 0 + object._p_serial = PERSISTED + + # NOTE: Avoid infinite loop if object tries to load its state before it is added to the cache + self._pre_cache[oid] = object + self._cache[oid] = object + self._pre_cache.pop(oid) + + return object + + def get_cached(self, oid: OID_TYPE) -> Optional[Persistent]: + """Return an object if it is in the cache or will be committed.""" + object = self._cache.get(oid) + if object is not None: + return object + + object = self._pre_cache.get(oid) + if object is not None: + return object + + object = self._objects_to_commit.get(oid) + if object is not None: + return object + + def new_ghost(self, oid: OID_TYPE, object: Persistent): + """Create a new ghost object.""" + object._p_jar = self + self._cache.new_ghost(oid, object) + + def setstate(self, object: Persistent): + """Load the state for a ghost object.""" + data = self._storage.load(filename=self._get_filename_from_oid(object._p_oid)) + self._reader.set_ghost_state(object, data) + object._p_serial = PERSISTED + + def commit(self): + """Commit modified and new objects.""" + while self._objects_to_commit: + _, object = self._objects_to_commit.popitem() + if object._p_changed or object._p_serial == NEW: + self._store_object(object) + + def _store_object(self, object: Persistent): + data = self._writer.serialize(object) + self._storage.store(filename=self._get_filename_from_oid(object._p_oid), data=data) + + self._cache[object._p_oid] = object + + object._p_changed = 0 # NOTE: transition from changed to up-to-date + object._p_serial = PERSISTED + + def remove_from_cache(self, object: Persistent): + """Remove an object from cache.""" + oid = object._p_oid + self._cache.pop(oid, None) + self._pre_cache.pop(oid, None) + self._objects_to_commit.pop(oid, None) + + def readCurrent(self, object): + """We don't use this method but some Persistent logic require its existence.""" + assert object._p_jar is self + assert object._p_oid is not None + + def oldstate(self, object, tid): + """See persistent.interfaces.IPersistentDataManager::oldstate.""" + raise NotImplementedError + + +@implementer(IPickleCache) +class Cache: + """Database Cache.""" + + def __init__(self): + self._entries = {} + + def __len__(self): + return len(self._entries) + + def __getitem__(self, oid): + assert isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" + return self._entries[oid] + + def __setitem__(self, oid, object): + assert isinstance(object, Persistent), f"Cannot cache non-Persistent objects: '{object}'" + assert isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" + + assert object._p_jar is not None, "Cached object jar missing" + assert oid == object._p_oid, f"Cache key does not match oid: {oid} != {object._p_oid}" + + if oid in self._entries: + existing_data = self.get(oid) + if existing_data is not object: + raise ValueError(f"The same oid exists: {existing_data} != {object}") + + self._entries[oid] = object + + def __delitem__(self, oid): + assert isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" + self._entries.pop(oid) + + def pop(self, oid, default=MARKER): + """Remove and return an object.""" + return self._entries.pop(oid) if default is MARKER else self._entries.pop(oid, default) + + def get(self, oid, default=None): + """See IPickleCache.""" + assert isinstance(oid, OID_TYPE), f"Invalid oid type: '{type(oid)}'" + return self._entries.get(oid, default) + + def new_ghost(self, oid, object): + """See IPickleCache.""" + assert object._p_oid is None, f"Object already has an oid: {object}" + assert object._p_jar is not None, f"Object does not have a jar: {object}" + assert oid not in self._entries, f"Duplicate oid: {oid}" + + object._p_oid = oid + if object._p_state != GHOST: + object._p_invalidate() + + self[oid] = object + + +class Index(Persistent): + """Database index.""" + + def __init__(self, *, name: str, object_type, attribute: Optional[str], key_type=None): + """ + Create an index where keys are extracted using `attribute` from an object or a key. + + @param name: Index's name + @param object_type: Type of objects that the index points to + @param attribute: Name of an attribute to be used to automatically generate a key (e.g. `entity.path`) + @param key_type: Type of keys. If not None then a key must be provided when updating the index + """ + assert name == name.lower(), f"Index name must be all lowercase: '{name}'." + + super().__init__() + + self._p_oid = f"{name}-index" + self._name: str = name + self._object_type = object_type + self._key_type = key_type + self._attribute: Optional[str] = attribute + self._entries: OOBTree = OOBTree() + self._entries._p_oid = name + + def __len__(self): + return len(self._entries) + + def __contains__(self, key): + return key in self._entries + + def __getitem__(self, key): + return self._entries[key] + + def __setitem__(self, key, value): + # NOTE: if Index is using a key object then we cannot check if key is valid. It's safer to use `add` method + # instead of setting values directly. + self._verify_and_get_key(object=value, key_object=None, key=key, missing_key_object_ok=True) + + self._entries[key] = value + + def __getstate__(self): + return { + "name": self._name, + "object_type": get_type_name(self._object_type), + "key_type": get_type_name(self._key_type), + "attribute": self._attribute, + "entries": self._entries, + } + + def __setstate__(self, data): + self._name = data.pop("name") + self._object_type = get_class(data.pop("object_type")) + self._key_type = get_class(data.pop("key_type")) + self._attribute = data.pop("attribute") + self._entries = data.pop("entries") + + @property + def name(self) -> str: + """Return Index's name.""" + return self._name + + @property + def object_type(self) -> type: + """Return Index's object_type.""" + return self._object_type + + def get(self, key, default=None): + """Return an entry based on its key.""" + return self._entries.get(key, default) + + def pop(self, key, default=MARKER): + """Remove and return an object.""" + return self._entries.pop(key) if default is MARKER else self._entries.pop(key, default) + + def values(self): + """Return an iterator of values.""" + return self._entries.values() + + def items(self): + """Return an iterator of keys and values.""" + return self._entries.items() + + def add(self, object: Persistent, *, key: Optional[str] = None, key_object=None): + """Update index with object. + + If `Index._attribute` is not None then key is automatically generated. + Key is extracted from `key_object` if it is not None; otherwise, it's extracted from `object`. + """ + assert isinstance(object, self._object_type), f"Cannot add objects of type '{type(object)}'" + + key = self._verify_and_get_key(object=object, key_object=key_object, key=key, missing_key_object_ok=False) + self._entries[key] = object + + def generate_key(self, object: Persistent, *, key_object=None): + """Return index key for an object. + + Key is extracted from `key_object` if it is not None; otherwise, it's extracted from `object`. + """ + return self._verify_and_get_key(object=object, key_object=key_object, key=None, missing_key_object_ok=False) + + def _verify_and_get_key(self, *, object: Persistent, key_object, key, missing_key_object_ok): + if self._key_type: + if not missing_key_object_ok: + assert isinstance(key_object, self._key_type), f"Invalid key type: {type(key_object)} for '{self.name}'" + else: + assert key_object is None, f"Index '{self.name}' does not accept 'key_object'" + + if self._attribute: + key_object = key_object or object + correct_key = get_attribute(key_object, self._attribute) + if key is not None: + assert key == correct_key, f"Incorrect key for index '{self.name}': '{key}' != '{correct_key}'" + else: + assert key is not None, "No key is provided" + correct_key = key + + return correct_key + + +class Storage: + """Store Persistent objects on the disk.""" + + MIN_COMPRESSED_FILENAME_LENGTH = 64 + + def __init__(self, path: Union[Path, str]): + self.path = Path(path) + self.path.mkdir(parents=True, exist_ok=True) + + def store(self, filename: str, data: Union[Dict, List]): + """Store object.""" + assert isinstance(filename, str) + + compressed = len(filename) >= Storage.MIN_COMPRESSED_FILENAME_LENGTH + if compressed: + path = self.path / filename[0:2] / filename[2:4] / filename + path.parent.mkdir(parents=True, exist_ok=True) + open_func = open # TODO: Change this to gzip.open for the final version + else: + path = self.path / filename + open_func = open + + with open_func(path, "w") as file: + json.dump(data, file, ensure_ascii=False, sort_keys=True, indent=2) + + def load(self, filename: str): + """Load data for object with object id oid.""" + assert isinstance(filename, str) + + compressed = len(filename) >= Storage.MIN_COMPRESSED_FILENAME_LENGTH + if compressed: + path = self.path / filename[0:2] / filename[2:4] / filename + open_func = open # TODO: Change this to gzip.open for the final version + else: + path = self.path / filename + open_func = open + + if not path.exists(): + raise POSKeyError(filename) + + with open_func(path) as file: + data = json.load(file) + + return data + + +class ObjectWriter: + """Serialize objects for storage in storage.""" + + def __init__(self, database: Database): + self._database: Database = database + + def serialize(self, object: Persistent): + """Convert an object to JSON.""" + assert isinstance(object, Persistent), f"Cannot serialize object of type '{type(object)}': {object}" + assert object._p_oid, f"Object does not have an oid: '{object}'" + assert object._p_jar is not None, f"Object is not associated with a Database: '{object}'" + + state = object.__getstate__() + data = self._serialize_helper(state) + + if not isinstance(data, dict): + data = {"@value": data} + + data["@type"] = get_type_name(object) + data["@oid"] = object._p_oid + + return data + + def _serialize_helper(self, object): + # TODO: Add support for weakref. See persistent.wref.WeakRef + if object is None: + return None + elif isinstance(object, list): + return [self._serialize_helper(value) for value in object] + elif isinstance(object, tuple): + return tuple([self._serialize_helper(value) for value in object]) + elif isinstance(object, dict): + for key, value in object.items(): + object[key] = self._serialize_helper(value) + return object + elif isinstance(object, (int, float, str, bool)): + return object + elif isinstance(object, datetime.datetime): + return {"@type": get_type_name(object), "@value": object.isoformat()} + elif isinstance(object, Index): + # NOTE: Include Index objects directly to their parent object (i.e. root) + assert object._p_oid is not None, f"Index has no oid: {object}" + state = object.__getstate__() + state = self._serialize_helper(state) + state["@type"] = get_type_name(object) + state["@oid"] = object._p_oid + return state + elif isinstance(object, Persistent): + if not object._p_oid: + object._p_oid = Database.generate_oid(object) + if object._p_state not in [GHOST, UPTODATE] or (object._p_state == UPTODATE and object._p_serial == NEW): + self._database.register(object) + return {"@type": get_type_name(object), "@oid": object._p_oid, "@reference": True} + elif hasattr(object, "__getstate__"): + state = object.__getstate__() + if isinstance(state, dict) and "_id" in state: # TODO: Remove this once all Renku classes have 'id' field + state["id"] = state.pop("_id") + return self._serialize_helper(state) + else: + state = object.__dict__.copy() + state = self._serialize_helper(state) + state["@type"] = get_type_name(object) + if "_id" in state: # TODO: Remove this once all Renku classes have 'id' field + state["id"] = state.pop("_id") + return state + + +class ObjectReader: + """Deserialize objects loaded from storage.""" + + def __init__(self, database: Database): + self._classes: Dict[str, type] = {} + self._database = database + + def _get_class(self, type_name: str) -> type: + cls = self._classes.get(type_name) + if cls: + return cls + + cls = get_class(type_name) + + self._classes[type_name] = cls + return cls + + def set_ghost_state(self, object: Persistent, data: Dict): + """Set state of a Persistent ghost object.""" + state = self._deserialize_helper(data, create=False) + if isinstance(object, OOBTree): + state = self._to_tuple(state) + + object.__setstate__(state) + + def _to_tuple(self, data): + if isinstance(data, list): + return tuple(self._to_tuple(value) for value in data) + return data + + def deserialize(self, data): + """Convert JSON to Persistent object.""" + oid = data["@oid"] + + object = self._deserialize_helper(data) + + object._p_oid = oid + object._p_jar = self._database + + return object + + def _deserialize_helper(self, data, create=True): + # TODO WeakRef + if data is None: + return None + elif isinstance(data, (int, float, str, bool)): + return data + elif isinstance(data, list): + return [self._deserialize_helper(value) for value in data] + elif isinstance(data, tuple): + return tuple([self._deserialize_helper(value) for value in data]) + else: + assert isinstance(data, dict), f"Data must be a list: '{type(data)}'" + + object_type = data.pop("@type", None) + if not object_type: # NOTE: A normal dict value + assert "@oid" not in data + for key, value in data.items(): + data[key] = self._deserialize_helper(value) + return data + + cls = self._get_class(object_type) + + if issubclass(cls, datetime.datetime): + assert create + value = data["@value"] + return datetime.datetime.fromisoformat(value) + + oid: str = data.pop("@oid", None) + if oid: + assert isinstance(oid, str) + + if "@reference" in data and data["@reference"]: # A reference + object = self._database.get_cached(oid) + if object: + return object + assert issubclass(cls, Persistent) + object = cls.__new__(cls) + self._database.new_ghost(oid, object) + return object + elif issubclass(cls, Index): + object = self._database.get_cached(oid) + if object: + return object + object = cls.__new__(cls) + object._p_oid = oid + self.set_ghost_state(object, data) + return object + + if "@value" in data: + data = data["@value"] + + if isinstance(data, dict): + for key, value in data.items(): + data[key] = self._deserialize_helper(value) + else: + data = self._deserialize_helper(data) + + if not create: + return data + + if issubclass(cls, Persistent): + object = cls.__new__(cls) + if isinstance(object, OOBTree): + data = self._to_tuple(data) + + object.__setstate__(data) + else: + assert isinstance(data, dict) + object = cls(**data) + + return object diff --git a/renku/core/incubation/graph.py b/renku/core/incubation/graph.py index 6f8a8b4097..78eb657371 100644 --- a/renku/core/incubation/graph.py +++ b/renku/core/incubation/graph.py @@ -42,7 +42,6 @@ from renku.core.models.entities import Entity from renku.core.models.jsonld import load_yaml from renku.core.models.provenance.activities import Activity -from renku.core.models.provenance.activity import ActivityCollection from renku.core.models.provenance.provenance_graph import ProvenanceGraph from renku.core.models.workflow.dependency_graph import DependencyGraph from renku.core.models.workflow.plan import Plan @@ -54,6 +53,7 @@ from renku.core.utils.shacl import validate_graph GRAPH_METADATA_PATHS = [ + Path(RENKU_HOME) / Path(RepositoryApiMixin.DATABASE_PATH), Path(RENKU_HOME) / Path(RepositoryApiMixin.DEPENDENCY_GRAPH), Path(RENKU_HOME) / Path(RepositoryApiMixin.PROVENANCE_GRAPH), Path(RENKU_HOME) / Path(DatasetsApiMixin.DATASETS_PROVENANCE), @@ -69,7 +69,7 @@ def generate_graph(): def _generate_graph(client, force=False): """Generate graph and dataset provenance metadata.""" - def process_workflows(commit: Commit, provenance_graph: ProvenanceGraph): + def process_workflows(commit: Commit): for file_ in commit.diff(commit.parents or NULL_TREE, paths=f"{client.workflow_path}/*.yaml"): # Ignore deleted files (they appear as ADDED in this backwards diff) if file_.change_type == "A": @@ -85,9 +85,7 @@ def process_workflows(commit: Commit, provenance_graph: ProvenanceGraph): continue workflow = Activity.from_yaml(path=path, client=client) - activity_collection = ActivityCollection.from_activity(workflow, client.dependency_graph, client) - - provenance_graph.add(activity_collection) + client.update_graphs(workflow) def process_datasets(commit): files_diff = list(commit.diff(commit.parents or NULL_TREE, paths=".renku/datasets/*/*.yml")) @@ -118,13 +116,11 @@ def process_datasets(commit): client.initialize_graph() client.initialize_datasets_provenance() - provenance_graph = ProvenanceGraph.from_json(client.provenance_graph_path) - for n, commit in enumerate(commits, start=1): communication.echo(f"Processing commits {n}/{n_commits}", end="\r") try: - process_workflows(commit, provenance_graph) + process_workflows(commit) process_datasets(commit) except errors.MigrationError: communication.echo("") @@ -133,10 +129,10 @@ def process_datasets(commit): communication.echo("") communication.warn(f"Cannot process commit '{commit.hexsha}' - Exception: {traceback.format_exc()}") - client.dependency_graph.to_json() - provenance_graph.to_json() client.datasets_provenance.to_json() + client.database.commit() + def status(): """Return a command for getting workflow graph status.""" @@ -234,7 +230,6 @@ def _export_graph(client, format, workflows_only, strict): if not client.provenance_graph_path.exists(): raise errors.ParameterError("Graph is not generated.") - pg = ProvenanceGraph.from_json(client.provenance_graph_path, lazy=True) format = format.lower() if strict and format not in ["json-ld", "jsonld"]: raise errors.SHACLValidationError(f"'--strict' not supported for '{format}'") diff --git a/renku/core/management/git.py b/renku/core/management/git.py index 169b5782f6..14afa58aaf 100644 --- a/renku/core/management/git.py +++ b/renku/core/management/git.py @@ -285,11 +285,10 @@ def commit( committer = Actor("renku {0}".format(__version__), version_url) - change_types = {} + change_types = {git_unicode_unescape(item.a_path): item.change_type for item in self.repo.index.diff(None)} if commit_only == COMMIT_DIFF_STRATEGY: # Get diff generated in command. - change_types = {git_unicode_unescape(item.a_path): item.change_type for item in self.repo.index.diff(None)} staged_after = set(change_types.keys()) modified_after_change_types = { @@ -308,7 +307,7 @@ def commit( if isinstance(commit_only, list): for path_ in commit_only: p = self.path / path_ - if p.exists() or change_types.get(path_) == "D": + if p.exists() or change_types.get(str(path_)) == "D": self.repo.git.add(path_) if not commit_only: diff --git a/renku/core/management/migrations/models/v3.py b/renku/core/management/migrations/models/v3.py index d6c8350c1f..8a935780bb 100644 --- a/renku/core/management/migrations/models/v3.py +++ b/renku/core/management/migrations/models/v3.py @@ -37,7 +37,8 @@ from renku.core.models.datasets import generate_dataset_tag_id, generate_url_id from renku.core.models.git import get_user_info from renku.core.models.projects import generate_project_id -from renku.core.models.provenance.agents import generate_person_id +from renku.core.models.provenance import agents +from renku.core.utils.urls import get_host class Base: @@ -60,24 +61,12 @@ class Person(Base): email = None name = None - @staticmethod - def _fix_person_id(person, client=None): - """Fixes the id of a Person if it is not set.""" - if not person._id or "mailto:None" in person._id or person._id.startswith("_:"): - if not client and person.client: - client = person.client - person._id = generate_person_id(client=client, email=person.email, full_identity=person.full_identity) - - return person - @classmethod def from_git(cls, git, client=None): """Create an instance from a Git repo.""" name, email = get_user_info(git) instance = cls(name=name, email=email) - - instance = Person._fix_person_id(instance, client) - + instance.fix_id(client) return instance def __init__(self, **kwargs): @@ -92,6 +81,14 @@ def full_identity(self): affiliation = f" [{self.affiliation}]" if self.affiliation else "" return f"{self.name}{email}{affiliation}" + def fix_id(self, client=None): + """Fixes the id of a Person if it is not set.""" + if not self._id or "mailto:None" in self._id or self._id.startswith("_:"): + if not client and self.client: + client = self.client + hostname = get_host(client) + self._id = agents.Person.generate_id(email=self.email, full_identity=self.full_identity, hostname=hostname) + class Project(Base): """Project migration model.""" @@ -208,8 +205,7 @@ class Meta: def make_instance(self, data, **kwargs): """Transform loaded dict into corresponding object.""" instance = JsonLDSchema.make_instance(self, data, **kwargs) - - instance = Person._fix_person_id(instance) + instance.fix_id(client=None) return instance diff --git a/renku/core/management/repository.py b/renku/core/management/repository.py index da6e7d1e9c..e62791e911 100644 --- a/renku/core/management/repository.py +++ b/renku/core/management/repository.py @@ -25,7 +25,7 @@ from collections import defaultdict from contextlib import contextmanager from subprocess import check_output -from typing import Union +from typing import Optional, Union import attr import filelock @@ -35,6 +35,7 @@ from renku.core import errors from renku.core.compat import Path +from renku.core.incubation.database import Database from renku.core.management.config import RENKU_HOME from renku.core.models.enums import ConfigFilter from renku.core.models.projects import Project @@ -118,6 +119,9 @@ class RepositoryApiMixin(GitCore): PROVENANCE_GRAPH = "provenance.json" """File for storing ProvenanceGraph.""" + DATABASE_PATH: str = "metadata" + """Directory for metadata storage.""" + ACTIVITY_INDEX = "activity_index.yaml" """Caches activities that generated a path.""" @@ -147,10 +151,10 @@ class RepositoryApiMixin(GitCore): _remote_cache = attr.ib(factory=dict) - _dependency_graph = None - _migration_type = attr.ib(default=MigrationType.ALL) + _database = attr.ib(default=None) + def __attrs_post_init__(self): """Initialize computed attributes.""" #: Configure Renku path. @@ -230,7 +234,7 @@ def template_checksums(self): return self.renku_path / self.TEMPLATE_CHECKSUMS @property - def provenance_graph_path(self) -> str: + def provenance_graph_path(self) -> Path: """Path to store activity files.""" return self.renku_path / self.PROVENANCE_GRAPH @@ -239,6 +243,11 @@ def dependency_graph_path(self): """Path to the dependency graph file.""" return self.renku_path / self.DEPENDENCY_GRAPH + @property + def database_path(self) -> Path: + """Path to the metadata storage directory.""" + return self.renku_path / self.DATABASE_PATH + @cached_property def cwl_prefix(self): """Return a CWL prefix.""" @@ -250,10 +259,24 @@ def dependency_graph(self): """Return dependency graph if available.""" if not self.has_graph_files(): return - if not self._dependency_graph: - self._dependency_graph = DependencyGraph.from_json(self.dependency_graph_path) + return DependencyGraph.from_database(self.database) - return self._dependency_graph + @property + def provenance_graph(self) -> Optional[ProvenanceGraph]: + """Return provenance graph if available.""" + if not self.has_graph_files(): + return + return ProvenanceGraph.from_database(self.database) + + @property + def database(self) -> Optional[Database]: + """Return metadata storage if available.""" + if not self.has_graph_files(): + return + if not self._database: + self._database = Database.from_path(path=self.database_path) + + return self._database @property def project(self): @@ -494,27 +517,45 @@ def process_and_store_run(self, command_line_tool, name, description, keywords): def update_graphs(self, activity: Union[ProcessRun, WorkflowRun]): """Update Dependency and Provenance graphs from a ProcessRun/WorkflowRun.""" if not self.has_graph_files(): - return + return None dependency_graph = DependencyGraph.from_json(self.dependency_graph_path) provenance_graph = ProvenanceGraph.from_json(self.provenance_graph_path) activity_collection = ActivityCollection.from_activity(activity, dependency_graph, self) + self.provenance_graph.add(activity_collection) + database = self.database + provenance_graph.add(activity_collection) + for activity in activity_collection.activities: + database.get("activities").add(activity) + database.get("plans").add(activity.association.plan) + + database.commit() dependency_graph.to_json() provenance_graph.to_json() def has_graph_files(self): """Return true if dependency or provenance graph exists.""" - return self.dependency_graph_path.exists() or self.provenance_graph_path.exists() + return self.database_path.exists() def initialize_graph(self): """Create empty graph files.""" self.dependency_graph_path.write_text("[]") self.provenance_graph_path.write_text("[]") + self.database_path.mkdir(parents=True, exist_ok=True) + + database = self.database + + from renku.core.models.provenance.activity import Activity + from renku.core.models.workflow.plan import Plan + + database.add_index(name="activities", object_type=Activity, attribute="id") + database.add_index(name="plans", object_type=Plan, attribute="id") + def remove_graph_files(self): """Remove all graph files.""" try: @@ -525,6 +566,10 @@ def remove_graph_files(self): self.provenance_graph_path.unlink() except FileNotFoundError: pass + try: + shutil.rmtree(self.database_path) + except FileNotFoundError: + pass def init_repository(self, force=False, user=None, initial_branch=None): """Initialize an empty Renku repository.""" diff --git a/renku/core/management/storage.py b/renku/core/management/storage.py index ba06e84b23..5fc59c9216 100644 --- a/renku/core/management/storage.py +++ b/renku/core/management/storage.py @@ -32,9 +32,8 @@ from werkzeug.utils import cached_property from renku.core import errors -from renku.core.models.provenance.activities import Collection +from renku.core.models.provenance.activity import Collection from renku.core.models.provenance.datasets import DatasetProvenance -from renku.core.models.provenance.provenance_graph import ProvenanceGraph from renku.core.utils import communication from renku.core.utils.file_size import parse_file_size from renku.core.utils.git import add_to_git, run_command @@ -534,6 +533,25 @@ def migrate_files_to_lfs(self, paths): def _map_checksum(entity, checksum_mapping): """Update the checksum and id of an entity based on a mapping.""" + if entity.checksum not in checksum_mapping: + return False + + new_checksum = checksum_mapping[entity.checksum] + + entity.id = entity.id.replace(entity.checksum, new_checksum) + entity.checksum = new_checksum + + if isinstance(entity, Collection) and entity.members: + for member in entity.members: + _map_checksum(member, checksum_mapping) + + return True + + def _map_checksum_old(entity, checksum_mapping): + """Update the checksum and id of an entity based on a mapping.""" + # TODO: Remove this method once moved to Entity with 'id' field + from renku.core.models.provenance.activities import Collection + if entity.checksum not in checksum_mapping: return @@ -547,30 +565,34 @@ def _map_checksum(entity, checksum_mapping): _map_checksum(member, checksum_mapping) # NOTE: Update workflow provenance - provenance_graph = ProvenanceGraph.from_json(self.provenance_graph_path) + provenance_graph = self.provenance_graph for activity in provenance_graph.activities: + changed = False if activity.generations: for generation in activity.generations: entity = generation.entity - _map_checksum(entity, sha_mapping) + changed |= _map_checksum(entity, sha_mapping) if activity.usages: for usage in activity.usages: entity = usage.entity - _map_checksum(entity, sha_mapping) + changed |= _map_checksum(entity, sha_mapping) if activity.invalidations: for entity in activity.invalidations: - _map_checksum(entity, sha_mapping) + changed |= _map_checksum(entity, sha_mapping) + + if changed: + activity._p_changed = True - provenance_graph.to_json() + self.database.commit() # NOTE: Update datasets provenance datasets_provenance = DatasetProvenance.from_json(self.datasets_provenance_path) for dataset in datasets_provenance.datasets: for file_ in dataset.files: - _map_checksum(file_.entity, sha_mapping) + _map_checksum_old(file_.entity, sha_mapping) datasets_provenance.to_json() diff --git a/renku/core/models/cwl/annotation.py b/renku/core/models/cwl/annotation.py index f144036e9a..6ba04c4f39 100644 --- a/renku/core/models/cwl/annotation.py +++ b/renku/core/models/cwl/annotation.py @@ -17,35 +17,18 @@ # limitations under the License. """Represent an annotation for a workflow.""" -import attr from marshmallow import EXCLUDE from renku.core.models.calamus import JsonLDSchema, dcterms, fields, oa -@attr.s(eq=False, order=False) class Annotation: """Represents a custom annotation for a research object.""" - _id = attr.ib(kw_only=True) - - body = attr.ib(default=None, kw_only=True) - - source = attr.ib(default=None, kw_only=True) - - @classmethod - def from_jsonld(cls, data): - """Create an instance from JSON-LD data.""" - if isinstance(data, cls): - return data - if not isinstance(data, dict): - raise ValueError(data) - - return AnnotationSchema().load(data) - - def as_jsonld(self): - """Create JSON-LD.""" - return AnnotationSchema().dump(self) + def __init__(self, *, id: str, body=None, source=None): + self.id = id + self.body = body + self.source = source class AnnotationSchema(JsonLDSchema): @@ -58,6 +41,6 @@ class Meta: model = Annotation unknown = EXCLUDE - _id = fields.Id(init_name="id") + id = fields.Id() body = fields.RawJsonLD(oa.hasBody) source = fields.Raw(dcterms.creator) diff --git a/renku/core/models/datasets.py b/renku/core/models/datasets.py index 970a9ba0cf..23620caea1 100644 --- a/renku/core/models/datasets.py +++ b/renku/core/models/datasets.py @@ -37,7 +37,7 @@ from renku.core.models.refs import LinkReference from renku.core.utils.datetime8601 import parse_date from renku.core.utils.doi import extract_doi, is_doi -from renku.core.utils.urls import get_slug +from renku.core.utils.urls import get_host, get_slug NoneType = type(None) @@ -597,15 +597,6 @@ def _replace_identifier(self, new_identifier): self.url = self._id self._label = self.identifier - def _get_host(self): - # Determine the hostname for the resource URIs. - # If RENKU_DOMAIN is set, it overrides the host from remote. - # Default is localhost. - host = "localhost" - if self.client: - host = self.client.remote.get("host") or host - return os.environ.get("RENKU_DOMAIN") or host - def _set_id(self): self._id = generate_dataset_id(client=self.client, identifier=self.identifier) @@ -618,7 +609,7 @@ def __attrs_post_init__(self): self._label = self.identifier if self.derived_from: - host = self._get_host() + host = get_host(self.client) derived_from_id = self.derived_from._id derived_from_url = self.derived_from.url.get("@id") u = urlparse(derived_from_url) diff --git a/renku/core/models/entity.py b/renku/core/models/entity.py new file mode 100644 index 0000000000..60890b573e --- /dev/null +++ b/renku/core/models/entity.py @@ -0,0 +1,126 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018-2021- Swiss Data Science Center (SDSC) +# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and +# Eidgenössische Technische Hochschule Zürich (ETHZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Represent provenance entities.""" + +from pathlib import Path +from typing import List, Union +from urllib.parse import quote + +from renku.core.models.calamus import JsonLDSchema, Nested, fields, prov, renku, wfprov +from renku.core.utils.git import get_object_hash + + +class Entity: + """Represent a file.""" + + def __init__(self, *, checksum: str, id: str = None, path: Union[Path, str]): + assert id is None or isinstance(id, str) + + self.id: str = id or Entity.generate_id(checksum, path) + self.path: Path = path + self.checksum: str = checksum + + @staticmethod + def generate_id(checksum: str, path: Union[Path, str]) -> str: + """Generate an Entity identifier.""" + quoted_path = quote(str(path).strip("/")) + + return f"/entities/{checksum}/{quoted_path}" + + @classmethod + def from_revision(cls, client, path: Union[Path, str], revision: str = "HEAD", find_previous: bool = True): + """Return dependency from given path and revision.""" + if find_previous: + revision = client.find_previous_commit(path, revision=revision) + + client, commit, path = client.resolve_in_submodules(revision, path) + + checksum = get_object_hash(repo=client.repo, revision=revision, path=path) + # TODO: What if checksum is None + # TODO: What would be checksum for a directory if it's not committed yet. + id = cls.generate_id(checksum=checksum, path=path) + + absolute_path = client.path / path + if str(path) != "." and absolute_path.is_dir(): + members = cls.get_directory_members(client, commit, absolute_path) + entity = Collection(id=id, checksum=checksum, path=path, members=members) + else: + entity = cls(id=id, checksum=checksum, path=path) + + return entity + + @classmethod + def get_directory_members(cls, client, commit, absolute_path: Path) -> List["Entity"]: + """Return first-level files/directories in a directory.""" + files_in_commit = commit.stats.files + members: List[Entity] = [] + + for member in absolute_path.iterdir(): + if member.name == ".gitkeep": + continue + + member_path = str(member.relative_to(client.path)) + find_previous = True + + if member_path in files_in_commit: + # we already know the newest commit, no need to look it up + find_previous = False + + try: + assert all(member_path != m.path for m in members) + + members.append(cls.from_revision(client, member_path, commit, find_previous=find_previous)) + except KeyError: + pass + + return members + + +class Collection(Entity): + """Represent a directory with files.""" + + def __init__(self, *, checksum: str, id: str = None, path: Union[Path, str], members: List[Entity] = None): + super().__init__(id=id, checksum=checksum, path=path) + self.members: List[Entity] = members or [] + + +class NewEntitySchema(JsonLDSchema): + """Entity Schema.""" + + class Meta: + """Meta class.""" + + # NOTE: wfprov.Artifact is not removed for compatibility with older project + rdf_type = [prov.Entity, wfprov.Artifact] + model = Entity + + checksum = fields.String(renku.checksum, missing=None) + id = fields.Id() + path = fields.String(prov.atLocation) + + +class NewCollectionSchema(NewEntitySchema): + """Entity Schema.""" + + class Meta: + """Meta class.""" + + rdf_type = prov.Collection + model = Collection + + members = Nested(prov.hadMember, [NewEntitySchema, "NewCollectionSchema"], many=True) diff --git a/renku/core/models/provenance/activity.py b/renku/core/models/provenance/activity.py index e55e344e12..6329852df5 100644 --- a/renku/core/models/provenance/activity.py +++ b/renku/core/models/provenance/activity.py @@ -17,19 +17,18 @@ # limitations under the License. """Represent an execution of a Plan.""" -import pathlib from datetime import datetime -from pathlib import Path from typing import List, Optional, Union -from urllib.parse import quote, urlparse +from urllib.parse import urlparse from uuid import uuid4 -from git import Git, GitCommandError from marshmallow import EXCLUDE +from renku.core.incubation.database import Persistent +from renku.core.models import entities as old_entities from renku.core.models.calamus import JsonLDSchema, Nested, fields, oa, prov, renku from renku.core.models.cwl.annotation import Annotation, AnnotationSchema -from renku.core.models.entities import Collection, CollectionSchema, Entity, EntitySchema +from renku.core.models.entity import Collection, Entity, NewCollectionSchema, NewEntitySchema from renku.core.models.provenance import qualified as old_qualified from renku.core.models.provenance.activities import ProcessRun, WorkflowRun from renku.core.models.provenance.agents import Person, PersonSchema, SoftwareAgent, SoftwareAgentSchema @@ -42,7 +41,7 @@ ) from renku.core.models.workflow.dependency_graph import DependencyGraph from renku.core.models.workflow.plan import Plan, PlanSchema -from renku.core.utils.urls import get_host +from renku.core.utils.git import get_object_hash class Association: @@ -53,6 +52,11 @@ def __init__(self, *, agent: Union[Person, SoftwareAgent] = None, id: str, plan: self.id: str = id self.plan: Plan = plan + @staticmethod + def generate_id(activity_id: str) -> str: + """Generate a Association identifier.""" + return f"{activity_id}/association" # TODO: Does it make sense to use plural name here? + class Usage: """Represent a dependent path.""" @@ -61,6 +65,11 @@ def __init__(self, *, entity: Union[Collection, Entity], id: str): self.entity: Union[Collection, Entity] = entity self.id: str = id + @staticmethod + def generate_id(activity_id: str) -> str: + """Generate a Usage identifier.""" + return f"{activity_id}/usages/{uuid4()}" + class Generation: """Represent an act of generating a path.""" @@ -69,8 +78,13 @@ def __init__(self, *, entity: Union[Collection, Entity], id: str): self.entity: Union[Collection, Entity] = entity self.id: str = id + @staticmethod + def generate_id(activity_id: str) -> str: + """Generate a Generation identifier.""" + return f"{activity_id}/generations/{uuid4()}" + -class Activity: +class Activity(Persistent): """Represent an activity in the repository.""" def __init__( @@ -108,17 +122,18 @@ def __init__( @classmethod def from_process_run(cls, process_run: ProcessRun, plan: Plan, client, order: Optional[int] = None): """Create an Activity from a ProcessRun.""" - hostname = get_host(client) - activity_id = Activity.generate_id(hostname) + activity_id = Activity.generate_id() - association = Association(agent=process_run.association.agent, id=f"{activity_id}/association", plan=plan) + association = Association( + agent=process_run.association.agent, id=Association.generate_id(activity_id), plan=plan + ) # NOTE: The same entity can have the same id during different times in its lifetime (e.g. different commit_sha, # but the same content). When it gets flattened, some fields will have multiple values which will cause an error # during deserialization. Make sure that no such Entity attributes exists (store those information in the # Generation object). - invalidations = [_convert_invalidated_entity(e, activity_id, client) for e in process_run.invalidated] + invalidations = [_convert_invalidated_entity(e, client) for e in process_run.invalidated] generations = [_convert_generation(g, activity_id, client) for g in process_run.generated] usages = [_convert_usage(u, activity_id, client) for u in process_run.qualified_usage] @@ -140,10 +155,10 @@ def from_process_run(cls, process_run: ProcessRun, plan: Plan, client, order: Op ) @staticmethod - def generate_id(hostname: str) -> str: + def generate_id() -> str: """Generate an identifier for an activity.""" # TODO: make id generation idempotent - return f"https://{hostname}/activities/{uuid4()}" + return f"/activities/{uuid4()}" def _convert_usage(usage: old_qualified.Usage, activity_id: str, client) -> Usage: @@ -152,9 +167,7 @@ def _convert_usage(usage: old_qualified.Usage, activity_id: str, client) -> Usag entity = _convert_used_entity(usage.entity, commit_sha, activity_id, client) assert entity, f"Top entity was not found for Usage: {usage._id}, {usage.entity.path}" - id = f"{activity_id}/usages/{uuid4()}" - - return Usage(id=id, entity=entity) + return Usage(id=Usage.generate_id(activity_id), entity=entity) def _convert_generation(generation: old_qualified.Generation, activity_id: str, client) -> Generation: @@ -163,45 +176,43 @@ def _convert_generation(generation: old_qualified.Generation, activity_id: str, entity = _convert_generated_entity(generation.entity, commit_sha, activity_id, client) assert entity, f"Root entity was not found for Generation: {generation._id}" - id = f"{activity_id}/generations/{uuid4()}" + return Generation(id=Generation.generate_id(activity_id), entity=entity) - return Generation(id=id, entity=entity) - -def _convert_used_entity(entity: Entity, revision: str, activity_id: str, client) -> Optional[Entity]: - """Convert an Entity to one with proper metadata. +def _convert_used_entity(entity: old_entities.Entity, revision: str, activity_id: str, client) -> Optional[Entity]: + """Convert an old Entity to one with proper metadata. For Collections, add members that are modified in the same commit or before the revision. """ - assert isinstance(entity, Entity) + assert isinstance(entity, old_entities.Entity) - checksum = _get_object_hash(revision=revision, path=entity.path, client=client) + checksum = get_object_hash(repo=client.repo, revision=revision, path=entity.path) if not checksum: return None - id_ = _generate_entity_id(entity_checksum=checksum, path=entity.path, activity_id=activity_id) - - if isinstance(entity, Collection): - new_entity = Collection(id=id_, checksum=checksum, path=entity.path) + if isinstance(entity, old_entities.Collection): + members = [] for child in entity.members: new_child = _convert_used_entity(child, revision, activity_id, client) if not new_child: continue - new_entity.members.append(new_child) + members.append(new_child) + + new_entity = Collection(checksum=checksum, path=entity.path, members=members) else: - new_entity = Entity(id=id_, checksum=checksum, path=entity.path) + new_entity = Entity(checksum=checksum, path=entity.path) - assert type(new_entity) is type(entity) + assert new_entity.__class__.__name__ == entity.__class__.__name__ return new_entity -def _convert_generated_entity(entity: Entity, revision: str, activity_id: str, client) -> Optional[Entity]: +def _convert_generated_entity(entity: old_entities.Entity, revision: str, activity_id: str, client) -> Optional[Entity]: """Convert an Entity to one with proper metadata. For Collections, add members that are modified in the same commit as revision. """ - assert isinstance(entity, Entity) + assert isinstance(entity, old_entities.Entity) try: entity_commit = client.find_previous_commit(paths=entity.path, revision=revision) @@ -210,77 +221,48 @@ def _convert_generated_entity(entity: Entity, revision: str, activity_id: str, c if entity_commit.hexsha != revision: return None - checksum = _get_object_hash(revision=revision, path=entity.path, client=client) + checksum = get_object_hash(repo=client.repo, revision=revision, path=entity.path) if not checksum: return None - id_ = _generate_entity_id(entity_checksum=checksum, path=entity.path, activity_id=activity_id) - - if isinstance(entity, Collection): - new_entity = Collection(id=id_, checksum=checksum, path=entity.path) + if isinstance(entity, old_entities.Collection): + members = [] for child in entity.members: new_child = _convert_generated_entity(child, revision, activity_id, client) if not new_child: continue - new_entity.members.append(new_child) + members.append(new_child) + + new_entity = Collection(checksum=checksum, path=entity.path, members=members) else: - new_entity = Entity(id=id_, checksum=checksum, path=entity.path) + new_entity = Entity(checksum=checksum, path=entity.path) - assert type(new_entity) is type(entity) + assert new_entity.__class__.__name__ == entity.__class__.__name__ return new_entity -def _convert_invalidated_entity(entity: Entity, activity_id: str, client) -> Optional[Entity]: +def _convert_invalidated_entity(entity: old_entities.Entity, client) -> Optional[Entity]: """Convert an Entity to one with proper metadata.""" - assert isinstance(entity, Entity) - assert not isinstance(entity, Collection), f"Collection passed as invalidated: {entity._id}" + assert isinstance(entity, old_entities.Entity) + assert not isinstance(entity, old_entities.Collection), f"Collection passed as invalidated: {entity._id}" commit_sha = _extract_commit_sha(entity_id=entity._id) commit = client.find_previous_commit(revision=commit_sha, paths=entity.path) - commit_sha = commit.hexsha - checksum = _get_object_hash(revision=commit_sha, path=entity.path, client=client) + revision = commit.hexsha + checksum = get_object_hash(repo=client.repo, revision=revision, path=entity.path) if not checksum: - # Entity was deleted at commit_sha; get the one before it to have object_id - checksum = _get_object_hash(revision=f"{commit_sha}~", path=entity.path, client=client) + # Entity was deleted at revision; get the one before it to have object_id + checksum = get_object_hash(repo=client.repo, revision=f"{revision}~", path=entity.path) if not checksum: - print(f"Cannot find invalidated entity hash for {entity._id} at {commit_sha}:{entity.path}") + print(f"Cannot find invalidated entity hash for {entity._id} at {revision}:{entity.path}") return - id_ = _generate_entity_id(entity_checksum=checksum, path=entity.path, activity_id=activity_id) - new_entity = Entity(id=id_, checksum=checksum, path=entity.path) - assert type(new_entity) is type(entity) - - return new_entity - - -def _generate_entity_id(entity_checksum: str, path: Union[Path, str], activity_id: str) -> str: - quoted_path = quote(str(path)) - path = pathlib.posixpath.join("blob", entity_checksum, quoted_path) - - return urlparse(activity_id)._replace(path=path).geturl() + new_entity = Entity(checksum=checksum, path=entity.path) + assert new_entity.__class__.__name__ == entity.__class__.__name__ -def _get_object_hash(revision: str, path: Union[Path, str], client): - try: - return client.repo.git.rev_parse(f"{revision}:{str(path)}") - except GitCommandError: - # NOTE: The file can be in a submodule or it was not there when the command ran but was there when workflows - # were migrated (this can happen only for Usage); the project might be broken too. - return _get_object_hash_from_submodules(path, client) - - -def _get_object_hash_from_submodules(path: Union[Path, str], client) -> str: - for submodule in client.repo.submodules: - try: - path_in_submodule = Path(path).relative_to(submodule.path) - except ValueError: - continue - else: - try: - return Git(submodule.abspath).rev_parse(f"HEAD:{str(path_in_submodule)}") - except GitCommandError: - pass + return new_entity def _extract_commit_sha(entity_id: str) -> str: @@ -358,7 +340,7 @@ def get_process_runs(workflow_run: WorkflowRun) -> List[ProcessRun]: assert len(run.subprocesses) == 1, f"Run in ProcessRun has multiple steps: {run._id}" run = run.subprocesses[0] - plan = Plan.from_run(run=run, hostname=get_host(client)) + plan = Plan.from_run(run=run) plan = dependency_graph.add(plan) activity = Activity.from_process_run(process_run=process_run, plan=plan, client=client) @@ -398,7 +380,7 @@ class Meta: id = fields.Id() # TODO: DatasetSchema, DatasetFileSchema - entity = Nested(prov.entity, [EntitySchema, CollectionSchema]) + entity = Nested(prov.entity, [NewEntitySchema, NewCollectionSchema]) class GenerationSchema(JsonLDSchema): @@ -413,7 +395,7 @@ class Meta: id = fields.Id() # TODO: DatasetSchema, DatasetFileSchema - entity = Nested(prov.qualifiedGeneration, [EntitySchema, CollectionSchema], reverse=True) + entity = Nested(prov.qualifiedGeneration, [NewEntitySchema, NewCollectionSchema], reverse=True) class ActivitySchema(JsonLDSchema): @@ -432,7 +414,7 @@ class Meta: ended_at_time = fields.DateTime(prov.endedAtTime, add_value_types=True) generations = Nested(prov.activity, GenerationSchema, reverse=True, many=True, missing=None) id = fields.Id() - invalidations = Nested(prov.wasInvalidatedBy, EntitySchema, reverse=True, many=True, missing=None) + invalidations = Nested(prov.wasInvalidatedBy, NewEntitySchema, reverse=True, many=True, missing=None) order = fields.Integer(renku.order) parameters = Nested( renku.parameter, diff --git a/renku/core/models/provenance/agents.py b/renku/core/models/provenance/agents.py index 4d4fe0ff07..00a3c59e45 100644 --- a/renku/core/models/provenance/agents.py +++ b/renku/core/models/provenance/agents.py @@ -17,51 +17,75 @@ # limitations under the License. """Represent provenance agents.""" -import os -import pathlib import re -import urllib import uuid from urllib.parse import quote -import attr -from attr.validators import instance_of from calamus.schema import JsonLDSchema from marshmallow import EXCLUDE from renku.core.models.calamus import StringList, fields, prov, rdfs, schema, wfprov from renku.core.models.git import get_user_info +from renku.core.utils.urls import get_host from renku.version import __version__, version_url -@attr.s(slots=True) class Person: """Represent a person.""" - client = attr.ib(default=None, kw_only=True) + def __init__( + self, + *, + affiliation: str = None, + alternate_name: str = None, + email: str = None, + id: str = None, + label: str = None, + name: str, + ): + self.validate_email(email) + + self.affiliation: str = affiliation + self.alternate_name: str = alternate_name + self.email: str = email + self.id: str = id + self.label: str = label or name + self.name: str = name - name = attr.ib(kw_only=True, validator=instance_of(str)) - email = attr.ib(default=None, kw_only=True) - label = attr.ib(kw_only=True) - affiliation = attr.ib(default=None, kw_only=True) - alternate_name = attr.ib(default=None, kw_only=True) - _id = attr.ib(default=None, kw_only=True) + # handle the case where ids were improperly set + if self.id == "mailto:None" or not self.id or self.id.startswith("_:"): + self.id = Person.generate_id(self.email, self.full_identity, hostname=get_host(client=None)) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, Person): + return False + return self.id == other.id and self.full_identity == other.full_identity + + def __hash__(self): + return hash((self.id, self.full_identity)) + + @staticmethod + def generate_id(email, full_identity, hostname): + """Generate identifier for Person.""" + if email: + return f"mailto:{email}" - def default_id(self): - """Set the default id.""" - return generate_person_id(email=self.email, client=self.client, full_identity=self.full_identity) + id = full_identity or str(uuid.uuid4()) + id = quote(id, safe="") - @email.validator - def check_email(self, attribute, value): + # TODO: Remove hostname part once migrating to new metadata + return f"https://{hostname}/persons/{id}" + + @staticmethod + def validate_email(email): """Check that the email is valid.""" - if self.email and not (isinstance(value, str) and re.match(r"[^@]+@[^@]+\.[^@]+", value)): + if not email: + return + if not isinstance(email, str) or not re.match(r"[^@]+@[^@]+\.[^@]+", email): raise ValueError("Email address is invalid.") - @label.default - def default_label(self): - """Set the default label.""" - return self.name - @classmethod def from_commit(cls, commit): """Create an instance from a Git commit.""" @@ -91,7 +115,7 @@ def full_identity(self): def from_git(cls, git): """Create an instance from a Git repo.""" name, email = get_user_info(git) - return cls(name=name, email=email) + return cls(email=email, name=name) @classmethod def from_string(cls, string): @@ -104,12 +128,12 @@ def from_string(cls, string): affiliation = affiliation.strip() affiliation = affiliation or None - return cls(name=name, email=email, affiliation=affiliation) + return cls(affiliation=affiliation, email=email, name=name) @classmethod - def from_dict(cls, obj): + def from_dict(cls, data): """Create and instance from a dictionary.""" - return cls(**obj) + return cls(**data) @classmethod def from_jsonld(cls, data): @@ -121,15 +145,6 @@ def from_jsonld(cls, data): return PersonSchema().load(data) - def __attrs_post_init__(self): - """Finish object initialization.""" - # handle the case where ids were improperly set - if self._id == "mailto:None" or not self._id or self._id.startswith("_:"): - self._id = self.default_id() - - if self.label is None: - self.label = self.default_label() - class PersonSchema(JsonLDSchema): """Person schema.""" @@ -141,67 +156,46 @@ class Meta: model = Person unknown = EXCLUDE - name = StringList(schema.name, missing=None) - email = fields.String(schema.email, missing=None) - label = StringList(rdfs.label, missing=None) affiliation = StringList(schema.affiliation, missing=None) alternate_name = StringList(schema.alternateName, missing=None) - _id = fields.Id(init_name="id") + email = fields.String(schema.email, missing=None) + id = fields.Id() + label = StringList(rdfs.label, missing=None) + name = StringList(schema.name, missing=None) -@attr.s(frozen=True, slots=True) class SoftwareAgent: """Represent executed software.""" - label = attr.ib(kw_only=True) + def __init__(self, *, id: str, label: str): + self.id: str = id + self.label: str = label + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, SoftwareAgent): + return False + return self.id == other.id and self.label == other.label - _id = attr.ib(kw_only=True) + def __hash__(self): + return hash((self.id, self.label)) @classmethod def from_commit(cls, commit): """Create an instance from a Git commit.""" + # FIXME: This method can return a Person object but SoftwareAgent is not its super class author = Person.from_commit(commit) if commit.author != commit.committer: return cls(label=commit.committer.name, id=commit.committer.email) return author - @classmethod - def from_jsonld(cls, data): - """Create an instance from JSON-LD data.""" - if isinstance(data, cls): - return data - if not isinstance(data, dict): - raise ValueError(data) - - return SoftwareAgentSchema().load(data) - - def as_jsonld(self): - """Create JSON-LD.""" - return SoftwareAgentSchema().dump(self) - # set up the default agent renku_agent = SoftwareAgent(label="renku {0}".format(__version__), id=version_url) -def generate_person_id(client, email, full_identity): - """Generate Person default id.""" - if email: - return "mailto:{email}".format(email=email) - - host = "localhost" - if client: - host = client.remote.get("host") or host - host = os.environ.get("RENKU_DOMAIN") or host - - id_ = full_identity or str(uuid.uuid4()) - - return urllib.parse.urljoin( - "https://{host}".format(host=host), pathlib.posixpath.join("/persons", quote(id_, safe="")) - ) - - class SoftwareAgentSchema(JsonLDSchema): """SoftwareAgent schema.""" @@ -213,4 +207,4 @@ class Meta: unknown = EXCLUDE label = fields.String(rdfs.label) - _id = fields.Id(init_name="id") + id = fields.Id() diff --git a/renku/core/models/provenance/provenance_graph.py b/renku/core/models/provenance/provenance_graph.py index fdaf2ccc93..7ecced3ded 100644 --- a/renku/core/models/provenance/provenance_graph.py +++ b/renku/core/models/provenance/provenance_graph.py @@ -24,6 +24,7 @@ from marshmallow import EXCLUDE from rdflib import ConjunctiveGraph +from renku.core.incubation.database import Database from renku.core.models.calamus import JsonLDSchema, Nested, schema from renku.core.models.provenance.activity import Activity, ActivityCollection, ActivitySchema @@ -37,7 +38,8 @@ def __init__(self, activities: List[Activity] = None): self._custom_bindings: Dict[str, str] = {} self._graph: Optional[ConjunctiveGraph] = None self._loaded: bool = False - self._order: int = 1 if len(self.activities) == 0 else max([a.order for a in self.activities]) + 1 + # TODO: Remove _order and rely on Activity's ended_at_time and started_at_time for ordering + self._order: int = len(self.activities) + 1 self._path: Optional[Path] = None @property @@ -52,8 +54,6 @@ def custom_bindings(self, custom_bindings: Dict[str, str]): def add(self, node: Union[Activity, ActivityCollection]) -> None: """Add an Activity/ActivityCollection to the graph.""" - assert self._loaded - activity_collection = node if isinstance(node, ActivityCollection) else ActivityCollection(activities=[node]) for activity in activity_collection.activities: @@ -62,6 +62,18 @@ def add(self, node: Union[Activity, ActivityCollection]) -> None: self._order += 1 self.activities.append(activity) + self._p_changed = True + + @classmethod + def from_database(cls, database: Database) -> "ProvenanceGraph": + """Return an instance from a metadata database.""" + activity_tree = database.get("activities") + activities = list(activity_tree.values()) + self = ProvenanceGraph(activities=activities) + # NOTE: If we sort then all ghost objects will be loaded which is not what we want + # self.activities.sort(key=lambda e: e.order) + return self + @classmethod def from_json(cls, path: Union[Path, str], lazy: bool = False) -> "ProvenanceGraph": """Return an instance from a JSON file.""" diff --git a/renku/core/models/workflow/dependency_graph.py b/renku/core/models/workflow/dependency_graph.py index 433d6f3880..ee9ad1c18b 100644 --- a/renku/core/models/workflow/dependency_graph.py +++ b/renku/core/models/workflow/dependency_graph.py @@ -25,6 +25,7 @@ import networkx from marshmallow import EXCLUDE +from renku.core.incubation.database import Database from renku.core.models.calamus import JsonLDSchema, Nested, schema from renku.core.models.workflow.plan import Plan, PlanSchema @@ -39,9 +40,27 @@ def __init__(self, plans: List[Plan] = None): self._plans: List[Plan] = plans or [] self._path = None - self._graph = networkx.DiGraph() - self._graph.add_nodes_from(self._plans) - self._connect_all_nodes() + # NOTE: If we connect nodes then all ghost objects will be loaded which is not what we want + self._graph = None + + @classmethod + def from_database(cls, database: Database) -> "DependencyGraph": + """Return an instance from a metadata database.""" + plan_tree = database.get("plans") + plans = list(plan_tree.values()) + self = DependencyGraph(plans=plans) + + return self + + @property + def graph(self) -> networkx.DiGraph: + """A networkx.DiGraph containing all plans.""" + if not self._graph: + self._graph = networkx.DiGraph() + self._graph.add_nodes_from(self._plans) + self._connect_all_nodes() + + return self._graph @property def plans(self) -> List[Plan]: @@ -64,7 +83,7 @@ def add(self, plan: Plan) -> Plan: self._add_helper(plan) # FIXME some existing projects have cyclic dependency; make this check outside this model. - # assert networkx.algorithms.dag.is_directed_acyclic_graph(self._graph) + # assert networkx.algorithms.dag.is_directed_acyclic_graph(self.graph) return plan @@ -77,15 +96,15 @@ def _find_similar_plan(self, plan: Plan) -> Optional[Plan]: def _add_helper(self, plan: Plan): self._plans.append(plan) - self._graph.add_node(plan) + self.graph.add_node(plan) self._connect_node_to_others(node=plan) def _connect_all_nodes(self): - for node in self._graph: + for node in self.graph: self._connect_node_to_others(node) def _connect_node_to_others(self, node: Plan): - for other_node in self._graph: + for other_node in self.graph: self._connect_two_nodes(from_=node, to_=other_node) self._connect_two_nodes(from_=other_node, to_=node) @@ -93,19 +112,19 @@ def _connect_two_nodes(self, from_: Plan, to_: Plan): for o in from_.outputs: for i in to_.inputs: if DependencyGraph._is_super_path(o.default_value, i.default_value): - self._graph.add_edge(from_, to_, name=o.default_value) + self.graph.add_edge(from_, to_, name=o.default_value) def visualize_graph(self): """Visualize graph using matplotlib.""" - networkx.draw(self._graph, with_labels=True, labels={n: n.name for n in self._graph.nodes}) + networkx.draw(self.graph, with_labels=True, labels={n: n.name for n in self.graph.nodes}) - pos = networkx.spring_layout(self._graph) - edge_labels = networkx.get_edge_attributes(self._graph, "name") - networkx.draw_networkx_edge_labels(self._graph, pos=pos, edge_labels=edge_labels) + pos = networkx.spring_layout(self.graph) + edge_labels = networkx.get_edge_attributes(self.graph, "name") + networkx.draw_networkx_edge_labels(self.graph, pos=pos, edge_labels=edge_labels) def to_png(self, path): """Create a PNG image from graph.""" - networkx.drawing.nx_pydot.to_pydot(self._graph).write_png(path) + networkx.drawing.nx_pydot.to_pydot(self.graph).write_png(path) @staticmethod def _is_super_path(parent, child): @@ -117,7 +136,7 @@ def get_dependent_paths(self, plan_id, path): """Get a list of downstream paths.""" nodes = deque() node: Plan - for node in self._graph: + for node in self.graph: if plan_id == node.id and any(self._is_super_path(path, p.default_value) for p in node.inputs): nodes.append(node) @@ -129,7 +148,7 @@ def get_dependent_paths(self, plan_id, path): outputs_paths = [o.default_value for o in node.outputs] paths.update(outputs_paths) - nodes.extend(self._graph.successors(node)) + nodes.extend(self.graph.successors(node)) return paths @@ -146,13 +165,13 @@ def node_has_deleted_inputs(node_): nodes_with_deleted_inputs = set() node: Plan for plan_id, path, _ in modified_usages: - for node in self._graph: + for node in self.graph: if plan_id == node.id and any(self._is_super_path(path, p.default_value) for p in node.inputs): nodes.add(node) - nodes.update(networkx.algorithms.dag.descendants(self._graph, node)) + nodes.update(networkx.algorithms.dag.descendants(self.graph, node)) sorted_nodes = [] - for node in networkx.algorithms.dag.topological_sort(self._graph): + for node in networkx.algorithms.dag.topological_sort(self.graph): if node in nodes: if node_has_deleted_inputs(node): nodes_with_deleted_inputs.add(node) diff --git a/renku/core/models/workflow/parameter.py b/renku/core/models/workflow/parameter.py index 2497ab38ce..c999a05887 100644 --- a/renku/core/models/workflow/parameter.py +++ b/renku/core/models/workflow/parameter.py @@ -25,12 +25,28 @@ from marshmallow import EXCLUDE from renku.core.models.calamus import JsonLDSchema, Nested, fields, rdfs, renku, schema -from renku.core.models.workflow.parameters import MappedIOStream, MappedIOStreamSchema from renku.core.utils.urls import get_slug RANDOM_ID_LENGTH = 4 +class MappedIOStream: + """Represents an IO stream (stdin, stdout, stderr).""" + + STREAMS = ["stdin", "stdout", "stderr"] + + def __init__(self, *, id: str = None, stream_type: str): + assert stream_type in MappedIOStream.STREAMS + + self.id: str = id or MappedIOStream.generate_id(stream_type) + self.stream_type = stream_type + + @staticmethod + def generate_id(stream_type: str) -> str: + """Generate an id for parameters.""" + return f"/iostreams/{stream_type}" + + class CommandParameterBase: """Represents a parameter for a Plan.""" @@ -61,9 +77,9 @@ def __init__( @staticmethod def _generate_id(plan_id: str, parameter_type: str, position: Optional[int], postfix: str = None) -> str: """Generate an id for parameters.""" - # https://localhost/plans/723fd784-9347-4081-84de-a6dbb067545b/inputs/1 - # https://localhost/plans/723fd784-9347-4081-84de-a6dbb067545b/inputs/stdin - # https://localhost/plans/723fd784-9347-4081-84de-a6dbb067545b/inputs/dda5fcbf-0098-4917-be46-dc12f5f7b675 + # /plans/723fd784-9347-4081-84de-a6dbb067545b/inputs/1 + # /plans/723fd784-9347-4081-84de-a6dbb067545b/inputs/stdin + # /plans/723fd784-9347-4081-84de-a6dbb067545b/inputs/dda5fcbf-0098-4917-be46-dc12f5f7b675 position = str(position) if position is not None else str(uuid4()) postfix = urllib.parse.quote(postfix) if postfix else position return f"{plan_id}/{parameter_type}/{postfix}" @@ -222,6 +238,20 @@ def _get_default_name(self) -> str: return self._generate_name(base="output") +class MappedIOStreamSchema(JsonLDSchema): + """MappedIOStream schema.""" + + class Meta: + """Meta class.""" + + rdf_type = renku.IOStream + model = MappedIOStream + unknown = EXCLUDE + + id = fields.Id() + stream_type = fields.String(renku.streamType) + + class CommandParameterBaseSchema(JsonLDSchema): """CommandParameterBase schema.""" diff --git a/renku/core/models/workflow/plan.py b/renku/core/models/workflow/plan.py index ee299d639a..c6c435f7ef 100644 --- a/renku/core/models/workflow/plan.py +++ b/renku/core/models/workflow/plan.py @@ -27,6 +27,7 @@ from marshmallow import EXCLUDE from werkzeug.utils import secure_filename +from renku.core.incubation.database import Persistent from renku.core.models.calamus import JsonLDSchema, Nested, fields, prov, renku, schema from renku.core.models.entities import Entity from renku.core.models.workflow import parameters as old_parameter @@ -37,6 +38,7 @@ CommandOutputSchema, CommandParameter, CommandParameterSchema, + MappedIOStream, ) from renku.core.models.workflow.run import Run from renku.core.utils.urls import get_host @@ -44,7 +46,7 @@ MAX_GENERATED_NAME_LENGTH = 25 -class Plan: +class Plan(Persistent): """Represent a `renku run` execution template.""" def __init__( @@ -75,11 +77,8 @@ def __init__( if not self.name: self.name = self._get_default_name() - def __repr__(self): - return self.name - @classmethod - def from_run(cls, run: Run, hostname: str): + def from_run(cls, run: Run): """Create a Plan from a Run.""" assert not run.subprocesses, f"Cannot create a Plan from a Run with subprocesses: {run._id}" @@ -88,7 +87,7 @@ def extract_run_uuid(run_id: str) -> str: return run_id.rstrip("/").rsplit("/", maxsplit=1)[-1] uuid = extract_run_uuid(run._id) - plan_id = cls.generate_id(hostname=hostname, uuid=uuid) + plan_id = cls.generate_id(uuid=uuid) def convert_argument(argument: old_parameter.CommandArgument) -> CommandParameter: """Convert an old CommandArgument to a new CommandParameter.""" @@ -108,12 +107,16 @@ def convert_input(input: old_parameter.CommandInput) -> CommandInput: """Convert an old CommandInput to a new CommandInput.""" assert isinstance(input, old_parameter.CommandInput) + mapped_to = input.mapped_to + if mapped_to: + mapped_to = MappedIOStream(stream_type=mapped_to.stream_type) + return CommandInput( default_value=input.consumes.path, description=input.description, id=CommandInput.generate_id(plan_id=plan_id, postfix=PurePosixPath(input._id).name), label=None, - mapped_to=input.mapped_to, + mapped_to=mapped_to, name=input.name, position=input.position, prefix=input.prefix, @@ -123,13 +126,17 @@ def convert_output(output: old_parameter.CommandOutput) -> CommandOutput: """Convert an old CommandOutput to a new CommandOutput.""" assert isinstance(output, old_parameter.CommandOutput) + mapped_to = output.mapped_to + if mapped_to: + mapped_to = MappedIOStream(stream_type=mapped_to.stream_type) + return CommandOutput( create_folder=output.create_folder, default_value=output.produces.path, description=output.description, id=CommandOutput.generate_id(plan_id=plan_id, postfix=PurePosixPath(output._id).name), label=None, - mapped_to=output.mapped_to, + mapped_to=mapped_to, name=output.name, position=output.position, prefix=output.prefix, @@ -148,10 +155,10 @@ def convert_output(output: old_parameter.CommandOutput) -> CommandOutput: ) @staticmethod - def generate_id(hostname: str, uuid: str) -> str: + def generate_id(uuid: str) -> str: """Generate an identifier for Plan.""" uuid = uuid or str(uuid4()) - return f"https://{hostname}/plans/{uuid}" + return f"/plans/{uuid}" def _get_default_name(self) -> str: name = "-".join(str(a) for a in self.to_argv()) @@ -236,24 +243,32 @@ def convert_parameter(argument: CommandParameter) -> old_parameter.CommandArgume ) def convert_input(input: CommandInput) -> old_parameter.CommandInput: + mapped_to = input.mapped_to + if mapped_to: + mapped_to = old_parameter.MappedIOStream(id=mapped_to.id, stream_type=mapped_to.stream_type) + return old_parameter.CommandInput( consumes=get_entity(input.default_value), description=input.description, id=input.id.replace(self.id, run_id), label=None, - mapped_to=input.mapped_to, + mapped_to=mapped_to, name=input.name, position=input.position, prefix=input.prefix, ) def convert_output(output: CommandOutput) -> old_parameter.CommandOutput: + mapped_to = output.mapped_to + if mapped_to: + mapped_to = old_parameter.MappedIOStream(id=mapped_to.id, stream_type=mapped_to.stream_type) + return old_parameter.CommandOutput( create_folder=output.create_folder, description=output.description, id=output.id.replace(self.id, run_id), label=None, - mapped_to=output.mapped_to, + mapped_to=mapped_to, name=output.name, position=output.position, prefix=output.prefix, diff --git a/renku/core/utils/git.py b/renku/core/utils/git.py index 3fa204179d..c50438e4f8 100644 --- a/renku/core/utils/git.py +++ b/renku/core/utils/git.py @@ -20,7 +20,11 @@ import math import pathlib import urllib +from pathlib import Path from subprocess import SubprocessError, run +from typing import Optional, Union + +from git import Git, GitCommandError, Repo from renku.core import errors from renku.core.models.git import GitURL @@ -105,3 +109,26 @@ def get_renku_repo_url(remote_url, deployment_hostname=None, access_token=None): hostname = deployment_hostname or parsed_remote.hostname return urllib.parse.urljoin(f"https://{credentials}{hostname}", path) + + +def get_object_hash(repo: Repo, revision: str, path: Union[Path, str]) -> Optional[str]: + """Return git hash of an object in a Repo or its submodule.""" + + def get_object_hash_from_submodules() -> Optional[str]: + for submodule in repo.submodules: + try: + path_in_submodule = Path(path).relative_to(submodule.path) + except ValueError: + continue + else: + try: + return Git(submodule.abspath).rev_parse(f"HEAD:{str(path_in_submodule)}") + except GitCommandError: + pass + + try: + return repo.git.rev_parse(f"{revision}:{str(path)}") + except GitCommandError: + # NOTE: The file can be in a submodule or it was not there when the command ran but was there when workflows + # were migrated (this can happen only for Usage); the project might be broken too. + return get_object_hash_from_submodules() diff --git a/setup.py b/setup.py index 7cd8834b9a..54c0206d6d 100644 --- a/setup.py +++ b/setup.py @@ -192,6 +192,7 @@ def run(self): "wcmatch>=6.0.0,<8.3", "werkzeug>=0.15.5,<2.0.2", "yagup>=0.1.1", + "ZODB==5.6.0", ] diff --git a/tests/cli/test_graph.py b/tests/cli/test_graph.py index 407ab17d4d..68b4b74aba 100644 --- a/tests/cli/test_graph.py +++ b/tests/cli/test_graph.py @@ -36,3 +36,15 @@ def test_graph_export_validation(runner, client, directory_tree, run, format): result = runner.invoke(cli, ["graph", "export", "--format", format, "--strict"]) assert 0 == result.exit_code, result.output + + +def test_graph(runner, client, directory_tree, run): + """Test graph generation.""" + assert 0 == runner.invoke(cli, ["dataset", "add", "-c", "my-data", str(directory_tree)]).exit_code + file1 = client.path / DATA_DIR / "my-data" / directory_tree.name / "file1" + file2 = client.path / DATA_DIR / "my-data" / directory_tree.name / "dir1" / "file2" + assert 0 == run(["run", "head", str(file1)], stdout="out1") + assert 0 == run(["run", "tail", str(file2)], stdout="out2") + + result = runner.invoke(cli, ["graph", "generate", "-f"]) + assert 0 == result.exit_code, result.output diff --git a/tests/cli/test_migrate.py b/tests/cli/test_migrate.py index f3e48fa29f..84a648aca1 100644 --- a/tests/cli/test_migrate.py +++ b/tests/cli/test_migrate.py @@ -271,7 +271,7 @@ def test_no_blank_node_after_dataset_migration(isolated_runner, old_dataset_proj dataset = old_dataset_project.load_dataset("201901_us_flights_1") - assert not dataset.creators[0]._id.startswith("_:") + assert not dataset.creators[0].id.startswith("_:") assert not dataset.same_as._id.startswith("_:") assert not dataset.tags[0]._id.startswith("_:") assert isinstance(dataset.license, str) diff --git a/tests/cli/test_run.py b/tests/cli/test_run.py index 88e3543cac..5fe4dba4d8 100644 --- a/tests/cli/test_run.py +++ b/tests/cli/test_run.py @@ -22,7 +22,6 @@ import pytest from renku.cli import cli -from renku.core.models.provenance.provenance_graph import ProvenanceGraph def test_run_simple(runner, project): @@ -158,7 +157,7 @@ def test_run_argument_parameters(runner, client): assert "delta-3" == plan.parameters[0].name assert "n-1" == plan.parameters[1].name - provenance_graph = ProvenanceGraph.from_json(client.provenance_graph_path) + provenance_graph = client.provenance_graph assert 1 == len(provenance_graph.activities) activity = provenance_graph.activities[0] diff --git a/tests/core/commands/test_client.py b/tests/core/commands/test_client.py index 4fc9c86042..a243fdc0ae 100644 --- a/tests/core/commands/test_client.py +++ b/tests/core/commands/test_client.py @@ -59,6 +59,7 @@ def test_safe_class_attributes(tmpdir): "ACTIVITY_INDEX", "CACHE", "CONFIG_NAME", + "DATABASE_PATH", "DATASETS", "DATASET_IMAGES", "DATASETS_PROVENANCE", @@ -84,8 +85,8 @@ def test_safe_class_attributes(tmpdir): "_CMD_STORAGE_TRACK", "_CMD_STORAGE_UNTRACK", "_LFS_HEADER", + "_database", "_datasets_provenance", - "_dependency_graph", "_global_config_dir", "_temporary_datasets_path", ] diff --git a/tests/core/commands/test_serialization.py b/tests/core/commands/test_serialization.py index 027b0297fe..0f87362684 100644 --- a/tests/core/commands/test_serialization.py +++ b/tests/core/commands/test_serialization.py @@ -41,7 +41,7 @@ def test_dataset_deserialization(client_with_datasets): for attribute, type_ in dataset_types.items(): assert type(dataset.__getattribute__(attribute)) in type_ - creator_types = {"email": str, "_id": str, "name": str, "affiliation": str} + creator_types = {"email": str, "id": str, "name": str, "affiliation": str} creator = client_with_datasets.load_dataset("dataset-1").creators[0] @@ -71,9 +71,9 @@ def test_dataset_creator_email(dataset_metadata): # modify the dataset metadata to change the creator dataset = Dataset.from_jsonld(dataset_metadata, client=LocalClient(".")) - dataset.creators[0]._id = "mailto:None" + dataset.creators[0].id = "mailto:None" dataset_broken = Dataset.from_jsonld(dataset.as_jsonld(), client=LocalClient(".")) - assert "mailto:None" not in dataset_broken.creators[0]._id + assert "mailto:None" not in dataset_broken.creators[0].id def test_calamus(client, dataset_metadata_before_calamus): diff --git a/tests/core/commands/test_storage.py b/tests/core/commands/test_storage.py index 0092c70324..f8be9392fd 100644 --- a/tests/core/commands/test_storage.py +++ b/tests/core/commands/test_storage.py @@ -104,7 +104,6 @@ def test_lfs_migrate(runner, project, client): client.repo.git.add("*") client.repo.index.commit("add files") dataset_checksum = client.repo.head.commit.tree["dataset_file"].hexsha - workflow_checksum = client.repo.head.commit.tree["workflow_file"].hexsha result = runner.invoke(cli, ["graph", "generate"]) assert 0 == result.exit_code @@ -121,7 +120,7 @@ def test_lfs_migrate(runner, project, client): previous_head = client.repo.head.commit.hexsha result = runner.invoke(cli, ["storage", "migrate", "--all"], input="y") - assert 0 == result.exit_code + assert 0 == result.exit_code, result.output assert "dataset_file" in result.output assert "workflow_file" in result.output assert "regular_file" in result.output @@ -129,13 +128,10 @@ def test_lfs_migrate(runner, project, client): assert previous_head != client.repo.head.commit.hexsha changed_files = client.repo.head.commit.stats.files.keys() - assert ".renku/dataset.json" in changed_files - assert ".renku/provenance.json" in changed_files + assert ".renku/metadata/activities" not in changed_files assert dataset_checksum not in (client.path / ".renku" / "dataset.json").read_text() - assert workflow_checksum not in (client.path / ".renku" / "provenance.json").read_text() - def test_lfs_migrate_no_changes(runner, project, client): """Test ``renku storage migrate`` command without broken files.""" @@ -173,7 +169,6 @@ def test_lfs_migrate_explicit_path(runner, project, client): client.repo.git.add("*") client.repo.index.commit("add files") dataset_checksum = client.repo.head.commit.tree["dataset_file"].hexsha - workflow_checksum = client.repo.head.commit.tree["workflow_file"].hexsha result = runner.invoke(cli, ["graph", "generate"]) assert 0 == result.exit_code @@ -193,6 +188,4 @@ def test_lfs_migrate_explicit_path(runner, project, client): assert dataset_checksum in (client.path / ".renku" / "dataset.json").read_text() - assert workflow_checksum in (client.path / ".renku" / "provenance.json").read_text() - assert "oid sha256:" in (client.path / "regular_file").read_text() diff --git a/tests/core/fixtures/core_database.py b/tests/core/fixtures/core_database.py new file mode 100644 index 0000000000..aea2d1faf5 --- /dev/null +++ b/tests/core/fixtures/core_database.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2021 Swiss Data Science Center (SDSC) +# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and +# Eidgenössische Technische Hochschule Zürich (ETHZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Renku fixtures for metadata Database.""" + +import copy +import datetime +from typing import Tuple + +import pytest +from ZODB.POSException import POSKeyError + +from renku.core.incubation.database import Database + + +class DummyStorage: + """An in-memory storage class.""" + + def __init__(self): + self._files = {} + self._modification_dates = {} + + def store(self, filename: str, data): + """Store object.""" + assert isinstance(filename, str) + + self._files[filename] = data + self._modification_dates[filename] = datetime.datetime.now() + + def load(self, filename: str): + """Load data for object with object id oid.""" + assert isinstance(filename, str) + + if filename not in self._files: + raise POSKeyError(filename) + + return copy.deepcopy(self._files[filename]) + + def get_modification_date(self, filename: str): + """Return modification date of a file.""" + return self._modification_dates[filename] + + def exists(self, filename: str): + """Return True if filename exists in the storage.""" + return filename in self._files + + +@pytest.fixture +def database() -> Tuple[Database, DummyStorage]: + """A Database with in-memory storage.""" + from renku.core.models.provenance.activity import Activity + from renku.core.models.workflow.plan import Plan + + storage = DummyStorage() + database = Database(storage=storage) + + database.add_index(name="activities", object_type=Activity, attribute="id") + database.add_index(name="plans", object_type=Plan, attribute="id") + + yield database, storage diff --git a/tests/core/incubation/test_database.py b/tests/core/incubation/test_database.py new file mode 100644 index 0000000000..8222bdbc05 --- /dev/null +++ b/tests/core/incubation/test_database.py @@ -0,0 +1,417 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2017-2021- Swiss Data Science Center (SDSC) +# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and +# Eidgenössische Technische Hochschule Zürich (ETHZ). +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test metadata Database.""" + +import pytest +from persistent import GHOST, UPTODATE +from persistent.list import PersistentList +from persistent.mapping import PersistentMapping + +from renku.cli import cli +from renku.core.incubation.database import PERSISTED, Database +from renku.core.models.entity import Entity +from renku.core.models.provenance.activity import Activity, Association, Usage +from renku.core.models.workflow.plan import Plan + + +def test_database_create(client, runner): + """Test database files are created in an empty project.""" + assert 0 == runner.invoke(cli, ["graph", "generate"]).exit_code + + assert not client.repo.is_dirty() + root_objects = ["root", "activities", "plans"] + for filename in root_objects: + assert (client.database_path / filename).exists() + + assert client.has_graph_files() + + +@pytest.mark.xfail +def test_database_recreate(client, runner): + """Test can force re-create the database.""" + assert 0 == runner.invoke(cli, ["graph", "generate"]).exit_code + + result = runner.invoke(cli, ["graph", "generate", "-f"]) + + assert 0 == result.exit_code, result.output + assert client.has_graph_files() + + +def test_database_add(database): + """Test adding an object to an index.""" + database, storage = database + + id = "/activities/42" + activity = Activity(id=id) + index = database.get("activities") + index.add(activity) + database.commit() + + root_objects = ["root", "activities", "plans"] + for filename in root_objects: + assert storage.exists(filename) + + oid = Database.hash_id(id) + assert storage.exists(oid) + + +def test_database_add_using_set_item(database): + """Test adding an object to the database using __setitem__.""" + database, storage = database + + id = "/activities/42" + activity_1 = Activity(id=id) + database["activities"][id] = activity_1 + + activity_2 = list(database.root["activities"].values())[0] + + assert activity_1 is activity_2 + + +def test_database_index_with_no_automatic_key(database): + """Test indexes with no automatic key attribute.""" + database, storage = database + index = database.add_index(name="manual", object_type=Activity) + + id = "/activities/42" + activity = Activity(id=id) + index.add(activity, key=id) + + database.commit() + + new_database = Database(storage=storage) + activity = new_database["manual"][id] + + assert id == activity.id + + oid = Database.hash_id(id) + assert storage.exists(oid) + + +def test_database_add_with_incorrect_key(database): + """Test adding an object to the database using __setitem__ with an incorrect key should fail.""" + database, storage = database + + id = "/activities/42" + activity_1 = Activity(id=id) + + with pytest.raises(AssertionError) as e: + database["activities"]["incorrect-key"] = activity_1 + + assert "Incorrect key for index 'activities': 'incorrect-key' != '/activities/42'" in str(e) + + +def test_database_add_fails_when_no_key_and_no_automatic_key(database): + """Test adding to an index with no automatic key fails if no key is provided.""" + database, storage = database + index = database.add_index(name="manual", object_type=Activity) + + activity = Activity(id="/activities/42") + + with pytest.raises(AssertionError) as e: + index.add(activity) + + assert "No key is provided" in str(e) + + +def test_database_no_file_created_if_not_committed(database): + """Test adding an object to a database does not create a file before commit.""" + database, storage = database + database.commit() + + assert storage.exists("root") + + id = "/activities/42" + activity = Activity(id=id) + database.get("activities").add(activity) + + oid = Database.hash_id(id) + assert not storage.exists(oid) + + +def test_database_update_required_object_only(database): + """Test adding an object to the database does not cause an update to all other objects.""" + database, storage = database + + index = database.get("activities") + + id_1 = "/activities/42" + activity_1 = Activity(id=id_1) + index.add(activity_1) + database.commit() + oid_1 = Database.hash_id(id_1) + modification_time_before = storage.get_modification_date(oid_1) + + id_2 = "/activities/43" + activity_2 = Activity(id=id_2) + index.add(activity_2) + database.commit() + + modification_time_after = storage.get_modification_date(oid_1) + + assert modification_time_before == modification_time_after + + +def test_database_update_required_root_objects_only(database): + """Test adding an object to an index does not cause an update to other indexes.""" + database, storage = database + + _ = database.root + database.commit() + + entity_modification_time_before = storage.get_modification_date("plans") + activity_modification_time_before = storage.get_modification_date("activities") + + activity = Activity(id="/activities/42") + database.get("activities").add(activity) + database.commit() + + entity_modification_time_after = storage.get_modification_date("plans") + activity_modification_time_after = storage.get_modification_date("activities") + + assert entity_modification_time_before == entity_modification_time_after + assert activity_modification_time_before != activity_modification_time_after + + +def test_database_add_non_persistent(database): + """Test adding a non-Persistent object to the database raises an error.""" + database, _ = database + + class Dummy: + id = 42 + + with pytest.raises(AssertionError) as e: + object = Dummy() + database.get("activities").add(object) + + assert "Cannot add objects of type" in str(e) + + +def test_database_loads_only_required_objects(database): + """Test loading an object does not load its Persistent members.""" + database, storage = database + + plan = Plan(id="/plan/9") + association = Association(id="association", plan=plan) + id = "/activities/42" + activity = Activity(id=id, association=association) + database.get("activities").add(activity) + database.commit() + + new_database = Database(storage=storage) + oid = Database.hash_id(id) + activity = new_database.get(oid) + + # Access a field to make sure that activity is loaded + _ = activity.id + + assert UPTODATE == activity._p_state + assert PERSISTED == activity._p_serial + assert GHOST == activity.association.plan._p_state + + assert UPTODATE == new_database.root["plans"]._p_state + assert UPTODATE == new_database.root["activities"]._p_state + + +def test_database_load_multiple(database): + """Test loading an object from multiple indexes returns the same object.""" + database, storage = database + database.add_index(name="associations", object_type=Activity, attribute="association.id") + + plan = Plan(id="/plan/9") + association = Association(id="/association/42", plan=plan) + id = "/activities/42" + activity = Activity(id=id, association=association) + database.get("activities").add(activity) + database.get("associations").add(activity) + database.commit() + + new_database = Database(storage=storage) + oid = Database.hash_id(id) + activity_1 = new_database.get(oid) + activity_2 = new_database.get("activities").get(id) + activity_3 = new_database.get("associations").get("/association/42") + + assert activity_1 is activity_2 + assert activity_2 is activity_3 + + +def test_database_index_update(database): + """Test adding objects with the same key, updates the Index.""" + database, storage = database + + index_name = "plan-names" + database.add_index(name=index_name, object_type=Plan, attribute="name") + + name = "same-name" + + plan_1 = Plan(id="/plans/42", name=name, description="old") + database.get(index_name).add(plan_1) + plan_2 = Plan(id="/plans/43", name=name, description="new") + database.get(index_name).add(plan_2) + assert plan_2 is database.get(index_name).get(name) + + database.commit() + + plan_3 = Plan(id="/plans/44", name=name, description="newer") + database.get(index_name).add(plan_3) + database.commit() + + new_database = Database(storage=storage) + plans = new_database.get(index_name) + plan = plans.get(name) + + assert "newer" == plan.description + + +def test_database_add_duplicate_index(database): + """Test cannot add an index with the same name.""" + database, _ = database + + same_name = "plans" + + with pytest.raises(AssertionError) as e: + database.add_index(name=same_name, object_type=Plan, attribute="name") + + assert "Index already exists: 'plans'" in str(e) + + +def test_database_index_different_key_type(database): + """Test adding an Index with a different key type.""" + database, storage = database + + index_name = "usages" + index = database.add_index(name=index_name, object_type=Activity, attribute="entity.path", key_type=Usage) + + entity = Entity(checksum="42", path="/dummy/path") + usage = Usage(entity=entity, id="/usages/42") + + activity = Activity(id="/activities/42", usages=[usage]) + database.get(index_name).add(activity, key_object=usage) + database.commit() + + new_database = Database(storage=storage) + usages = new_database[index_name] + activity = usages.get("/dummy/path") + + assert "/activities/42" == activity.id + assert "42" == activity.usages[0].entity.checksum + assert "/dummy/path" == activity.usages[0].entity.path + + key = index.generate_key(activity, key_object=usage) + + assert activity is usages[key] + + +def test_database_wrong_index_key_type(database): + """Test adding to an Index with a wrong key type.""" + database, _ = database + + index_name = "usages" + database.add_index(name=index_name, object_type=Activity, attribute="id", key_type=Usage) + + activity = Activity(id="/activities/42") + + with pytest.raises(AssertionError) as e: + database.get(index_name).add(activity) + + assert "Invalid key type" in str(e) + + +def test_database_missing_attribute(database): + """Test adding to an Index while object does not have the requires attribute.""" + database, _ = database + + index_name = "usages" + database.add_index(name=index_name, object_type=Activity, attribute="missing.attribute") + + activity = Activity(id="/activities/42") + + with pytest.raises(AttributeError) as e: + database.get(index_name).add(activity) + + assert "'Activity' object has no attribute 'missing'" in str(e) + + +def test_database_remove(database): + """Test removing an object from an index.""" + database, storage = database + + id = "/activities/42" + activity = Activity(id=id) + database.get("activities").add(activity) + database.commit() + + database = Database(storage=storage) + database.get("activities").pop(id) + database.commit() + + database = Database(storage=storage) + activity = database.get("activities").get(id, None) + + assert activity is None + # However, the file still exists in the storage + oid = Database.hash_id(id) + assert storage.exists(oid) + + +def test_database_remove_non_existing(database): + """Test removing a non-existing object from an index.""" + database, storage = database + + with pytest.raises(KeyError): + database.get("activities").pop("non-existing-key") + + object = database.get("activities").pop("non-existing-key", None) + + assert object is None + + +def test_database_persistent_collections(database): + """Test using Persistent collections.""" + database, storage = database + index_name = "collections" + database.add_index(name=index_name, object_type=PersistentMapping) + + entity_checksum = "42" + entity_path = "/dummy/path" + usage = Usage(entity=Entity(checksum=entity_checksum, path=entity_path), id="/usages/42") + id_1 = "/activities/1" + activity_1 = Activity(id=id_1, usages=[usage]) + id_2 = "/activities/2" + activity_2 = Activity(id=id_2, usages=[usage]) + + p_mapping = PersistentMapping() + + database[index_name][entity_path] = p_mapping + + p_list = PersistentList() + p_mapping[entity_checksum] = p_list + p_list.append(activity_1) + p_list.append(activity_2) + + database.commit() + + new_database = Database(storage=storage) + collections = new_database[index_name] + + id_3 = "/activities/3" + activity_3 = Activity(id=id_3, usages=[usage]) + collections[entity_path][entity_checksum].append(activity_3) + + assert {id_1, id_2, id_3} == {activity.id for activity in collections[entity_path][entity_checksum]}