Skip to content

Commit

Permalink
feat(dataset): store s3 credentials per bucket (#3339)
Browse files Browse the repository at this point in the history
  • Loading branch information
m-alisafaee committed Mar 10, 2023
1 parent a685c52 commit 717a780
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 13 deletions.
23 changes: 15 additions & 8 deletions renku/core/dataset/providers/s3.py
Expand Up @@ -145,22 +145,29 @@ def get_credentials_section_name(self) -> str:
NOTE: This methods should be overridden by subclasses to allow multiple credentials per providers if needed.
"""
return self.provider.endpoint.lower()
return f"{self.provider.bucket}.{self.provider.endpoint.lower()}"


def parse_s3_uri(uri: str) -> Tuple[str, str, str]:
"""Extract endpoint, bucket name, and path within the bucket from a given URI.
NOTE: We only support s3://<endpoint>/<bucket-name>/<path> at the moment.
NOTE: We only support s3://<hostname>/<bucket-name>/<path> at the moment.
"""
parsed_uri = urllib.parse.urlparse(uri)

endpoint = parsed_uri.netloc
hostname = parsed_uri.netloc
path = parsed_uri.path.strip("/")

if parsed_uri.scheme.lower() != "s3" or not endpoint:
raise errors.ParameterError(f"Invalid S3 URI: {uri}. Valid format is 's3://<endpoint>/<bucket-name>/<path>'")

bucket, _, path = path.partition("/")

return endpoint, bucket, path.strip("/")
if parsed_uri.scheme.lower() != "s3":
raise errors.ParameterError(f"Invalid S3 scheme: {uri}. Valid format is 's3://<hostname>/<bucket-name>/<path>'")
if not hostname:
raise errors.ParameterError(
f"Hostname is missing in S3 URI: {uri}. Valid format is 's3://<hostname>/<bucket-name>/<path>'"
)
if not bucket:
raise errors.ParameterError(
f"Bucket name is missing in S3 URI: {uri}. Valid format is 's3://<hostname>/<bucket-name>/<path>'"
)

return hostname, bucket, path.strip("/")
5 changes: 4 additions & 1 deletion tests/cli/fixtures/cli_providers.py
Expand Up @@ -101,7 +101,10 @@ def cloud_storage_credentials(project):
# S3
s3_access_key_id = os.getenv("CLOUD_STORAGE_S3_ACCESS_KEY_ID", "")
s3_secret_access_key = os.getenv("CLOUD_STORAGE_S3_SECRET_ACCESS_KEY", "")
s3_section = "os.zhdk.cloud.switch.ch"
s3_section = "renku-python-test-public.os.zhdk.cloud.switch.ch"
set_value(section=s3_section, key="access-key-id", value=s3_access_key_id, global_only=True)
set_value(section=s3_section, key="secret-access-key", value=s3_secret_access_key, global_only=True)
s3_section = "renku-python-integration-test.os.zhdk.cloud.switch.ch"
set_value(section=s3_section, key="access-key-id", value=s3_access_key_id, global_only=True)
set_value(section=s3_section, key="secret-access-key", value=s3_secret_access_key, global_only=True)

Expand Down
40 changes: 36 additions & 4 deletions tests/core/test_dataset.py
Expand Up @@ -21,16 +21,16 @@
import pytest

from renku.core import errors
from renku.core.config import get_value
from renku.core.dataset.dataset_add import get_dataset_file_path_within_dataset
from renku.core.dataset.providers.s3 import parse_s3_uri
from renku.core.dataset.providers.s3 import S3Credentials, S3Provider, parse_s3_uri
from renku.domain_model.dataset import Dataset
from renku.domain_model.enums import ConfigFilter


@pytest.mark.parametrize(
"uri, endpoint, bucket, path",
[
("s3://no.bucket.path/", "no.bucket.path", "", ""),
("s3://no.bucket.path///", "no.bucket.path", "", ""),
("s3://no.path/bucket/", "no.path", "bucket", ""),
("S3://uppercase.scheme/bucket/path", "uppercase.scheme", "bucket", "path"),
("s3://slashes.are.stripped///bucket///path/to/data//", "slashes.are.stripped", "bucket", "path/to/data"),
Expand All @@ -45,13 +45,45 @@ def test_valid_s3_uri(uri, endpoint, bucket, path):
assert path == parsed_path


@pytest.mark.parametrize("uri", ["https://invalid.scheme/bucket/", "s3:no-endpoint/bucket/path"])
@pytest.mark.parametrize(
"uri",
[
"https://invalid.scheme/bucket/",
"s3:no-endpoint/bucket/path",
"s3://no.bucket.path/",
"s3://no.bucket.path///",
],
)
def test_invalid_s3_uri(uri):
"""Test invalid s3 URI raise an error."""
with pytest.raises(errors.ParameterError):
parse_s3_uri(uri=uri)


def test_s3_credential_is_per_bucket(project):
"""Test S3 stores credentials per bucket."""
same_host = "s3.host"
provider_1 = S3Provider(uri=f"s3://{same_host}/bucket-1/")
credentials_1 = S3Credentials(provider_1)
credentials_1["access-key-id"] = "id-1"
credentials_1["secret-access-key"] = "key-1"

provider_2 = S3Provider(uri=f"s3://{same_host}/bucket-2/")
credentials_2 = S3Credentials(provider_2)
credentials_2["access-key-id"] = "id-2"
credentials_2["secret-access-key"] = "key-2"

credentials_1.store()
credentials_2.store()

assert "id-1" == get_value(
section=f"bucket-1.{same_host}", key="access-key-id", config_filter=ConfigFilter.GLOBAL_ONLY
)
assert "id-2" == get_value(
section=f"bucket-2.{same_host}", key="access-key-id", config_filter=ConfigFilter.GLOBAL_ONLY
)


@pytest.mark.parametrize(
"entity_path, within_dataset_path",
[
Expand Down

0 comments on commit 717a780

Please sign in to comment.