Skip to content

Commit

Permalink
tls: Added support for mbedtls_ssl_conf_read_timeout (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
sebachm94 committed Jan 5, 2024
1 parent 665fe81 commit 9a30d6f
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 1 deletion.
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[next]

* tls: Add `mbedtls_ssl_conf_read_timeout`, for the read timeout
configuration

[2.8.0] - 2023-11-28

* ci: Update wheels to mbedtls 2.28.6
Expand Down
8 changes: 7 additions & 1 deletion src/mbedtls/_tls.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ cdef extern from "mbedtls/ssl.h" nogil:
# set_handshake_timeout
unsigned int hs_timeout_min
unsigned int hs_timeout_max
# set_read_timeout
unsigned int read_timeout

unsigned int endpoint
unsigned int transport
Expand Down Expand Up @@ -247,7 +249,10 @@ cdef extern from "mbedtls/ssl.h" nogil:
void *p_dbg
)

# mbedtls_ssl_conf_read_timeout
void mbedtls_ssl_conf_read_timeout(
mbedtls_ssl_config *conf,
unsigned int timeout
)
# mbedtls_ssl_conf_session_tickets_cb
# mbedtls_ssl_conf_export_keys_cb

Expand Down Expand Up @@ -427,6 +432,7 @@ cdef class MbedTLSConfiguration:
cdef _set_max_fragmentation_length(self, object mfl)
cdef _set_anti_replay(self, mode)
cdef _set_handshake_timeout(self, minimum, maximum)
cdef _set_read_timeout(self, timeout)
cdef _set_cookie(self, _DTLSCookie cookie)
cdef _set_sni_callback(self, callback)
cdef _set_pre_shared_key(self, psk)
Expand Down
34 changes: 34 additions & 0 deletions src/mbedtls/_tls.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,7 @@ cdef class MbedTLSConfiguration:
# badmac_limit
handshake_timeout_min,
handshake_timeout_max,
read_timeout,
sni_callback,
pre_shared_key,
pre_shared_key_store,
Expand All @@ -423,6 +424,7 @@ cdef class MbedTLSConfiguration:
self._set_handshake_timeout(
handshake_timeout_min, handshake_timeout_max
)
self._set_read_timeout(read_timeout)
self._set_sni_callback(sni_callback)
self._set_pre_shared_key(pre_shared_key)
self._set_pre_shared_key_store(pre_shared_key_store)
Expand Down Expand Up @@ -479,6 +481,7 @@ cdef class MbedTLSConfiguration:
self.anti_replay,
self.handshake_timeout_min,
self.handshake_timeout_max,
self.read_timeout,
self.sni_callback,
self.pre_shared_key,
self.pre_shared_key_store,
Expand Down Expand Up @@ -802,6 +805,34 @@ cdef class MbedTLSConfiguration:

return float(self._ctx.hs_timeout_max) / 1000.0

cdef _set_read_timeout(self, timeout):
"""Set TLS/DTLS read timeout.
Use 0 for no timeout.
Args:
timeout (float, optional): read timeout in seconds.
"""
if timeout is None:
return

def validate(extremum, *, default: float) -> float:
if extremum is None:
return default
if extremum < 0.0:
raise ValueError(extremum)
return extremum

_tls.mbedtls_ssl_conf_read_timeout(
&self._ctx,
int(1000.0 * validate(timeout, default=0))
)

@property
def read_timeout(self):
"""Read timeout in seconds. Use 0 for no timeout. (default 0)."""
return float(self._ctx.read_timeout) / 1000.0

cdef _set_sni_callback(self, callback):
# PEP 543, optional, server-side only
if callback is None:
Expand Down Expand Up @@ -916,6 +947,7 @@ cdef class _BaseContext:
anti_replay=None,
handshake_timeout_min=None,
handshake_timeout_max=None,
read_timeout=None,
sni_callback=configuration.sni_callback,
pre_shared_key=configuration.pre_shared_key,
pre_shared_key_store=configuration.pre_shared_key_store,
Expand All @@ -938,6 +970,7 @@ cdef class _BaseContext:
anti_replay=configuration.anti_replay,
handshake_timeout_min=configuration.handshake_timeout_min,
handshake_timeout_max=configuration.handshake_timeout_max,
read_timeout=configuration.read_timeout,
sni_callback=configuration.sni_callback,
pre_shared_key=configuration.pre_shared_key,
pre_shared_key_store=configuration.pre_shared_key_store,
Expand Down Expand Up @@ -991,6 +1024,7 @@ cdef class _BaseContext:
anti_replay=self._conf.anti_replay,
handshake_timeout_min=self._conf.handshake_timeout_min,
handshake_timeout_max=self._conf.handshake_timeout_max,
read_timeout=self._conf.read_timeout,
sni_callback=self._conf.sni_callback,
pre_shared_key=self._conf.pre_shared_key,
pre_shared_key_store=self._conf.pre_shared_key_store,
Expand Down
11 changes: 11 additions & 0 deletions src/mbedtls/_tlsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ class TLSConfiguration:
highest_supported_version: TLSVersion = TLSVersion.MAXIMUM_SUPPORTED
trust_store: Optional[TrustStore] = None
max_fragmentation_length: Optional[MaxFragmentLength] = None
read_timeout: float = 0.0
sni_callback: Optional[ServerNameCallback] = None
pre_shared_key: Optional[Tuple[str, bytes]] = None
pre_shared_key_store: Mapping[str, bytes] = field(default_factory=dict)
Expand Down Expand Up @@ -250,6 +251,7 @@ def __eq__(self, other: object) -> bool:
self.sni_callback == other.sni_callback,
self.pre_shared_key == other.pre_shared_key,
self.pre_shared_key_store == other.pre_shared_key_store,
self.read_timeout == other.read_timeout,
)
)

