Skip to content

Commit

Permalink
Merge pull request #41 from joarleymoraes/master
Browse files Browse the repository at this point in the history
added support to set timeout for TAXII requests
  • Loading branch information
traut committed Jun 19, 2017
2 parents bbef6ce + bf53cf7 commit 11a1d42
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 13 deletions.
9 changes: 6 additions & 3 deletions cabby/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AbstractClient(object):
taxii_version = None

def __init__(self, host=None, discovery_path=None, port=None,
use_https=False, headers=None):
use_https=False, headers=None, timeout=None):

self.host = host
self.port = port
Expand All @@ -47,6 +47,7 @@ def __init__(self, host=None, discovery_path=None, port=None,

self.jwt_token = None
self.headers = headers or {}
self.timeout = timeout

self.log = logging.getLogger(
"{}.{}".format(self.__module__, self.__class__.__name__))
Expand Down Expand Up @@ -193,13 +194,15 @@ def _execute_request(self, request, uri=None, service_type=None):
'key_file': self.key_file,
'key_password': self.key_password,
'ca_cert': self.ca_cert
})
},
timeout=self.timeout)
else:
message = dispatcher.send_taxii_request(
session,
self._prepare_url(uri),
request,
taxii_binding=self.taxii_binding)
taxii_binding=self.taxii_binding,
timeout=self.timeout)

return message

Expand Down
14 changes: 9 additions & 5 deletions cabby/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def raise_http_error(status_code, response_stream=None):


def send_taxii_request(session, url, request, taxii_binding=None,
tls_details=None):
tls_details=None, timeout=None):
'''
Send XML message to a TAXII service and parse a response.
'''
Expand All @@ -55,7 +55,7 @@ def send_taxii_request(session, url, request, taxii_binding=None,
# https://github.com/kennethreitz/requests/issues/2519 is fixed
try:
response = get_response_using_key_pass(
url, request_body, session, **tls_details)
url, request_body, session, timeout=timeout, **tls_details)
except urllib.error.HTTPError as e:
log.error(
"Error while connecting to {}".format(url),
Expand All @@ -64,7 +64,8 @@ def send_taxii_request(session, url, request, taxii_binding=None,

stream, headers = response, response.headers
else:
response = session.post(url, data=request_body, stream=True)
response = session.post(url, data=request_body, stream=True,
timeout=timeout)
if not response.ok:
raise_http_error(response.status_code, response.raw)

Expand Down Expand Up @@ -369,7 +370,7 @@ def obtain_jwt_token(session, jwt_url, username, password):


def get_response_using_key_pass(url, data, session, cert_file, key_file,
key_password, ca_cert=None):
key_password, ca_cert=None, timeout=None):

if sys.version_info < (2, 7, 9):
raise ValueError(
Expand Down Expand Up @@ -405,4 +406,7 @@ def get_response_using_key_pass(url, data, session, cert_file, key_file,

request = urllib.request.Request(url, data, headers)

return opener.open(request)
if timeout:
return opener.open(request, timeout=timeout)
else:
return opener.open(request)
41 changes: 36 additions & 5 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import json
import gzip
import sys
import requests
from time import sleep

from six import StringIO

Expand Down Expand Up @@ -55,7 +57,6 @@ def get_sent_message(version):

@pytest.mark.parametrize("version", [11, 10])
def test_set_headers(version):

httpretty.reset()
httpretty.enable()

Expand Down Expand Up @@ -86,7 +87,6 @@ def test_set_headers(version):

@pytest.mark.parametrize("version", [11, 10])
def test_invalid_response(version):

httpretty.reset()
httpretty.enable()

Expand All @@ -112,7 +112,6 @@ def test_invalid_response(version):

@pytest.mark.parametrize("version", [11, 10])
def test_invalid_response_status(version):

httpretty.reset()
httpretty.enable()

Expand All @@ -132,7 +131,6 @@ def test_invalid_response_status(version):

@pytest.mark.parametrize("version", [11, 10])
def test_jwt_auth_response(version):

httpretty.reset()
httpretty.enable()

Expand Down Expand Up @@ -202,7 +200,6 @@ def compress(text):

@pytest.mark.parametrize("version", [11, 10])
def test_gzip_response(version):

httpretty.reset()
httpretty.enable()

Expand All @@ -222,3 +219,37 @@ def test_gzip_response(version):

httpretty.disable()
httpretty.reset()


@pytest.mark.parametrize("version", [11, 10])
def test_timeout(version):
httpretty.reset()
httpretty.enable()

timeout_in_sec = 1

client = make_client(version)
#
# configure to raise the error before the timeout
#
client.timeout = timeout_in_sec / 2.0

def timeout_request_callback(request, uri, headers):
sleep(timeout_in_sec)

return 200, headers, {'result': 'success'}

uri = get_fix(version).DISCOVERY_URI_HTTP

httpretty.register_uri(
httpretty.POST,
uri,
body=timeout_request_callback,
content_type='application/json'
)

with pytest.raises(requests.exceptions.Timeout):
client.discover_services(uri=uri)

httpretty.disable()
httpretty.reset()

0 comments on commit 11a1d42

Please sign in to comment.