Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classes, functions, and attributes.

ApiClient
ApiClientFactory
OIDCSessionBuilder
AuthorizationSessionBuilder
SessionConfiguration


Expand Down
4 changes: 2 additions & 2 deletions src/ansys/openapi/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

__version__ = metadata_backport("ansys-openapi-common")["version"]

from ._session import ApiClientFactory, OIDCSessionBuilder
from ._session import ApiClientFactory, AuthorizationSessionBuilder
from ._util import SessionConfiguration, generate_user_agent
from ._exceptions import ApiConnectionException, ApiException, AuthenticationWarning
from ._api_client import ApiClient
Expand All @@ -26,7 +26,7 @@
"AuthenticationWarning",
"create_session_from_granta_stk",
"generate_user_agent",
"OIDCSessionBuilder",
"AuthorizationSessionBuilder",
"ApiBase",
"ApiClientBase",
"ModelBase",
Expand Down
4 changes: 2 additions & 2 deletions src/ansys/openapi/common/_contrib/granta.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ def create_session_from_granta_stk(
cached_token_key = auth_settings["token_key"]
return (
ApiClientFactory(sl_url, api_session_configuration)
.with_oidc(idp_session_configuration)
.with_oidc_authorization_flow(idp_session_configuration)
.with_stored_token(cached_token_key)
.connect()
)
elif mode == "oidc_token":
refresh_token = auth_settings["refresh_token"]
return (
ApiClientFactory(sl_url, api_session_configuration)
.with_oidc(idp_session_configuration)
.with_oidc_authorization_flow(idp_session_configuration)
.with_token(refresh_token)
.connect()
)
Expand Down
51 changes: 50 additions & 1 deletion src/ansys/openapi/common/_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import keyring
import requests
from requests.models import CaseInsensitiveDict
from requests_auth import OAuth2AuthorizationCodePKCE, InvalidGrantRequest # type: ignore[import]
from requests_auth import ( # type: ignore[import]
OAuth2AuthorizationCodePKCE,
OAuth2ClientCredentials,
InvalidGrantRequest,
)
from requests_auth.authentication import OAuth2 # type: ignore[import]

from ._util import (
Expand All @@ -25,6 +29,51 @@
_log_tokens = False


def get_client_credential_auth(
token_url: str,
client_id: str,
client_secret: str,
scope: Optional[str] = "",
session: Optional[requests.Session] = None,
) -> OAuth2ClientCredentials:

"""Get a requests auth object for the Client Credential OIDC flow.
Provides a wrapper around the requests_auth.OAuth2ClientCredentials class.
Parameters
----------
token_url: :class:`str`
OAuth 2 token URL.
client_id : :class:`str`
Resource owner username. Provided by the Identity provider.
client_secret : :class:`str`
Resource owner password. Provided by the Identity provider.
scope : :class:`str`, optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do we provide multiple scopes here, using a comma separated list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. In re-reading the requests-auth docs, the docstring there says "Scope parameter sent to token URL as body. Can also be a list of scopes." - my interpretation of this is that it can be str or list[str]. https://colin-b.github.io/requests_auth/#client-credentials-flow

Single scope or list of scopes required by the application.
session: :class:`requests.Session`, optional
Session used to retrieve the token. Used if a specific configuration is required,
e.g. to disable SSL certificate verification.
Returns
-------
:class:`requests_auth.OAuth2ClientCredentials`
Requests Client Credentials auth object.
Notes
-----
OIDC Authentication requires the ``[oidc]`` extra to be installed.
"""

return OAuth2ClientCredentials(
token_url=token_url,
client_id=client_id,
client_secret=client_secret,
scope=scope,
session=session,
)


class OIDCSessionFactory:
"""
Creates an OpenID Connect session with the configuration fetched from the API server. This class uses either
Expand Down
93 changes: 78 additions & 15 deletions src/ansys/openapi/common/_session.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import warnings
from typing import Tuple, Union, Container, Optional, Mapping, TypeVar, Any
from typing import Tuple, Union, Optional, Mapping, TypeVar, Any, Callable
from functools import wraps
from copy import copy

import requests
from urllib3.util.retry import Retry
Expand Down Expand Up @@ -31,7 +33,7 @@
# noinspection PyUnresolvedReferences
import requests_auth # type: ignore[import]
import keyring
from ._oidc import OIDCSessionFactory
from ._oidc import OIDCSessionFactory, get_client_credential_auth
except ImportError:
_oidc_enabled = False

Expand All @@ -51,6 +53,29 @@

_platform_windows = False

Return_Type = TypeVar("Return_Type")


def require_oidc(func: Callable[..., Return_Type]) -> Callable[..., Return_Type]:
"""Enforce that OIDC features are enabled before executing the wrapped function/method.

Raises
------
ImportError
If the OIDC features have not been installed.
"""

@wraps(func)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice

def wrapper(*args: Any, **kwargs: Any) -> Return_Type:
if not _oidc_enabled:
raise ImportError(
"OpenID Connect features are not enabled. To use them, run `pip install ansys-openapi-common[oidc]`."
)
return func(*args, **kwargs)

return wrapper


# Required to allow the ApiClientFactory to be subclassed. This ensures that Pylance
# understands that the subclass is returned by the builder methods instead of the base class
Api_Client_Factory = TypeVar("Api_Client_Factory", bound="ApiClientFactory")
Expand Down Expand Up @@ -261,11 +286,52 @@ def with_autologon(self: Api_Client_Factory) -> Api_Client_Factory:
return self
raise ConnectionError("Unable to connect with autologon.")

def with_oidc(
@require_oidc
def with_oidc_client_credentials_flow(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like it would make more sense to take the token_url here, and leave the api_url as the actual service_url. We could always fall back to the api_url as the token_url if it isn't specified, but it just seems a bit clumsy right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should get all the information we need in the initial 401 response from the API url and from the well known endpoint at the identity provider. I'll check this branch out and have a look at what it's doing today

self: Api_Client_Factory,
client_id: str,
client_secret: str,
scope: Optional[str] = "",
) -> Api_Client_Factory:
"""Set up client authentication for use with OpenID Connect using the Client Credentials flow.

Parameters
----------
client_id : :class:`str`
Resource owner username. Provided by the Identity provider.
client_secret : :class:`str`
Resource owner password. Provided by the Identity provider.
scope : Union[:class:`str`, :class:`list`[:class:`str`]], optional
Single scope or list of scopes required by the application.

Returns
-------
:class:`~ansys.openapi.common.ApiClientFactory`
Current client factory object.

Notes
-----
OIDC Authentication requires the ``[oidc]`` extra to be installed.
"""

auth = get_client_credential_auth(
token_url=self._api_url,
client_id=client_id,
client_secret=client_secret,
scope=scope,
session=copy(self._session),
)
self._session.auth = auth
self._configured = True
return self

@require_oidc
def with_oidc_authorization_flow(
self,
idp_session_configuration: Optional[SessionConfiguration] = None,
) -> "OIDCSessionBuilder":
"""Set up client authentication for use with OpenID Connect.
) -> "AuthorizationSessionBuilder":
"""Set up client authentication for use with OpenID Connect using the authorization flow. Currently
only authorization flow with PKCE is supported.

Parameters
----------
Expand All @@ -274,20 +340,17 @@ def with_oidc(

Returns
-------
:class:`~ansys.openapi.common.OIDCSessionBuilder`
:class:`~ansys.openapi.common.AuthorizationSessionBuilder`
Builder object to authenticate via OIDC.

Notes
-----
OIDC Authentication requires the ``[oidc]`` extra to be installed.
"""
if not _oidc_enabled:
raise ImportError(
"OpenID Connect features are not enabled. To use them, run `pip install ansys-openapi-common[oidc]`."
)

initial_response = self._session.get(self._api_url)
if self.__handle_initial_response(initial_response):
return OIDCSessionBuilder(self)
return AuthorizationSessionBuilder(self)

session_factory = OIDCSessionFactory(
self._session,
Expand All @@ -296,7 +359,7 @@ def with_oidc(
idp_session_configuration,
)

return OIDCSessionBuilder(self, session_factory)
return AuthorizationSessionBuilder(self, session_factory)

def __test_connection(self) -> bool:
"""Attempt to connect to the API server. If this returns a 2XX status code, the method returns
Expand Down Expand Up @@ -380,9 +443,9 @@ def __get_authenticate_header(
return parse_authenticate(response.headers["www-authenticate"])


class OIDCSessionBuilder:
"""Helps create OpenID Connect sessions from different types of input and provides OIDC-specific
configuration options.
class AuthorizationSessionBuilder:
"""Helps create OpenID Connect Authorize Flow sessions from different types of input and provides
configuration options specific to the Authorization Flow.

Parameters
----------
Expand Down
1 change: 0 additions & 1 deletion src/ansys/openapi/common/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import pyparsing as pp
from collections import OrderedDict
from http.server import BaseHTTPRequestHandler, HTTPServer
from itertools import chain
from typing import Dict, Union, List, Optional, Tuple, Any, Collection, cast

Expand Down
4 changes: 2 additions & 2 deletions tests/contrib/test_granta_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_provided_token_session():
builder.with_token.return_value = MagicMock()

with patch.object(
ApiClientFactory, "with_oidc", return_value=builder
ApiClientFactory, "with_oidc_authorization_flow", return_value=builder
) as mock_method:
_ = create_session_from_granta_stk(stk_token_config)
mock_method.assert_called_once_with(None)
Expand All @@ -79,7 +79,7 @@ def test_stored_token_session():
builder.with_stored_token.return_value = MagicMock()

with patch.object(
ApiClientFactory, "with_oidc", return_value=builder
ApiClientFactory, "with_oidc_authorization_flow", return_value=builder
) as mock_method:
_ = create_session_from_granta_stk(stk_token_config)
mock_method.assert_called_once_with(None)
Expand Down
4 changes: 3 additions & 1 deletion tests/test_missing_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def test_create_oidc_with_no_extra_throws(self, mocker):
from ansys.openapi.common import ApiClientFactory

with pytest.raises(ImportError) as excinfo:
_ = ApiClientFactory("http://www.my-api.com/v1.svc").with_oidc()
_ = ApiClientFactory(
"http://www.my-api.com/v1.svc"
).with_oidc_authorization_flow()

package_name = get_package_name()
assert f"`pip install {package_name}[oidc]`" in str(excinfo.value)
Expand Down
30 changes: 24 additions & 6 deletions tests/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,24 @@
import pytest
import requests
import requests_mock
from requests_auth.authentication import OAuth2
from requests_auth.authentication import OAuth2, OAuth2ClientCredentials
from unittest.mock import Mock, MagicMock
from covertable import make

from ansys.openapi.common import ApiClientFactory
from ansys.openapi.common._oidc import OIDCSessionFactory
from ansys.openapi.common._oidc import OIDCSessionFactory, get_client_credential_auth

TOKEN_URL = "https://www.example.com/token"
CLIENT_ID = "3acde603-9bb9-48e7-9eaa-c624c4fd40ca"
CLIENT_SECRET = "000a5e5ecdd7e2dbb2d61b4a7291e0b44b719c786ed1d677587e83ab60bf8bbf"
REQUIRED_HEADERS = {
"clientid": "3acde603-9bb9-48e7-9eaa-c624c4fd40ca",
"clientid": CLIENT_ID,
"authority": "authority.com",
"redirecturi": "http://localhost:1729",
}

WELL_KNOWN_PARAMETERS = {
"token_endpoint": "www.example.com/token",
"token_endpoint": TOKEN_URL,
"authorization_endpoint": "www.example.com/authorization",
}

Expand Down Expand Up @@ -233,7 +236,7 @@ def match_token_request(request):
headers={"WWW-Authenticate": "Bearer error=invalid_token"},
)
with pytest.raises(ValueError) as exception_info:
ApiClientFactory(api_url).with_oidc().with_token(
ApiClientFactory(api_url).with_oidc_authorization_flow().with_token(
refresh_token=refresh_token
)
assert "refresh token was invalid" in str(exception_info)
Expand Down Expand Up @@ -267,8 +270,23 @@ def test_endpoint_with_refresh_configures_correctly():
headers={"WWW-Authenticate": authenticate_header},
)

session = ApiClientFactory(secure_servicelayer_url).with_oidc()
session = ApiClientFactory(
secure_servicelayer_url
).with_oidc_authorization_flow()
auth = session._session_factory._auth

assert auth.token_url == f"{authority_url}token"
assert auth.refresh_data["client_id"] == client_id


def test_get_client_credential_auth():
result = get_client_credential_auth(TOKEN_URL, CLIENT_ID, CLIENT_SECRET, "scope")
assert isinstance(result, OAuth2ClientCredentials)


def test_get_client_credential_auth_succeeds_custom_session():
session = requests.session()
result = get_client_credential_auth(
TOKEN_URL, CLIENT_ID, CLIENT_SECRET, "scope", session
)
assert isinstance(result, OAuth2ClientCredentials)
Loading