Skip to content

Commit

Permalink
feat(dataset): support non-AWS S3 URI (#3159)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-alisafaee committed Oct 19, 2022
1 parent f9be486 commit b81bbe5
Show file tree
Hide file tree
Showing 6 changed files with 181 additions and 83 deletions.
54 changes: 43 additions & 11 deletions renku/core/dataset/providers/s3.py
Expand Up @@ -20,7 +20,7 @@
import re
import urllib
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple, cast

from renku.core import errors
from renku.core.dataset.providers.api import ProviderApi, ProviderCredentials, ProviderPriority
Expand All @@ -43,8 +43,11 @@ class S3Provider(ProviderApi):

def __init__(self, uri: Optional[str]):
super().__init__(uri=uri)
bucket, _ = extract_bucket_and_path(uri=self.uri)

endpoint, bucket, _ = parse_s3_uri(uri=self.uri)

self._bucket: str = bucket
self._endpoint: str = endpoint

@staticmethod
def supports(uri: str) -> bool:
Expand Down Expand Up @@ -100,16 +103,16 @@ def add(uri: str, destination: Path, **kwargs) -> List["DatasetAddMetadata"]:
if not storage.exists(uri):
raise errors.ParameterError(f"S3 bucket '{uri}' doesn't exists.")

repository = project_context.repository
destination_path_in_repo = Path(destination).relative_to(project_context.repository.path)
hashes = storage.get_hashes(uri=uri)
return [
DatasetAddMetadata(
entity_path=Path(destination).relative_to(repository.path) / hash.path,
entity_path=destination_path_in_repo / hash.path,
url=hash.base_uri,
action=DatasetAddAction.NONE,
based_on=RemoteEntity(checksum=hash.hash if hash.hash else "", url=hash.base_uri, path=hash.path),
source=Path(hash.full_uri),
destination=Path(destination).relative_to(repository.path),
destination=destination_path_in_repo,
gitignored=True,
)
for hash in hashes
Expand All @@ -120,6 +123,11 @@ def bucket(self) -> str:
"""Return S3 bucket name."""
return self._bucket

@property
def endpoint(self) -> str:
"""Return S3 bucket endpoint."""
return self._endpoint

def on_create(self, dataset: "Dataset") -> None:
"""Hook to perform provider-specific actions on a newly-created dataset."""
credentials = S3Credentials(provider=self)
Expand All @@ -145,15 +153,39 @@ def get_credentials_names() -> Tuple[str, ...]:
"""Return a tuple of the required credentials for a provider."""
return "Access Key ID", "Secret Access Key"

@property
def provider(self) -> S3Provider:
"""Return the associated provider instance."""
return cast(S3Provider, self._provider)

def get_credentials_section_name(self) -> str:
"""Get section name for storing credentials.
NOTE: This methods should be overridden by subclasses to allow multiple credentials per providers if needed.
"""
return self.provider.endpoint.lower()

def extract_bucket_and_path(uri: str) -> Tuple[str, str]:
"""Extract bucket name and path within the bucket from a given URI.

NOTE: We only support s3://<bucket-name>/<path> at the moment.
def create_renku_s3_uri(uri: str) -> str:
"""Create a S3 URI to work with Renku."""
_, bucket, path = parse_s3_uri(uri=uri)

return f"s3://{bucket}/{path}"


def parse_s3_uri(uri: str) -> Tuple[str, str, str]:
"""Extract endpoint, bucket name, and path within the bucket from a given URI.
NOTE: We only support s3://<endpoint>/<bucket-name>/<path> at the moment.
"""
parsed_uri = urllib.parse.urlparse(uri)

if parsed_uri.scheme.lower() != "s3" or not parsed_uri.netloc:
raise errors.ParameterError(f"Invalid S3 URI: {uri}")
endpoint = parsed_uri.netloc
path = parsed_uri.path.strip("/")

if parsed_uri.scheme.lower() != "s3" or not endpoint or not path:
raise errors.ParameterError(f"Invalid S3 URI: {uri}. Valid format is 's3://<endpoint>/<bucket-name>/<path>'")

bucket, _, path = path.partition("/")

return parsed_uri.netloc, parsed_uri.path
return endpoint, bucket, path.strip("/")
16 changes: 10 additions & 6 deletions renku/core/interface/storage.py
Expand Up @@ -20,7 +20,7 @@
import abc
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

if TYPE_CHECKING:
from renku.core.dataset.providers.api import ProviderApi, ProviderCredentials
Expand Down Expand Up @@ -70,15 +70,20 @@ def provider(self) -> "ProviderApi":
return self._provider

@abc.abstractmethod
def copy(self, source: Union[Path, str], destination: Union[Path, str]) -> None:
"""Copy data from ``source`` to ``destination``."""
def copy(self, uri: str, destination: Union[Path, str]) -> None:
"""Copy data from ``uri`` to ``destination``."""
raise NotImplementedError

@abc.abstractmethod
def exists(self, uri: str) -> bool:
"""Checks if a remote storage URI exists."""
raise NotImplementedError

@abc.abstractmethod
def get_configurations(self) -> Dict[str, str]:
"""Get required configurations to access the storage."""
raise NotImplementedError

@abc.abstractmethod
def get_hashes(self, uri: str, hash_type: str = "md5") -> List[FileHash]:
"""Get the hashes of all files at the uri."""
Expand All @@ -90,6 +95,5 @@ def mount(self, path: Union[Path, str]) -> None:
raise NotImplementedError

@abc.abstractmethod
def set_configurations(self) -> None:
"""Set required configurations to access the storage."""
raise NotImplementedError
def run_rclone_command(self, command: str, uri: str, *args, **kwargs) -> Any:
"""Run a RClone command by possibly add storage-specific flags."""
50 changes: 24 additions & 26 deletions renku/infrastructure/storage/base.py
Expand Up @@ -17,31 +17,29 @@
# limitations under the License.
"""Base storage handler."""

import abc
import json
import os
import subprocess
from pathlib import Path
from typing import Any, List, Union
from typing import Any, Dict, List, Union

from renku.core import errors
from renku.core.interface.storage import FileHash, IStorage
from renku.core.util.util import NO_VALUE


class RCloneBaseStorage(IStorage):
class RCloneBaseStorage(IStorage, abc.ABC):
"""Base external storage handler class."""

def copy(self, source: Union[Path, str], destination: Union[Path, str]) -> None:
"""Copy data from ``source`` to ``destination``."""
self.set_configurations()
execute_rclone_command("copyto", source, destination)
def copy(self, uri: str, destination: Union[Path, str]) -> None:
"""Copy data from ``uri`` to ``destination``."""
self.run_rclone_command("copyto", uri, str(destination))

def exists(self, uri: str) -> bool:
"""Checks if a remote storage URI exists."""
self.set_configurations()

try:
execute_rclone_command("lsf", uri, max_depth=1)
self.run_rclone_command("lsf", uri=uri, max_depth=1)
except errors.StorageObjectNotFound:
return False
else:
Expand All @@ -67,9 +65,7 @@ def get_hashes(self, uri: str, hash_type: str = "md5") -> List[FileHash]:
}
]
"""
self.set_configurations()

hashes_raw = execute_rclone_command("lsjson", "--hash", "-R", "--files-only", uri)
hashes_raw = self.run_rclone_command("lsjson", uri, hash=True, R=True, files_only=True)
hashes = json.loads(hashes_raw)
if not hashes:
raise errors.ParameterError(f"Cannot find URI: {uri}")
Expand All @@ -90,25 +86,32 @@ def get_hashes(self, uri: str, hash_type: str = "md5") -> List[FileHash]:

def mount(self, path: Union[Path, str]) -> None:
"""Mount the provider's URI to the given path."""
self.set_configurations()
execute_rclone_command("mount", self.provider.uri, path, daemon=True, read_only=True, no_modtime=True)
self.run_rclone_command("mount", self.provider.uri, str(path), daemon=True, read_only=True, no_modtime=True)

def set_configurations(self) -> None:
"""Set required configurations for rclone to access the storage."""
def get_configurations(self) -> Dict[str, str]:
"""Get required configurations for rclone to access the storage."""
configurations = {}
for name, value in self.credentials.items():
name = get_rclone_env_var_name(self.provider.name, name)
if value is not NO_VALUE:
set_rclone_env_var(name=name, value=value)
name = get_rclone_env_var_name(self.provider.name, name)
configurations[name] = value

return configurations


def execute_rclone_command(command: str, *args: Any, **kwargs) -> str:
def run_rclone_command(command: str, *args: Any, env=None, **kwargs) -> str:
"""Execute an R-clone command."""
os_env = os.environ.copy()
if env:
os_env.update(env)

try:
result = subprocess.run(
("rclone", command, *transform_kwargs(**kwargs), *args),
("rclone", command, *transform_kwargs(**kwargs), *transform_args(*args)),
text=True,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=os_env,
)
except FileNotFoundError:
raise errors.RCloneException("RClone is not installed. See https://rclone.org/install/")
Expand All @@ -121,7 +124,7 @@ def execute_rclone_command(command: str, *args: Any, **kwargs) -> str:
if result.returncode in (3, 4):
raise errors.StorageObjectNotFound(all_outputs)
elif "AccessDenied" in all_outputs:
raise errors.AuthenticationError("Authentication failed when accessing the remote storage")
raise errors.AuthenticationError(f"Authentication failed when accessing the remote storage: {all_outputs}")
else:
raise errors.RCloneException(f"Remote storage operation failed: {all_outputs}")

Expand Down Expand Up @@ -154,8 +157,3 @@ def get_rclone_env_var_name(provider_name, name) -> str:
# See https://rclone.org/docs/#config-file
name = name.replace(" ", "_").replace("-", "_")
return f"RCLONE_CONFIG_{provider_name}_{name}".upper()


def set_rclone_env_var(name, value) -> None:
"""Set value for an RClone config env var."""
os.environ[name] = value
32 changes: 22 additions & 10 deletions renku/infrastructure/storage/s3.py
Expand Up @@ -17,20 +17,32 @@
# limitations under the License.
"""S3 storage handler."""

from renku.infrastructure.storage.base import RCloneBaseStorage, get_rclone_env_var_name, set_rclone_env_var
from typing import Any, Dict, cast

from renku.core.dataset.providers.s3 import S3Provider, create_renku_s3_uri
from renku.infrastructure.storage.base import RCloneBaseStorage, run_rclone_command


class S3Storage(RCloneBaseStorage):
"""S3 storage handler."""

def set_configurations(self):
"""Set required configurations for rclone to access the storage."""
super().set_configurations()
@property
def provider(self) -> S3Provider:
"""Return the dataset provider for this storage handler."""
return cast(S3Provider, self._provider)

def get_configurations(self) -> Dict[str, str]:
"""Get required configurations for rclone to access the storage."""
configurations = super().get_configurations()

configurations["RCLONE_CONFIG_S3_TYPE"] = "s3"
configurations["RCLONE_CONFIG_S3_PROVIDER"] = "AWS"
configurations["RCLONE_CONFIG_S3_ENDPOINT"] = self.provider.endpoint

return configurations

# NOTE: Set RCLONE_CONFIG_MYS3_TYPE
name = get_rclone_env_var_name(provider_name=self.provider.name, name="TYPE")
set_rclone_env_var(name=name, value="s3")
def run_rclone_command(self, command: str, uri: str, *args, **kwargs) -> Any:
"""Run a RClone command by possibly add storage-specific flags."""
uri = create_renku_s3_uri(uri=uri)

# NOTE: Set RCLONE_CONFIG_S3_PROVIDER
name = get_rclone_env_var_name(provider_name=self.provider.name, name="PROVIDER")
set_rclone_env_var(name=name, value="AWS")
return run_rclone_command(command, uri, *args, **kwargs, env=self.get_configurations())

0 comments on commit b81bbe5

Please sign in to comment.