Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"setuptools;python_version>='3.12'", # Python3.12 doesn't include setuptools automatically
"backoff>=2.0.0",
"pydantic>=1.10.0",
"PyJWT>=2.8.0"
]

[project.optional-dependencies]
Expand Down
22 changes: 21 additions & 1 deletion src/ansys/hps/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@
from typing import Union
import warnings

import jwt
import requests

from .auth.authenticate import authenticate
from .connection import create_session
from .exceptions import raise_for_status
from .exceptions import HPSError, raise_for_status
from .warnings import UnverifiedHTTPSRequestsWarning

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -193,6 +194,25 @@ def __init__(
# client credentials flow does not return a refresh token
self.refresh_token = tokens.get("refresh_token", None)

parsed_username = None

try:
parsed_jwt = jwt.decode(self.access_token, options={"verify_signature": False})
parsed_username = parsed_jwt["preferred_username"]
except:
log.warning("Could not retrieve preferred_username from access token.")

if parsed_username is not None:
if self.username is not None and self.username != parsed_username:
raise HPSError(
(
f"Username: '{self.username}' and "
f"preferred_username: '{parsed_username}' "
"from access token do not match."
)
)
self.username = parsed_username

self.session = create_session(
self.access_token,
verify=self.verify,
Expand Down
38 changes: 38 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import requests

from ansys.hps.client import Client
from ansys.hps.client.exceptions import HPSError

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,3 +77,40 @@ def test_authentication_workflows(url, username, password):
assert client2.access_token is not None
assert client2.refresh_token != client0.refresh_token
client2.refresh_access_token()


def test_authentication_username(url, username, password, keycloak_client):

# Password workflow
client0 = Client(url, username, password)
assert client0.username == username

# Impersonation
realm_clients = keycloak_client.get_clients()
rep_impersonation_client = next(
(x for x in realm_clients if x["clientId"] == "rep-impersonation"), None
)
assert rep_impersonation_client is not None
client1 = Client(
url=url,
client_id=rep_impersonation_client["clientId"],
client_secret=rep_impersonation_client["secret"],
)
assert client1.username == "service-account-rep-impersonation"


def test_authentication_username_exception(url, username, keycloak_client):

# Impersonation
realm_clients = keycloak_client.get_clients()
rep_impersonation_client = next(
(x for x in realm_clients if x["clientId"] == "rep-impersonation"), None
)
assert rep_impersonation_client is not None
with pytest.raises(HPSError):
Client(
url=url,
username=username,
client_id=rep_impersonation_client["clientId"],
client_secret=rep_impersonation_client["secret"],
)