diff --git a/pyproject.toml b/pyproject.toml index 019f4102a..ea4c440d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "setuptools;python_version>='3.12'", # Python3.12 doesn't include setuptools automatically "backoff>=2.0.0", "pydantic>=1.10.0", + "PyJWT>=2.8.0" ] [project.optional-dependencies] diff --git a/src/ansys/hps/client/client.py b/src/ansys/hps/client/client.py index 7f423cb2d..32a32959b 100644 --- a/src/ansys/hps/client/client.py +++ b/src/ansys/hps/client/client.py @@ -25,11 +25,12 @@ from typing import Union import warnings +import jwt import requests from .auth.authenticate import authenticate from .connection import create_session -from .exceptions import raise_for_status +from .exceptions import HPSError, raise_for_status from .warnings import UnverifiedHTTPSRequestsWarning log = logging.getLogger(__name__) @@ -193,6 +194,25 @@ def __init__( # client credentials flow does not return a refresh token self.refresh_token = tokens.get("refresh_token", None) + parsed_username = None + + try: + parsed_jwt = jwt.decode(self.access_token, options={"verify_signature": False}) + parsed_username = parsed_jwt["preferred_username"] + except: + log.warning("Could not retrieve preferred_username from access token.") + + if parsed_username is not None: + if self.username is not None and self.username != parsed_username: + raise HPSError( + ( + f"Username: '{self.username}' and " + f"preferred_username: '{parsed_username}' " + "from access token do not match." + ) + ) + self.username = parsed_username + self.session = create_session( self.access_token, verify=self.verify, diff --git a/tests/test_client.py b/tests/test_client.py index 9d40df55e..3f88425e1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -27,6 +27,7 @@ import requests from ansys.hps.client import Client +from ansys.hps.client.exceptions import HPSError log = logging.getLogger(__name__) @@ -76,3 +77,40 @@ def test_authentication_workflows(url, username, password): assert client2.access_token is not None assert client2.refresh_token != client0.refresh_token client2.refresh_access_token() + + +def test_authentication_username(url, username, password, keycloak_client): + + # Password workflow + client0 = Client(url, username, password) + assert client0.username == username + + # Impersonation + realm_clients = keycloak_client.get_clients() + rep_impersonation_client = next( + (x for x in realm_clients if x["clientId"] == "rep-impersonation"), None + ) + assert rep_impersonation_client is not None + client1 = Client( + url=url, + client_id=rep_impersonation_client["clientId"], + client_secret=rep_impersonation_client["secret"], + ) + assert client1.username == "service-account-rep-impersonation" + + +def test_authentication_username_exception(url, username, keycloak_client): + + # Impersonation + realm_clients = keycloak_client.get_clients() + rep_impersonation_client = next( + (x for x in realm_clients if x["clientId"] == "rep-impersonation"), None + ) + assert rep_impersonation_client is not None + with pytest.raises(HPSError): + Client( + url=url, + username=username, + client_id=rep_impersonation_client["clientId"], + client_secret=rep_impersonation_client["secret"], + )