diff --git a/src/cryptojwt/key_bundle.py b/src/cryptojwt/key_bundle.py index 2c87966e..30026f2e 100755 --- a/src/cryptojwt/key_bundle.py +++ b/src/cryptojwt/key_bundle.py @@ -152,7 +152,7 @@ class KeyBundle: def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True, fileformat="jwks", keytype="RSA", keyusage=None, kid='', - httpc=None): + httpc=None, httpc_params=None): """ Contains a set of keys that have a common origin. The sources can be serveral: @@ -171,6 +171,8 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True, :param keyusage: What the key loaded from file should be used for. Only applicable for DER files :param httpc: A HTTP client function + :param httpc_params: Additional parameters to pass to the HTTP client + function """ self._keys = [] @@ -193,6 +195,7 @@ def __init__(self, keys=None, source="", cache_time=300, verify_ssl=True, else: self.httpc = requests.request self.verify_ssl = verify_ssl + self.httpc_params = httpc_params or {} if keys: self.source = None @@ -314,13 +317,11 @@ def do_remote(self): :return: True or False if load was successful """ if self.verify_ssl is not None: - args = {"verify": self.verify_ssl} - else: - args = {} + self.httpc_params["verify"] = self.verify_ssl try: LOGGER.debug('KeyBundle fetch keys from: %s', self.source) - _http_resp = self.httpc('GET', self.source, **args) + _http_resp = self.httpc('GET', self.source, **self.httpc_params) except Exception as err: LOGGER.error(err) raise UpdateFailed( diff --git a/tests/test_03_key_bundle.py b/tests/test_03_key_bundle.py index b935ee98..e65f5aa6 100755 --- a/tests/test_03_key_bundle.py +++ b/tests/test_03_key_bundle.py @@ -5,6 +5,8 @@ import time import pytest +import requests +import responses from cryptography.hazmat.primitives.asymmetric import rsa from cryptojwt.jwk.ec import new_ec_key from cryptojwt.jwk.hmac import SYMKey @@ -471,6 +473,31 @@ def test_local_jwk_copy(): # assert len(kb.get('oct')) == 1 +@pytest.fixture() +def mocked_jwks_response(): + with responses.RequestsMock() as rsps: + yield rsps + + +def test_httpc_params_1(): + source = 'https://login.salesforce.com/id/keys' # From test_jwks_url() + # Mock response + responses.add(method=responses.GET, url=source, json=JWKS_DICT, status=200) + httpc_params = {'timeout': (2, 2)} # connect, read timeouts in seconds + kb = KeyBundle(source=source, httpc=requests.request, + httpc_params=httpc_params) + assert kb.do_remote() + + +def test_httpc_params_2(): + httpc_params = {'timeout': 0} + kb = KeyBundle(source='https://login.salesforce.com/id/keys', + httpc=requests.request, httpc_params=httpc_params) + # Will always fail to fetch the JWKS because the timeout cannot be set + # to 0s + assert not kb.update() + + def test_update_2(): rsa_key = new_rsa_key() _jwks = {"keys": [rsa_key.serialize()]}