From a91ad650773987e3f88e6a992da9d0dc90e10929 Mon Sep 17 00:00:00 2001 From: Mathias Laurin Date: Sat, 20 May 2023 12:05:29 +0200 Subject: [PATCH] tls: Wrap `mbedtls_ssl_set_mtu` Closes #82 --- src/mbedtls/_tls.pxd | 7 +++++++ src/mbedtls/_tls.pyi | 1 + src/mbedtls/_tls.pyx | 14 ++++++++++++++ src/mbedtls/tls.py | 11 +++++++++++ tests/test_tls.py | 27 +++++++++++++++++++++++++++ 5 files changed, 60 insertions(+) diff --git a/src/mbedtls/_tls.pxd b/src/mbedtls/_tls.pxd index 118f0994..88fde581 100644 --- a/src/mbedtls/_tls.pxd +++ b/src/mbedtls/_tls.pxd @@ -1,6 +1,8 @@ # SPDX-License-Identifier: MIT # Copyright (c) 2018, Mathias Laurin +from libc.stdint cimport uint16_t + cimport mbedtls._ecdh as _ecdh cimport mbedtls._ringbuf as _rb cimport mbedtls._timing as _timing @@ -313,6 +315,11 @@ cdef extern from "mbedtls/ssl.h" nogil: mbedtls_ssl_recv_p f_recv, mbedtls_ssl_recv_timeout_p f_recv_timeout) + void mbedtls_ssl_set_mtu( + mbedtls_ssl_context *ssl, + uint16_t mtu, + ) + void mbedtls_ssl_set_timer_cb( # DTLS mbedtls_ssl_context *ssl, diff --git a/src/mbedtls/_tls.pyi b/src/mbedtls/_tls.pyi index d840d751..943f473b 100644 --- a/src/mbedtls/_tls.pyi +++ b/src/mbedtls/_tls.pyi @@ -111,6 +111,7 @@ class MbedTLSBuffer: def _server_hostname(self) -> str: ... def shutdown(self) -> None: ... def setcookieparam(self, info: bytes) -> None: ... + def setmtu(self, mtu: int) -> None: ... def read(self, amt: int) -> bytes: ... def readinto(self, buffer: bytes, amt: int) -> int: ... def write(self, buffer: bytes) -> int: ... diff --git a/src/mbedtls/_tls.pyx b/src/mbedtls/_tls.pyx index dd167ae1..32b5d4ba 100644 --- a/src/mbedtls/_tls.pyx +++ b/src/mbedtls/_tls.pyx @@ -1130,6 +1130,20 @@ cdef class MbedTLSBuffer: info.size, ) + def setmtu(self, mtu): + """Set Maxiumum Transport Unit (MTU) for DTLS. + + Set to zero to unset. + + Raises: + OverflowError: If value cannot be converted to UInt16. + + """ + # DTLS + if not isinstance(mtu, int): + raise TypeError(mtu) + _tls.mbedtls_ssl_set_mtu(&self._ctx, mtu) + def _reset(self): _exc.check_error(_tls.mbedtls_ssl_session_reset(&self._ctx)) diff --git a/src/mbedtls/tls.py b/src/mbedtls/tls.py index aade0c92..1860048b 100644 --- a/src/mbedtls/tls.py +++ b/src/mbedtls/tls.py @@ -280,6 +280,17 @@ def _handshake_state(self) -> HandshakeStep: # pylint: disable=protected-access return self._buffer._handshake_state + def setmtu(self, mtu: int) -> None: + """Set Maxiumum Transport Unit (MTU) for DTLS. + + Set to zero to unset. + + Raises: + OverflowError: If value cannot be converted to UInt16. + + """ + self._buffer.setmtu(mtu) + def accept(self) -> Tuple[TLSWrappedSocket, _Address]: if self.type == _pysocket.SOCK_STREAM: conn, address = self._socket.accept() diff --git a/tests/test_tls.py b/tests/test_tls.py index b2f04229..850ba673 100644 --- a/tests/test_tls.py +++ b/tests/test_tls.py @@ -827,6 +827,33 @@ def test_psk(self) -> None: assert do_send(secret, src=client, dst=server) == secret assert do_send(secret, src=server, dst=client) == secret + @pytest.mark.parametrize("mtu_cli", [0, 128, 380, 500, (1 << 16) - 1]) + @pytest.mark.parametrize("mtu_srv", [0, 128, 380, 500, (1 << 16) - 1]) + def test_psk_set_mtu(self, mtu_cli: int, mtu_srv: int) -> None: + psk = ("cli", b"secret") + server = ServerContext( + DTLSConfiguration( + pre_shared_key_store=dict((psk,)), + validate_certificates=False, + ) + ).wrap_buffers() + server.setmtu(mtu_srv) + client = ClientContext( + DTLSConfiguration( + pre_shared_key=psk, + validate_certificates=False, + ), + ).wrap_buffers("hostname") + client.setmtu(mtu_cli) + make_hello_verify_request( + client=client, server=server, cookie="🍪🍪🍪".encode() + ) + make_full_handshake(client=client, server=server) + + secret = b"a very secret message" + assert do_send(secret, src=client, dst=server) == secret + assert do_send(secret, src=server, dst=client) == secret + def test_resume_from_pickle(self) -> None: psk = ("cli", b"secret") server = ServerContext(