From 12317c18d4a1660c8c16427c0c607d079bca6646 Mon Sep 17 00:00:00 2001 From: Alon Maor <48641682+alonmr@users.noreply.github.com> Date: Wed, 15 May 2024 12:06:12 +0300 Subject: [PATCH] [Runs] Enrich run with artifacts when getting a single run [1.6.x] (#5560) --- server/api/crud/runs.py | 27 +++++++- tests/api/crud/test_runs.py | 129 +++++++++++++++++++++++++++++++++++- 2 files changed, 152 insertions(+), 4 deletions(-) diff --git a/server/api/crud/runs.py b/server/api/crud/runs.py index 12c361e4bd8..e2473bc696b 100644 --- a/server/api/crud/runs.py +++ b/server/api/crud/runs.py @@ -141,13 +141,36 @@ def get_run( db_session: sqlalchemy.orm.Session, uid: str, iter: int, - project: str = mlrun.mlconf.default_project, + project: str = None, ) -> dict: project = project or mlrun.mlconf.default_project - return server.api.utils.singletons.db.get_db().read_run( + run = server.api.utils.singletons.db.get_db().read_run( db_session, uid, project, iter ) + # Since we don't store the artifacts in the run body, we need to fetch them separately + # The client may be using them as in pipeline as input for the next step + producer_uri = None + producer_id = run["metadata"].get("labels", {}).get("workflow") + if not producer_id: + producer_id = uid + else: + # Producer URI is the URI of the MLClientCtx object that produced the artifact + producer_uri = f"{project}/{run['metadata']['uid']}" + + artifacts = server.api.crud.Artifacts().list_artifacts( + db_session, + producer_id=producer_id, + producer_uri=producer_uri, + project=project, + ) + + if artifacts or "artifacts" in run.get("status", {}): + run.setdefault("status", {}) + run["status"]["artifacts"] = artifacts + + return run + def list_runs( self, db_session: sqlalchemy.orm.Session, diff --git a/tests/api/crud/test_runs.py b/tests/api/crud/test_runs.py index 354756d6d9d..3401a2c7c0a 100644 --- a/tests/api/crud/test_runs.py +++ b/tests/api/crud/test_runs.py @@ -15,6 +15,7 @@ import unittest.mock import uuid +import deepdiff import pytest import sqlalchemy.orm from kubernetes import client as k8s_client @@ -39,6 +40,7 @@ async def test_delete_runs_with_resources(self, db: sqlalchemy.orm.Session): { "metadata": { "name": "run-name", + "uid": "uid", "labels": { "kind": "job", }, @@ -237,6 +239,7 @@ def test_run_abortion_failure(self, db: sqlalchemy.orm.Session): { "metadata": { "name": "run-name", + "uid": run_uid, "labels": { "kind": "job", }, @@ -307,7 +310,9 @@ def test_store_run_strip_artifacts_metadata(self, db: sqlalchemy.orm.Session): project=project, ) - run = server.api.crud.Runs().get_run(db, run_uid, 0, project) + runs = server.api.crud.Runs().list_runs(db, project=project) + assert len(runs) == 1 + run = runs[0] assert "artifacts" not in run["status"] assert run["status"]["artifact_uris"] == { "key1": "store://artifacts/project-name/db_key1@tree1", @@ -375,10 +380,130 @@ def test_update_run_strip_artifacts_metadata(self, db: sqlalchemy.orm.Session): }, ) - run = server.api.crud.Runs().get_run(db, run_uid, 0, project) + runs = server.api.crud.Runs().list_runs(db, project=project) + assert len(runs) == 1 + run = runs[0] assert "artifacts" not in run["status"] assert run["status"]["artifact_uris"] == { "key1": "store://artifacts/project-name/db_key1@tree1", "key2": "store://artifacts/project-name/db_key2@tree2", "key3": "store://artifacts/project-name/db_key3@tree3", } + + def test_get_run_restore_artifacts_metadata(self, db: sqlalchemy.orm.Session): + project = "project-name" + run_uid = str(uuid.uuid4()) + artifacts = self._generate_artifacts(project, run_uid) + + for artifact in artifacts: + server.api.crud.Artifacts().store_artifact( + db, + artifact["spec"]["db_key"], + artifact, + project=project, + ) + + server.api.crud.Runs().store_run( + db, + { + "metadata": { + "name": "run-name", + "uid": run_uid, + "labels": { + "kind": "job", + }, + }, + "status": { + "artifacts": artifacts, + }, + }, + run_uid, + project=project, + ) + + self._validate_run_artifacts(artifacts, db, project, run_uid) + + def test_get_workflow_run_restore_artifacts_metadata( + self, db: sqlalchemy.orm.Session + ): + project = "project-name" + run_uid = str(uuid.uuid4()) + workflow_uid = str(uuid.uuid4()) + artifacts = self._generate_artifacts(project, run_uid, workflow_uid) + + for artifact in artifacts: + server.api.crud.Artifacts().store_artifact( + db, + artifact["spec"]["db_key"], + artifact, + project=project, + ) + + server.api.crud.Runs().store_run( + db, + { + "metadata": { + "name": "run-name", + "uid": run_uid, + "labels": { + "kind": "job", + "workflow": workflow_uid, + }, + }, + "status": { + "artifacts": artifacts, + }, + }, + run_uid, + project=project, + ) + + self._validate_run_artifacts(artifacts, db, project, run_uid) + + @staticmethod + def _generate_artifacts(project, run_uid, workflow_uid=None, artifacts_len=2): + artifacts = [] + i = 0 + while len(artifacts) < artifacts_len: + artifact = { + "kind": "artifact", + "metadata": { + "key": f"key{i}", + "tree": workflow_uid or run_uid, + "uid": f"uid{i}", + "project": project, + "iter": None, + }, + "spec": { + "db_key": f"db_key{i}", + }, + "status": {}, + } + if workflow_uid: + artifact["spec"]["producer"] = { + "uri": f"{project}/{run_uid}", + } + artifacts.append(artifact) + i += 1 + return artifacts + + @staticmethod + def _validate_run_artifacts(artifacts, db, project, run_uid): + run = server.api.crud.Runs().get_run(db, run_uid, 0, project) + assert "artifacts" in run["status"] + enriched_artifacts = list(run["status"]["artifacts"]) + + def sort_by_key(e): + return e["metadata"]["key"] + + enriched_artifacts.sort(key=sort_by_key) + artifacts.sort(key=sort_by_key) + for artifact, enriched_artifact in zip(artifacts, enriched_artifacts): + assert ( + deepdiff.DeepDiff( + artifact, + enriched_artifact, + exclude_paths="root['metadata']['tag']", + ) + == {} + )