Skip to content

Commit

Permalink
[Runs] Enrich run with artifacts when getting a single run [1.6.x] (m…
Browse files Browse the repository at this point in the history
…lrun#5560)

(cherry picked from commit 12317c1)
  • Loading branch information
alonmr committed May 15, 2024
1 parent 7a85be1 commit 142e600
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 4 deletions.
27 changes: 25 additions & 2 deletions server/api/crud/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,36 @@ def get_run(
db_session: sqlalchemy.orm.Session,
uid: str,
iter: int,
project: str = mlrun.mlconf.default_project,
project: str = None,
) -> dict:
project = project or mlrun.mlconf.default_project
return server.api.utils.singletons.db.get_db().read_run(
run = server.api.utils.singletons.db.get_db().read_run(
db_session, uid, project, iter
)

# Since we don't store the artifacts in the run body, we need to fetch them separately
# The client may be using them as in pipeline as input for the next step
producer_uri = None
producer_id = run["metadata"].get("labels", {}).get("workflow")
if not producer_id:
producer_id = uid
else:
# Producer URI is the URI of the MLClientCtx object that produced the artifact
producer_uri = f"{project}/{run['metadata']['uid']}"

artifacts = server.api.crud.Artifacts().list_artifacts(
db_session,
producer_id=producer_id,
producer_uri=producer_uri,
project=project,
)

if artifacts or "artifacts" in run.get("status", {}):
run.setdefault("status", {})
run["status"]["artifacts"] = artifacts

return run

def list_runs(
self,
db_session: sqlalchemy.orm.Session,
Expand Down
129 changes: 127 additions & 2 deletions tests/api/crud/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import unittest.mock
import uuid

import deepdiff
import pytest
import sqlalchemy.orm
from kubernetes import client as k8s_client
Expand All @@ -39,6 +40,7 @@ async def test_delete_runs_with_resources(self, db: sqlalchemy.orm.Session):
{
"metadata": {
"name": "run-name",
"uid": "uid",
"labels": {
"kind": "job",
},
Expand Down Expand Up @@ -237,6 +239,7 @@ def test_run_abortion_failure(self, db: sqlalchemy.orm.Session):
{
"metadata": {
"name": "run-name",
"uid": run_uid,
"labels": {
"kind": "job",
},
Expand Down Expand Up @@ -307,7 +310,9 @@ def test_store_run_strip_artifacts_metadata(self, db: sqlalchemy.orm.Session):
project=project,
)

run = server.api.crud.Runs().get_run(db, run_uid, 0, project)
runs = server.api.crud.Runs().list_runs(db, project=project)
assert len(runs) == 1
run = runs[0]
assert "artifacts" not in run["status"]
assert run["status"]["artifact_uris"] == {
"key1": "store://artifacts/project-name/db_key1@tree1",
Expand Down Expand Up @@ -375,10 +380,130 @@ def test_update_run_strip_artifacts_metadata(self, db: sqlalchemy.orm.Session):
},
)

run = server.api.crud.Runs().get_run(db, run_uid, 0, project)
runs = server.api.crud.Runs().list_runs(db, project=project)
assert len(runs) == 1
run = runs[0]
assert "artifacts" not in run["status"]
assert run["status"]["artifact_uris"] == {
"key1": "store://artifacts/project-name/db_key1@tree1",
"key2": "store://artifacts/project-name/db_key2@tree2",
"key3": "store://artifacts/project-name/db_key3@tree3",
}

def test_get_run_restore_artifacts_metadata(self, db: sqlalchemy.orm.Session):
project = "project-name"
run_uid = str(uuid.uuid4())
artifacts = self._generate_artifacts(project, run_uid)

for artifact in artifacts:
server.api.crud.Artifacts().store_artifact(
db,
artifact["spec"]["db_key"],
artifact,
project=project,
)

server.api.crud.Runs().store_run(
db,
{
"metadata": {
"name": "run-name",
"uid": run_uid,
"labels": {
"kind": "job",
},
},
"status": {
"artifacts": artifacts,
},
},
run_uid,
project=project,
)

self._validate_run_artifacts(artifacts, db, project, run_uid)

def test_get_workflow_run_restore_artifacts_metadata(
self, db: sqlalchemy.orm.Session
):
project = "project-name"
run_uid = str(uuid.uuid4())
workflow_uid = str(uuid.uuid4())
artifacts = self._generate_artifacts(project, run_uid, workflow_uid)

for artifact in artifacts:
server.api.crud.Artifacts().store_artifact(
db,
artifact["spec"]["db_key"],
artifact,
project=project,
)

server.api.crud.Runs().store_run(
db,
{
"metadata": {
"name": "run-name",
"uid": run_uid,
"labels": {
"kind": "job",
"workflow": workflow_uid,
},
},
"status": {
"artifacts": artifacts,
},
},
run_uid,
project=project,
)

self._validate_run_artifacts(artifacts, db, project, run_uid)

@staticmethod
def _generate_artifacts(project, run_uid, workflow_uid=None, artifacts_len=2):
artifacts = []
i = 0
while len(artifacts) < artifacts_len:
artifact = {
"kind": "artifact",
"metadata": {
"key": f"key{i}",
"tree": workflow_uid or run_uid,
"uid": f"uid{i}",
"project": project,
"iter": None,
},
"spec": {
"db_key": f"db_key{i}",
},
"status": {},
}
if workflow_uid:
artifact["spec"]["producer"] = {
"uri": f"{project}/{run_uid}",
}
artifacts.append(artifact)
i += 1
return artifacts

@staticmethod
def _validate_run_artifacts(artifacts, db, project, run_uid):
run = server.api.crud.Runs().get_run(db, run_uid, 0, project)
assert "artifacts" in run["status"]
enriched_artifacts = list(run["status"]["artifacts"])

def sort_by_key(e):
return e["metadata"]["key"]

enriched_artifacts.sort(key=sort_by_key)
artifacts.sort(key=sort_by_key)
for artifact, enriched_artifact in zip(artifacts, enriched_artifacts):
assert (
deepdiff.DeepDiff(
artifact,
enriched_artifact,
exclude_paths="root['metadata']['tag']",
)
== {}
)

0 comments on commit 142e600

Please sign in to comment.