Skip to content

Commit

Permalink
Merge pull request #626 from OpenIDC/refactor_client
Browse files Browse the repository at this point in the history
Refactor client
  • Loading branch information
tpazderka committed Mar 19, 2019
2 parents 64e5944 + e480f7c commit 21ed5af
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 144 deletions.
24 changes: 12 additions & 12 deletions src/oic/oauth2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from oic import CC_METHOD
from oic import OIDCONF_PATTERN
from oic import unreserved
from oic.exception import CommunicationError
from oic.oauth2.base import PBase
from oic.oauth2.exception import GrantError
from oic.oauth2.exception import HttpError
Expand Down Expand Up @@ -896,21 +897,20 @@ def provider_config(self, issuer, keys=True, endpoints=True,
url = serv_pattern % _issuer

pcr = None
r = self.http_request(url)
r = self.http_request(url, allow_redirects=True)
if r.status_code == 200:
pcr = response_cls().from_json(r.text)
elif r.status_code == 302:
while r.status_code == 302:
r = self.http_request(r.headers["location"])
if r.status_code == 200:
pcr = response_cls().from_json(r.text)
break

if pcr is None:
raise PyoidcError("Trying '%s', status %s" % (url, r.status_code))
try:
pcr = response_cls().from_json(r.text)
except Exception as e:
# FIXME: This should catch specific exception from `from_json()`
_err_txt = "Faulty provider config response: {}".format(e)
logger.error(sanitize(_err_txt))
raise ParseError(_err_txt)
else:
raise CommunicationError("Trying '%s', status %s" % (url, r.status_code))

self.store_response(pcr, r.text)
self.handle_provider_config(pcr, issuer, keys, endpoints)

return pcr


Expand Down
119 changes: 18 additions & 101 deletions src/oic/oic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from oic.exception import AccessDenied
from oic.exception import AuthnToOld
from oic.exception import AuthzError
from oic.exception import CommunicationError
from oic.exception import IssuerMismatch
from oic.exception import MissingParameter
from oic.exception import ParameterError
from oic.exception import PyoidcError
Expand Down Expand Up @@ -62,7 +60,6 @@
from oic.oic.message import UserInfoRequest
from oic.utils import time_util
from oic.utils.http_util import Response
from oic.utils.keyio import KeyJar
from oic.utils.sanitize import sanitize
from oic.utils.webfinger import OIC_ISSUER
from oic.utils.webfinger import WebFinger
Expand Down Expand Up @@ -470,10 +467,8 @@ def construct_AuthorizationRequest(self, request=AuthorizationRequest,
request_param = "request"
del kwargs["request_method"]

areq = oauth2.Client.construct_AuthorizationRequest(self, request,
request_args,
extra_args,
**kwargs)
areq = super().construct_AuthorizationRequest(request=request, request_args=request_args, extra_args=extra_args,
**kwargs)

if request_param:
alg = None
Expand Down Expand Up @@ -526,19 +521,16 @@ def construct_AccessTokenRequest(self, request=AccessTokenRequest,
request_args=None, extra_args=None,
**kwargs):

return oauth2.Client.construct_AccessTokenRequest(self, request,
request_args,
extra_args, **kwargs)
return super().construct_AccessTokenRequest(request=request, request_args=request_args, extra_args=extra_args,
**kwargs)

def construct_RefreshAccessTokenRequest(self,
request=RefreshAccessTokenRequest,
request_args=None, extra_args=None,
**kwargs):

return oauth2.Client.construct_RefreshAccessTokenRequest(self, request,
request_args,
extra_args,
**kwargs)
return super().construct_RefreshAccessTokenRequest(requests=request, request_args=request_args,
extra_args=extra_args, **kwargs)

def construct_UserInfoRequest(self, request=UserInfoRequest,
request_args=None, extra_args=None,
Expand Down Expand Up @@ -643,12 +635,9 @@ def do_authorization_request(self, request=AuthorizationRequest,
_args, code_verifier = self.add_code_challenge()
request_args.update(_args)

return oauth2.Client.do_authorization_request(self, request, state,
body_type, method,
request_args,
extra_args, http_args,
response_cls,
algs=algs)
return super().do_authorization_request(request=request, state=state, body_type=body_type, method=method,
request_args=request_args, extra_args=extra_args, http_args=http_args,
response_cls=response_cls, algs=algs)

def do_access_token_request(self, request=AccessTokenRequest,
scope="", state="", body_type="json",
Expand All @@ -657,11 +646,10 @@ def do_access_token_request(self, request=AccessTokenRequest,
response_cls=AccessTokenResponse,
authn_method="client_secret_basic", **kwargs):

atr = oauth2.Client.do_access_token_request(self, request, scope,
state, body_type, method,
request_args, extra_args,
http_args, response_cls,
authn_method, **kwargs)
atr = super().do_access_token_request(request=request, scope=scope, state=state, body_type=body_type,
method=method, request_args=request_args, extra_args=extra_args,
http_args=http_args, response_cls=response_cls, authn_method=authn_method,
**kwargs)
try:
_idt = atr['id_token']
except KeyError:
Expand All @@ -681,11 +669,9 @@ def do_access_token_refresh(self, request=RefreshAccessTokenRequest,
response_cls=AccessTokenResponse,
**kwargs):

return oauth2.Client.do_access_token_refresh(self, request, state,
body_type, method,
request_args,
extra_args, http_args,
response_cls, **kwargs)
return super().do_access_token_refresh(request=request, state=state, body_type=body_type, method=method,
requset_args=request_args, extra_args=extra_args, http_args=http_args,
response_cls=response_cls, **kwargs)

def do_registration_request(self, request=RegistrationRequest,
scope="", state="", body_type="json",
Expand Down Expand Up @@ -958,80 +944,11 @@ def get_userinfo_claims(self, access_token, endpoint, method="POST",
self.store_response(res, resp.text)
return res

def handle_provider_config(self, pcr, issuer, keys=True, endpoints=True):
"""
Deal with Provider Config Response.
:param pcr: The ProviderConfigResponse instance
:param issuer: The one I thought should be the issuer of the config
:param keys: Should I deal with keys
:param endpoints: Should I deal with endpoints, that is store them as attributes in self.
"""
if "issuer" in pcr:
_pcr_issuer = pcr["issuer"]
if pcr["issuer"].endswith("/"):
if issuer.endswith("/"):
_issuer = issuer
else:
_issuer = issuer + "/"
else:
if issuer.endswith("/"):
_issuer = issuer[:-1]
else:
_issuer = issuer

try:
self.allow["issuer_mismatch"]
except KeyError:
if _issuer != _pcr_issuer:
raise IssuerMismatch("'%s' != '%s'" % (_issuer, _pcr_issuer), pcr)

self.provider_info = pcr
else:
_pcr_issuer = issuer

if endpoints:
for key, val in pcr.items():
if key.endswith("_endpoint"):
setattr(self, key, val)

if keys:
if self.keyjar is None:
self.keyjar = KeyJar(verify_ssl=self.verify_ssl)

self.keyjar.load_keys(pcr, _pcr_issuer)

def provider_config(self, issuer, keys=True, endpoints=True,
response_cls=ProviderConfigurationResponse,
serv_pattern=OIDCONF_PATTERN):
if issuer.endswith("/"):
_issuer = issuer[:-1]
else:
_issuer = issuer

url = serv_pattern % _issuer

pcr = None
r = self.http_request(url, allow_redirects=True)
if r.status_code == 200:
try:
pcr = response_cls().from_json(r.text)
except Exception as e:
# FIXME: This should catch specific exception from `from_json()`
_err_txt = "Faulty provider config response: {}".format(e)
logger.error(sanitize(_err_txt))
raise ParseError(_err_txt)

logger.debug("Provider info: %s" % sanitize(pcr))
if pcr is None:
raise CommunicationError(
"Trying '%s', status %s" % (url, r.status_code))

self.store_response(pcr, r.text)

self.handle_provider_config(pcr, issuer, keys, endpoints)

return pcr
return super().provider_config(issuer=issuer, keys=keys, endpoints=endpoints, response_cls=response_cls,
serv_pattern=serv_pattern)

def unpack_aggregated_claims(self, userinfo):
if userinfo["_claim_sources"]:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_oauth2_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from urllib.parse import urlparse

import pytest
import responses

from oic import rndstr
from oic.exception import AuthzError
Expand All @@ -11,6 +12,7 @@
from oic.oauth2.consumer import stateID
from oic.oauth2.message import SINGLE_OPTIONAL_INT
from oic.oauth2.message import AccessTokenResponse
from oic.oauth2.message import ASConfigurationResponse
from oic.oauth2.message import AuthorizationErrorResponse
from oic.oauth2.message import AuthorizationResponse
from oic.oauth2.message import MissingRequiredAttribute
Expand Down Expand Up @@ -220,6 +222,17 @@ def test_consumer_client_auth_info(self):
assert ha == {}
assert extra == {'auth_method': 'bearer_body'}

def test_provider_config(self):
c = Consumer(None, None)
response = ASConfigurationResponse(**{'issuer': 'https://example.com',
'end_session_endpoint': 'https://example.com/end_session'})
with responses.RequestsMock() as rsps:
rsps.add(responses.GET, 'https://example.com/.well-known/openid-configuration', json=response.to_dict())
info = c.provider_config('https://example.com')
assert isinstance(info, ASConfigurationResponse)
assert _eq(info.keys(), ['issuer', 'version', 'end_session_endpoint'])
assert info["end_session_endpoint"] == "https://example.com/end_session"

def test_client_get_access_token_request(self):
self.consumer.client_secret = "secret0"
_state = "state"
Expand Down
31 changes: 0 additions & 31 deletions tests/test_oic_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,37 +492,6 @@ def test_discover(self, fake_oic_server):
res = c.discover(principal)
assert res == "https://localhost:8088/"

def test_provider_config(self, fake_oic_server):
c = Consumer(None, None)
mfos = fake_oic_server("https://example.com")
mfos.keyjar = SRVKEYS
c.http_request = mfos.http_request

principal = "foo@example.com"

res = c.discover(principal)
info = c.provider_config(res)
assert isinstance(info, ProviderConfigurationResponse)
assert _eq(info.keys(), ['registration_endpoint', 'jwks_uri',
'check_session_endpoint',
'refresh_session_endpoint',
'register_endpoint',
'subject_types_supported',
'token_endpoint_auth_methods_supported',
'id_token_signing_alg_values_supported',
'grant_types_supported', 'user_info_endpoint',
'claims_parameter_supported',
'request_parameter_supported',
'discovery_endpoint', 'issuer',
'authorization_endpoint', 'scopes_supported',
'require_request_uri_registration',
'identifiers_supported', 'token_endpoint',
'request_uri_parameter_supported', 'version',
'response_types_supported',
'end_session_endpoint', 'flows_supported'])

assert info["end_session_endpoint"] == "https://example.com/end_session"

def test_client_register(self, fake_oic_server):
c = Consumer(None, None)

Expand Down

0 comments on commit 21ed5af

Please sign in to comment.