Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Path resolution fixes for DatabricksArtifactRepository #4

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions mlflow/store/artifact/databricks_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, INTERNAL_ERROR
from mlflow.protos.databricks_artifacts_pb2 import DatabricksMlflowArtifactsService, \
GetCredentialsForWrite, GetCredentialsForRead, ArtifactCredentialType
from mlflow.protos.service_pb2 import MlflowService, ListArtifacts
from mlflow.protos.service_pb2 import MlflowService, GetRun, ListArtifacts
from mlflow.store.artifact.artifact_repo import ArtifactRepository
from mlflow.utils.databricks_utils import get_databricks_host_creds
from mlflow.utils.file_utils import relative_path_to_artifact_path, yield_file_in_chunks
Expand Down Expand Up @@ -54,6 +54,17 @@ def __init__(self, artifact_uri):
error_code=INVALID_PARAMETER_VALUE)
self.run_id = self._extract_run_id(self.artifact_uri)

# Fetch the artifact root for the MLflow Run associated with `artifact_uri` and compute
# the path of `artifact_uri` relative to the MLflow Run's artifact root
# (the `run_relative_artifact_repo_root_path`). All operations performed on this artifact
# repository will be performed relative to this computed location
artifact_repo_root_path = extract_and_normalize_path(artifact_uri)
run_artifact_root_uri = self._get_run_artifact_root(self.run_id)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We fetch the actual run artifact root from the MLflow Tracking Service because it seems somewhat brittle to assume that the root is exactly dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/artifacts. We must assume that it's at least dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>, but we don't need to assume that the root contains the artifacts subdirectory or that the root is not in some other subdirectory (e.g., dbfs:/databricks/mlflow-tracking/<EXP_ID>/<RUN_ID>/my/other/awesome/root).

run_artifact_root_path = extract_and_normalize_path(run_artifact_root_uri)
self.run_relative_artifact_repo_root_path = posixpath.relpath(
path=artifact_repo_root_path, start=run_artifact_root_path
)

@staticmethod
def _extract_run_id(artifact_uri):
"""
Expand All @@ -76,6 +87,12 @@ def _call_endpoint(self, service, api, json_body):
return call_endpoint(get_databricks_host_creds(),
endpoint, method, json_body, response_proto)

def _get_run_artifact_root(self, run_id):
json_body = message_to_json(GetRun(run_id=run_id))
run_response = self._call_endpoint(MlflowService,
GetRun, json_body)
return run_response.run.info.artifact_uri

def _get_write_credentials(self, run_id, path=None):
json_body = message_to_json(GetCredentialsForWrite(run_id=run_id, path=path))
return self._call_endpoint(DatabricksMlflowArtifactsService,
Expand Down Expand Up @@ -187,8 +204,13 @@ def log_artifact(self, local_file, artifact_path=None):
basename = os.path.basename(local_file)
artifact_path = artifact_path or ""
artifact_path = posixpath.join(artifact_path, basename)
write_credentials = self._get_write_credentials(self.run_id, artifact_path)
self._upload_to_cloud(write_credentials, local_file, artifact_path)
if len(artifact_path) > 0:
run_relative_artifact_path = posixpath.join(
self.run_relative_artifact_repo_root_path, artifact_path)
else:
run_relative_artifact_path = self.run_relative_artifact_repo_root_path
write_credentials = self._get_write_credentials(self.run_id, run_relative_artifact_path)
self._upload_to_cloud(write_credentials, local_file, run_relative_artifact_path)

def log_artifacts(self, local_dir, artifact_path=None):
artifact_path = artifact_path or ""
Expand All @@ -203,7 +225,12 @@ def log_artifacts(self, local_dir, artifact_path=None):
self.log_artifact(file_path, artifact_subdir)

def list_artifacts(self, path=None):
json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=path))
if path:
run_relative_path = posixpath.join(
self.run_relative_artifact_repo_root_path, path)
else:
run_relative_path = self.run_relative_artifact_repo_root_path
json_body = message_to_json(ListArtifacts(run_id=self.run_id, path=run_relative_path))
artifact_list = self._call_endpoint(MlflowService, ListArtifacts, json_body).files
# If `path` is a file, ListArtifacts returns a single list element with the
# same name as `path`. The list_artifacts API expects us to return an empty list in this
Expand All @@ -212,13 +239,17 @@ def list_artifacts(self, path=None):
and not artifact_list[0].is_dir:
return []
infos = list()
for file in artifact_list:
artifact_size = None if file.is_dir else file.file_size
infos.append(FileInfo(file.path, file.is_dir, artifact_size))
for output_file in artifact_list:
file_rel_path = posixpath.relpath(
path=output_file.path, start=self.run_relative_artifact_repo_root_path)
artifact_size = None if output_file.is_dir else output_file.file_size
infos.append(FileInfo(file_rel_path, output_file.is_dir, artifact_size))
return infos

def _download_file(self, remote_file_path, local_path):
read_credentials = self._get_read_credentials(self.run_id, remote_file_path)
run_relative_remote_file_path = posixpath.join(
self.run_relative_artifact_repo_root_path, remote_file_path)
read_credentials = self._get_read_credentials(self.run_id, run_relative_remote_file_path)
self._download_from_cloud(read_credentials.credentials, local_path)

def delete_artifacts(self, artifact_path=None):
Expand Down