Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed

- `DataCube.sar_backscatter()`: add corresponding band names to metadata when enabling "mask", "contributing_area", "local_incidence_angle" or "ellipsoid_incidence_angle" ([#804](https://github.com/Open-EO/openeo-python-client/issues/804))
- Proactively refresh access/bearer token in `MultiBackendJobManager` before launching a job start thread ([#817](https://github.com/Open-EO/openeo-python-client/issues/817))


## [0.45.0] - 2025-09-17
Expand Down
19 changes: 19 additions & 0 deletions openeo/extra/job_management/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,8 @@ def __init__(
)
self._thread = None
self._worker_pool = None
# Generic cache
self._cache = {}

def add_backend(
self,
Expand Down Expand Up @@ -650,6 +652,8 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
# start job if not yet done by callback
try:
job_con = job.connection
# Proactively refresh bearer token (because task in thread will not be able to do that)
self._refresh_bearer_token(connection=job_con)
task = _JobStartTask(
root_url=job_con.root_url,
bearer_token=job_con.auth.bearer if isinstance(job_con.auth, BearerAuth) else None,
Expand All @@ -670,6 +674,21 @@ def _launch_job(self, start_job, df, i, backend_name, stats: Optional[dict] = No
df.loc[i, "status"] = "skipped"
stats["start_job skipped"] += 1

def _refresh_bearer_token(self, connection: Connection, *, max_age: float = 60) -> None:
"""
Helper to proactively refresh the bearer (access) token of the connection
(but not too often, based on `max_age`).
"""
# TODO: be smarter about timing, e.g. by inspecting expiry of current token?
now = time.time()
key = f"connection:{id(connection)}:refresh-time"
if self._cache.get(key, 0) + max_age < now:
refreshed = connection.try_access_token_refresh()
if refreshed:
self._cache[key] = now
else:
_log.warning("Failed to proactively refresh bearer token")

def _process_threadworker_updates(
self,
worker_pool: _JobManagerWorkerThreadPool,
Expand Down
16 changes: 16 additions & 0 deletions openeo/rest/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,22 @@ def at_url(cls, root_url: str, *, requests_mock, capabilities: Optional[dict] =
connection = Connection(root_url)
return cls(requests_mock=requests_mock, connection=connection)

def setup_credentials_oidc(self, *, issuer: str = "https://oidc.test", id: str = "oi"):
self._requests_mock.get(
self.connection.build_url("/credentials/oidc"),
json={
"providers": [
{
"id": id,
"issuer": issuer,
"title": id,
"scopes": ["openid"],
}
]
},
)
return self

def setup_collection(
self,
collection_id: str,
Expand Down
5 changes: 5 additions & 0 deletions openeo/rest/auth/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def token_callback_resource_owner_password_credentials(self, params: dict, conte
assert params["scope"] == self.expected_fields["scope"]
return self._build_token_response()

def token_callback_block_400(self, params: dict, context):
"""Failing callback with 400 Bad Request"""
context.status_code = 400
return "block_400"

def device_code_callback(self, request: requests_mock.request._RequestObjectProxy, context):
params = self._get_query_params(query=request.text)
assert params["client_id"] == self.expected_client_id
Expand Down
85 changes: 56 additions & 29 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,28 +342,32 @@ def _authenticate_oidc(
*,
provider_id: str,
store_refresh_token: bool = False,
fallback_refresh_token_to_store: Optional[str] = None,
auto_renew_from_refresh_token: bool = False,
fallback_refresh_token: Optional[str] = None,
oidc_auth_renewer: Optional[OidcAuthenticator] = None,
) -> Connection:
"""
Authenticate through OIDC and set up bearer token (based on OIDC access_token) for further requests.
"""
tokens = authenticator.get_tokens(request_refresh_token=store_refresh_token)
request_refresh_token = store_refresh_token or (not oidc_auth_renewer and auto_renew_from_refresh_token)
tokens = authenticator.get_tokens(request_refresh_token=request_refresh_token)
_log.info("Obtained tokens: {t}".format(t=[k for k, v in tokens._asdict().items() if v]))

refresh_token = tokens.refresh_token or fallback_refresh_token
if store_refresh_token:
refresh_token = tokens.refresh_token or fallback_refresh_token_to_store
if refresh_token:
self._get_refresh_token_store().set_refresh_token(
issuer=authenticator.provider_info.issuer,
client_id=authenticator.client_id,
refresh_token=refresh_token
)
if not oidc_auth_renewer:
oidc_auth_renewer = OidcRefreshTokenAuthenticator(
client_info=authenticator.client_info, refresh_token=refresh_token
)
else:
_log.warning("No OIDC refresh token to store.")
if not oidc_auth_renewer and auto_renew_from_refresh_token and refresh_token:
oidc_auth_renewer = OidcRefreshTokenAuthenticator(
client_info=authenticator.client_info, refresh_token=refresh_token
)

token = tokens.access_token
self.auth = OidcBearerAuth(provider_id=provider_id, access_token=token)
self._oidc_auth_renewer = oidc_auth_renewer
Expand Down Expand Up @@ -452,7 +456,12 @@ def authenticate_oidc_resource_owner_password_credentials(
authenticator = OidcResourceOwnerPasswordAuthenticator(
client_info=client_info, username=username, password=password
)
return self._authenticate_oidc(authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token)
return self._authenticate_oidc(
authenticator,
provider_id=provider_id,
store_refresh_token=store_refresh_token,
oidc_auth_renewer=authenticator,
)

def authenticate_oidc_refresh_token(
self,
Expand Down Expand Up @@ -493,7 +502,7 @@ def authenticate_oidc_refresh_token(
authenticator,
provider_id=provider_id,
store_refresh_token=store_refresh_token,
fallback_refresh_token_to_store=refresh_token,
fallback_refresh_token=refresh_token,
oidc_auth_renewer=authenticator,
)

Expand Down Expand Up @@ -534,7 +543,13 @@ def authenticate_oidc_device(
authenticator = OidcDeviceAuthenticator(
client_info=client_info, use_pkce=use_pkce, max_poll_time=max_poll_time, **kwargs
)
return self._authenticate_oidc(authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token)
return self._authenticate_oidc(
authenticator,
provider_id=provider_id,
store_refresh_token=store_refresh_token,
# TODO: expose `auto_renew_from_refresh_token` directly as option instead of reusing `store_refresh_token` arg?
auto_renew_from_refresh_token=store_refresh_token,
)

def authenticate_oidc(
self,
Expand Down Expand Up @@ -604,7 +619,8 @@ def authenticate_oidc(
authenticator,
provider_id=provider_id,
store_refresh_token=store_refresh_token,
fallback_refresh_token_to_store=refresh_token,
fallback_refresh_token=refresh_token,
oidc_auth_renewer=authenticator,
)
# TODO: pluggable/jupyter-aware display function?
print("Authenticated using refresh token.")
Expand All @@ -622,6 +638,8 @@ def authenticate_oidc(
authenticator,
provider_id=provider_id,
store_refresh_token=store_refresh_token,
# TODO: expose `auto_renew_from_refresh_token` directly as option instead of reusing `store_refresh_token` arg?
auto_renew_from_refresh_token=store_refresh_token,
)
print("Authenticated using device code flow.")
return con
Expand Down Expand Up @@ -665,6 +683,28 @@ def authenticate_bearer_token(self, bearer_token: str) -> Connection:
self._oidc_auth_renewer = None
return self

def try_access_token_refresh(self, *, reason: Optional[str] = None) -> bool:
"""
Try to get a fresh access token if possible.
Returns whether a new access token was obtained.
"""
reason = f" Reason: {reason}" if reason else ""
if isinstance(self.auth, OidcBearerAuth) and self._oidc_auth_renewer:
try:
self._authenticate_oidc(
authenticator=self._oidc_auth_renewer,
provider_id=self._oidc_auth_renewer.provider_info.id,
store_refresh_token=False,
oidc_auth_renewer=self._oidc_auth_renewer,
)
_log.info(f"Obtained new access token (grant {self._oidc_auth_renewer.grant_type!r}).{reason}")
return True
except OpenEoClientException as auth_exc:
_log.error(
f"Failed to obtain new access token (grant {self._oidc_auth_renewer.grant_type!r}): {auth_exc!r}.{reason}"
)
return False

def request(
self,
method: str,
Expand All @@ -690,24 +730,11 @@ def _request():
api_exc.http_status_code in {HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN}
and api_exc.code == "TokenInvalid"
):
# Auth token expired: can we refresh?
if isinstance(self.auth, OidcBearerAuth) and self._oidc_auth_renewer:
msg = f"OIDC access token expired ({api_exc.http_status_code} {api_exc.code})."
try:
self._authenticate_oidc(
authenticator=self._oidc_auth_renewer,
provider_id=self._oidc_auth_renewer.provider_info.id,
store_refresh_token=False,
oidc_auth_renewer=self._oidc_auth_renewer,
)
_log.info(f"{msg} Obtained new access token (grant {self._oidc_auth_renewer.grant_type!r}).")
except OpenEoClientException as auth_exc:
_log.error(
f"{msg} Failed to obtain new access token (grant {self._oidc_auth_renewer.grant_type!r}): {auth_exc!r}."
)
else:
# Retry request.
return _request()
# Retry if we can refresh the access token
if self.try_access_token_refresh(
reason=f"OIDC access token expired ({api_exc.http_status_code} {api_exc.code})."
):
return _request()
raise

def describe_account(self) -> dict:
Expand Down
49 changes: 48 additions & 1 deletion tests/extra/job_management/test_job_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_TaskResult,
)
from openeo.rest._testing import OPENEO_BACKEND, DummyBackend, build_capabilities
from openeo.rest.auth.testing import OidcMock
from openeo.util import rfc3339
from openeo.utils.version import ComparableVersion

Expand Down Expand Up @@ -269,7 +270,7 @@ def test_create_job_db(self, tmp_path, job_manager, job_manager_root_dir, sleep_
assert set(result.status) == {"finished"}
assert set(result.backend_name) == {"foo", "bar"}

def test_basic_threading(self, tmp_path, job_manager, job_manager_root_dir, sleep_mock):
def test_start_job_thread_basic(self, tmp_path, job_manager, job_manager_root_dir, sleep_mock):
df = pd.DataFrame(
{
"year": [2018, 2019, 2020, 2021, 2022],
Expand Down Expand Up @@ -868,6 +869,52 @@ def execute(self):
assert any("Skipping invalid db_update" in msg for msg in caplog.messages)
assert any("Skipping invalid stats_update" in msg for msg in caplog.messages)

def test_refresh_bearer_token_before_start(
self,
tmp_path,
job_manager,
dummy_backend_foo,
dummy_backend_bar,
job_manager_root_dir,
sleep_mock,
requests_mock,
):

client_id = "client123"
client_secret = "$3cr3t"
oidc_issuer = "https://oidc.test/"
oidc_mock = OidcMock(
requests_mock=requests_mock,
expected_grant_type="client_credentials",
expected_client_id=client_id,
expected_fields={"client_secret": client_secret, "scope": "openid"},
oidc_issuer=oidc_issuer,
)
dummy_backend_foo.setup_credentials_oidc(issuer=oidc_issuer)
dummy_backend_bar.setup_credentials_oidc(issuer=oidc_issuer)
dummy_backend_foo.connection.authenticate_oidc_client_credentials(client_id="client123", client_secret="$3cr3t")
dummy_backend_bar.connection.authenticate_oidc_client_credentials(client_id="client123", client_secret="$3cr3t")

# After this setup, we have 2 client credential token requests (one for each backend)
assert len(oidc_mock.grant_request_history) == 2

df = pd.DataFrame({"year": [2020, 2021, 2022, 2023, 2024]})
job_db_path = tmp_path / "jobs.csv"
job_db = CsvJobDatabase(job_db_path).initialize_from_df(df)
run_stats = job_manager.run_jobs(job_db=job_db, start_job=self._create_year_job)

assert run_stats == dirty_equals.IsPartialDict(
{
"job_queued_for_start": 5,
"job started running": 5,
"job finished": 5,
}
)

# Because of proactive+throttled token refreshing,
# we should have 2 additional token requests now
assert len(oidc_mock.grant_request_history) == 4


JOB_DB_DF_BASICS = pd.DataFrame(
{
Expand Down
Loading
Loading