Expand All @@ -263,6 +265,7 @@ def update(
highest_supported_version: _Wrap[TLSVersion] = _DEFAULT_VALUE,
trust_store: _Wrap[TrustStore] = _DEFAULT_VALUE,
max_fragmentation_length: _Wrap[MaxFragmentLength] = _DEFAULT_VALUE,
read_timeout: _Wrap[float] = _DEFAULT_VALUE,
sni_callback: _Wrap[Optional[ServerNameCallback]] = _DEFAULT_VALUE,
pre_shared_key: _Wrap[Tuple[str, bytes]] = _DEFAULT_VALUE,
pre_shared_key_store: _Wrap[Mapping[str, bytes]] = _DEFAULT_VALUE,
Expand Down Expand Up @@ -296,6 +299,7 @@ def update(
self.max_fragmentation_length,
),
sni_callback=_unwrap(sni_callback, self.sni_callback),
read_timeout=_unwrap(read_timeout, self.read_timeout),
pre_shared_key=_unwrap(pre_shared_key, self.pre_shared_key),
pre_shared_key_store=_unwrap(
pre_shared_key_store,
Expand All @@ -318,6 +322,7 @@ class DTLSConfiguration:
anti_replay: bool = True
handshake_timeout_min: float = 1.0
handshake_timeout_max: float = 60.0
read_timeout: float = 0.0
sni_callback: Optional[ServerNameCallback] = None
pre_shared_key: Optional[Tuple[str, bytes]] = None
pre_shared_key_store: Mapping[str, bytes] = field(default_factory=dict)
Expand Down Expand Up @@ -375,6 +380,7 @@ def __eq__(self, other: object) -> bool:
self.anti_replay == other.anti_replay,
self.handshake_timeout_min == other.handshake_timeout_min,
self.handshake_timeout_max == other.handshake_timeout_max,
self.read_timeout == other.read_timeout,
self.sni_callback == other.sni_callback,
self.pre_shared_key == other.pre_shared_key,
self.pre_shared_key_store == other.pre_shared_key_store,
Expand All @@ -394,6 +400,7 @@ def update(
anti_replay: _Wrap[bool] = _DEFAULT_VALUE,
handshake_timeout_min: _Wrap[float] = _DEFAULT_VALUE,
handshake_timeout_max: _Wrap[float] = _DEFAULT_VALUE,
read_timeout: _Wrap[float] = _DEFAULT_VALUE,
sni_callback: _Wrap[ServerNameCallback] = _DEFAULT_VALUE,
pre_shared_key: _Wrap[Tuple[str, bytes]] = _DEFAULT_VALUE,
pre_shared_key_store: _Wrap[Mapping[str, bytes]] = _DEFAULT_VALUE,
Expand Down Expand Up @@ -435,6 +442,10 @@ def update(
handshake_timeout_max,
self.handshake_timeout_max,
),
read_timeout=_unwrap(
read_timeout,
self.read_timeout,
),
sni_callback=_unwrap(sni_callback, self.sni_callback),
pre_shared_key=_unwrap(pre_shared_key, self.pre_shared_key),
pre_shared_key_store=_unwrap(
Expand Down
14 changes: 14 additions & 0 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,20 @@ def test_handshake_timeout_default(
assert_conf_invariant(conf, handshake_timeout_min=hs_min)
assert_conf_invariant(conf, handshake_timeout_max=hs_max)

@pytest.mark.parametrize("timeout", [1, 10, 5.3, 300])
def test_read_timeout_default(
self,
conf: Union[TLSConfiguration, DTLSConfiguration],
timeout: float,
) -> None:
assert conf.read_timeout == 0
conf_ = conf.update(
read_timeout=timeout,
)
assert conf_.read_timeout == timeout

assert_conf_invariant(conf, read_timeout=timeout)


class TestContext:
@pytest.fixture(params=[Purpose.SERVER_AUTH, Purpose.CLIENT_AUTH])
Expand Down

0 comments on commit 9a30d6f

Please sign in to comment.