Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions openeo/extra/artifacts/_s3sts/sts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from __future__ import annotations

import logging
import time
from random import randint
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -12,12 +15,16 @@
from openeo.rest.connection import Connection
from openeo.util import Rfc3339

_log = logging.getLogger(__name__)


class OpenEOSTSClient:
_MAX_STS_ATTEMPTS = 3

def __init__(self, config: S3STSConfig):
self.config = config

def assume_from_openeo_connection(self, connection: Connection) -> AWSSTSCredentials:
def assume_from_openeo_connection(self, connection: Connection, attempt: int = 0) -> AWSSTSCredentials:
"""
Takes an OpenEO connection object and returns temporary credentials to interact with S3
"""
Expand All @@ -27,14 +34,31 @@ def assume_from_openeo_connection(self, connection: Connection) -> AWSSTSCredent
raise ProviderSpecificException("Only connections that have BearerAuth can be used.")
auth_token = auth.bearer.split("/")

return AWSSTSCredentials.from_assume_role_response(
self._get_sts_client().assume_role_with_web_identity(
RoleArn=self._get_aws_access_role(),
RoleSessionName=f"artifact-helper-{Rfc3339().now_utc()}",
WebIdentityToken=auth_token[2],
DurationSeconds=43200,
try:
# Do an API call with OpenEO to trigger a refresh of our token if it were stale.
connection.describe_account()
return AWSSTSCredentials.from_assume_role_response(
self._get_sts_client().assume_role_with_web_identity(
RoleArn=self._get_aws_access_role(),
RoleSessionName=f"artifact-helper-{Rfc3339().now_utc()}",
WebIdentityToken=auth_token[2],
DurationSeconds=43200,
)
)
)
except Exception as e:
_log.warning("Failed to get credentials for STS access")

if attempt < self._MAX_STS_ATTEMPTS:
# backoff with jitter
max_sleep_ms = 500 * (2**attempt)
sleep_ms = randint(0, max_sleep_ms)
_log.info(f"Retrying STS access in {sleep_ms} ms")
time.sleep(sleep_ms / 1000.0)
attempt += 1
_log.info(f"Retrying to get credentials for STS access {attempt}/{self._MAX_STS_ATTEMPTS}")
return self.assume_from_openeo_connection(connection, attempt)
else:
raise RuntimeError("Could not get credentials from STS") from e

def _get_sts_client(self) -> STSClient:
return self.config.build_client("sts")
Expand Down
1 change: 1 addition & 0 deletions tests/extra/artifacts/_s3sts/test_s3sts.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def conn_with_s3sts_capabilities(
requests_mock, extra_api_capabilities, advertised_s3sts_config
) -> Iterator[Connection]:
requests_mock.get(API_URL, json={"api_version": "1.0.0", **extra_api_capabilities})
requests_mock.get(f"{API_URL}me", json={})
conn = Connection(API_URL)
conn.auth = BearerAuth("oidc/fake/token")
yield conn
Expand Down
Loading