Skip to content

Commit

Permalink
[Runs] Store artifact URIs instead of full artifacts in body [1.6.x] (m…
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr committed May 12, 2024
1 parent 5223d5d commit a54892f
Show file tree
Hide file tree
Showing 6 changed files with 172 additions and 6 deletions.
2 changes: 2 additions & 0 deletions mlrun/lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"parameters",
"results",
"artifacts",
"artifact_uris",
"error",
]

Expand Down Expand Up @@ -63,6 +64,7 @@ def to_rows(self, extend_iterations=False):
get_in(run, "spec.parameters", ""),
get_in(run, "status.results", ""),
get_in(run, "status.artifacts", []),
get_in(run, "status.artifact_uris", {}),
get_in(run, "status.error", ""),
]
if extend_iterations and iterations:
Expand Down
9 changes: 7 additions & 2 deletions mlrun/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,7 @@ def __init__(
ui_url=None,
reason: str = None,
notifications: Dict[str, Notification] = None,
artifact_uris: dict[str, str] = None,
):
self.state = state or "created"
self.status_text = status_text
Expand All @@ -1072,6 +1073,8 @@ def __init__(
self.ui_url = ui_url
self.reason = reason
self.notifications = notifications or {}
# Artifact key -> URI mapping, since the full artifacts are not stored in the runs DB table
self.artifact_uris = artifact_uris or {}

def is_failed(self) -> Optional[bool]:
"""
Expand Down Expand Up @@ -1384,8 +1387,10 @@ def refresh(self):
iter=self.metadata.iteration,
)
if run:
self.status = RunStatus.from_dict(run.get("status", {}))
self.status.from_dict(run.get("status", {}))
run_status = run.get("status", {})
# Artifacts are not stored in the DB, so we need to preserve them here
run_status["artifacts"] = self.status.artifacts
self.status = RunStatus.from_dict(run_status)
return self

def show(self):
Expand Down
12 changes: 9 additions & 3 deletions mlrun/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,12 +404,18 @@ def time_str(x):
df.drop("labels", axis=1, inplace=True)
df.drop("inputs", axis=1, inplace=True)
df.drop("artifacts", axis=1, inplace=True)
df.drop("artifact_uris", axis=1, inplace=True)
else:
df["labels"] = df["labels"].apply(dict_html)
df["inputs"] = df["inputs"].apply(inputs_html)
df["artifacts"] = df["artifacts"].apply(
lambda artifacts: artifacts_html(artifacts, "target_path"),
)
if df["artifact_uris"][0]:
df["artifact_uris"] = df["artifact_uris"].apply(dict_html)
df.drop("artifacts", axis=1, inplace=True)
else:
df["artifacts"] = df["artifacts"].apply(
lambda artifacts: artifacts_html(artifacts, "target_path"),
)
df.drop("artifact_uris", axis=1, inplace=True)

def expand_error(x):
if x["state"] == "error":
Expand Down
30 changes: 30 additions & 0 deletions server/api/crud/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sqlalchemy.orm
from fastapi.concurrency import run_in_threadpool

import mlrun.artifacts
import mlrun.common.schemas
import mlrun.config
import mlrun.errors
Expand Down Expand Up @@ -55,6 +56,19 @@ def store_run(
data, server.api.constants.MaskOperations.REDACT
)

# Clients before 1.7.0 send the full artifact metadata in the run object, we need to strip it
# to avoid bloating the DB.
data.setdefault("status", {})
artifacts = data["status"].get("artifacts", [])
artifact_uris = data["status"].get("artifact_uris", {})
for artifact in artifacts:
artifact = mlrun.artifacts.dict_to_artifact(artifact)
artifact_uris[artifact.key] = artifact.uri

if artifact_uris:
data["status"]["artifact_uris"] = artifact_uris
data["status"].pop("artifacts", None)

server.api.utils.singletons.db.get_db().store_run(
db_session,
data,
Expand All @@ -73,6 +87,22 @@ def update_run(
):
project = project or mlrun.mlconf.default_project
logger.debug("Updating run", project=project, uid=uid, iter=iter)

# Clients before 1.7.0 send the full artifact metadata in the run object, we need to strip it
# to avoid bloating the DB.
artifacts = data.get("status.artifacts", None)
artifact_uris = data.get("status.artifact_uris", None)
# If neither was given, nothing to do. Otherwise, we merge the two fields into artifact_uris.
if artifacts is not None or artifact_uris is not None:
artifacts = artifacts or []
artifact_uris = artifact_uris or {}
for artifact in artifacts:
artifact = mlrun.artifacts.dict_to_artifact(artifact)
artifact_uris[artifact.key] = artifact.uri

data["status.artifact_uris"] = artifact_uris
data.pop("status.artifacts", None)

# TODO: Abort run moved to a separate endpoint, remove this section once in 1.8.0
# (once 1.5.x clients are not supported)
if (
Expand Down
2 changes: 1 addition & 1 deletion server/api/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_log(self, session, uid, project="", offset=0, size=0):
def store_run(
self,
session,
struct,
run_data,
uid,
project="",
iter=0,
Expand Down
123 changes: 123 additions & 0 deletions tests/api/crud/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,126 @@ def test_run_abortion_failure(self, db: sqlalchemy.orm.Session):
run = server.api.crud.Runs().get_run(db, run_uid, 0, project)
assert run["status"]["state"] == mlrun.runtimes.constants.RunStates.error
assert run["status"]["error"] == "Failed to abort run, error: BOOM"

def test_store_run_strip_artifacts_metadata(self, db: sqlalchemy.orm.Session):
project = "project-name"
run_uid = str(uuid.uuid4())
server.api.crud.Runs().store_run(
db,
{
"metadata": {
"name": "run-name",
"labels": {
"kind": "job",
},
},
"status": {
"artifact_uris": {
"key1": "this should be replaced",
"key2": "store://artifacts/project-name/db_key2@tree2",
},
"artifacts": [
{
"metadata": {
"key": "key1",
"tree": "tree1",
"uid": "uid1",
"project": project,
},
"spec": {
"db_key": "db_key1",
},
},
{
"metadata": {
"key": "key3",
"tree": "tree3",
"uid": "uid3",
"project": project,
},
"spec": {
"db_key": "db_key3",
},
},
],
},
},
run_uid,
project=project,
)

run = server.api.crud.Runs().get_run(db, run_uid, 0, project)
assert "artifacts" not in run["status"]
assert run["status"]["artifact_uris"] == {
"key1": "store://artifacts/project-name/db_key1@tree1",
"key2": "store://artifacts/project-name/db_key2@tree2",
"key3": "store://artifacts/project-name/db_key3@tree3",
}

def test_update_run_strip_artifacts_metadata(self, db: sqlalchemy.orm.Session):
project = "project-name"
run_uid = str(uuid.uuid4())
server.api.crud.Runs().store_run(
db,
{
"metadata": {
"name": "run-name",
"labels": {
"kind": "job",
},
},
"status": {
"artifact_uris": {
"key1": "this should be replaced",
"key2": "store://artifacts/project-name/db_key2@tree2",
},
},
},
run_uid,
project=project,
)

server.api.crud.Runs().update_run(
db,
project,
run_uid,
iter=0,
data={
"status.artifact_uris": {
"key1": "this should be replaced",
"key2": "store://artifacts/project-name/db_key2@tree2",
},
"status.artifacts": [
{
"metadata": {
"key": "key1",
"tree": "tree1",
"uid": "uid1",
"project": project,
},
"spec": {
"db_key": "db_key1",
},
},
{
"metadata": {
"key": "key3",
"tree": "tree3",
"uid": "uid3",
"project": project,
},
"spec": {
"db_key": "db_key3",
},
},
],
},
)

run = server.api.crud.Runs().get_run(db, run_uid, 0, project)
assert "artifacts" not in run["status"]
assert run["status"]["artifact_uris"] == {
"key1": "store://artifacts/project-name/db_key1@tree1",
"key2": "store://artifacts/project-name/db_key2@tree2",
"key3": "store://artifacts/project-name/db_key3@tree3",
}

0 comments on commit a54892f

Please sign in to comment.