Skip to content

Commit

Permalink
tls: Add support for PSK (TLS and DTLS)
Browse files Browse the repository at this point in the history
  • Loading branch information
Synss committed Feb 20, 2020
1 parent 76e2163 commit 40d0436
Show file tree
Hide file tree
Showing 5 changed files with 369 additions and 16 deletions.
1 change: 1 addition & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[next]

* tls: Add support to PSK for (D)TLS
* tls: Fixup `access()` method for DTLS on Python 2.7 and 3.4.

[1.0.0] - 2020-01-05
Expand Down
30 changes: 30 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,33 @@ Now, the DTLS communication is complete.
>>> dtls_cli.close()
>>> runner.join(0.1)
>>> dtls_srv.close()


Pre-shared key (PSK) for TLS and DTLS
-------------------------------------

PSK authentication is supported for TLS and DTLS, both server
and client side. The client configuration is a tuple with an
identifier (UTF-8 encoded) and the secret key,

>>> cli_conf = DTLSConfiguration(
... pre_shared_key=("client42", b"the secret")
... )

and the server configuration receives the key store as a
`Mapping[unicode, bytes]` of identifiers and keys. For example,

>>> srv_conf = DTLSConfiguration(
... ciphers=(
... "TLS-ECDHE-PSK-WITH-AES-128-CBC-SHA256",
... "TLS-PSK-WITH-AES-128-CBC-SHA256",
... ),
... pre_shared_key_store={
... "client0": b"a secret",
... "client1": b"other secret",
... "client42": b"the secret",
... "client100": b"yet another one",
... },
... )

The rest of the session is the same as in the previous sections.
31 changes: 27 additions & 4 deletions src/mbedtls/tls.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ cdef extern from "mbedtls/ssl.h" nogil:
# set_anti_replay
unsigned int anti_replay


unsigned int endpoint
unsigned int transport
# set_validate_certificates
Expand All @@ -164,6 +163,12 @@ cdef extern from "mbedtls/ssl.h" nogil:
_x509.mbedtls_x509_crt *ca_crl
# set_sni_callback
# f_sni / p_sni
# for mbedtls_ssl_conf_psk_cb
void *p_psk
unsigned char *psk
size_t psk_len
unsigned char *psk_identity
size_t psk_identity_len

ctypedef struct mbedtls_ssl_context:
const mbedtls_ssl_config *conf
Expand Down Expand Up @@ -230,7 +235,12 @@ cdef extern from "mbedtls/ssl.h" nogil:
_x509.mbedtls_x509_crt *own_cert,
_pk.mbedtls_pk_context *pk_key)

# mbedtls_ssl_conf_psk
int mbedtls_ssl_conf_psk(
mbedtls_ssl_config *conf,
const unsigned char *psk,
size_t psk_len,
const unsigned char *psk_identity,
size_t psk_identity_len)
# mbedtls_ssl_conf_dh_param
# mbedtls_ssl_conf_dh_param_ctx
# mbedtls_ssl_conf_dhm_min_bitlen
Expand Down Expand Up @@ -271,7 +281,10 @@ cdef extern from "mbedtls/ssl.h" nogil:
)

# mbedtls_ssl_conf_session_cache
# mbedtls_ssl_conf_psk_cb
void mbedtls_ssl_conf_psk_cb(
mbedtls_ssl_config *conf,
int (*f_psk)(void *, mbedtls_ssl_context *, const unsigned char *, size_t),
void *psk_store)
void mbedtls_ssl_conf_sni(
mbedtls_ssl_config *conf,
int (*f_sni)(void *, mbedtls_ssl_context *, const unsigned char*,
Expand Down Expand Up @@ -328,7 +341,10 @@ cdef extern from "mbedtls/ssl.h" nogil:
int mbedtls_ssl_set_session(
const mbedtls_ssl_context *ssl,
mbedtls_ssl_session *session)
# mbedtls_ssl_set_hs_psk
int mbedtls_ssl_set_hs_psk(
mbedtls_ssl_context *ssl,
const unsigned char *psk,
size_t psk_len)
int mbedtls_ssl_set_hostname(
mbedtls_ssl_context *ssl,
const char *hostname)
Expand Down Expand Up @@ -395,11 +411,16 @@ cdef class _DTLSCookie:
cdef mbedtls_ssl_cookie_ctx _ctx


cdef class _PSKSToreProxy:
cdef object _mapping


cdef class _BaseConfiguration:
cdef mbedtls_ssl_config _ctx
cdef _chain
cdef int *_ciphers
cdef char **_protos
cdef _PSKSToreProxy _store
# cdef'd because we aim at a non-writable structure.
cdef _set_validate_certificates(self, validate)
cdef _set_certificate_chain(self, chain)
Expand All @@ -409,6 +430,8 @@ cdef class _BaseConfiguration:
cdef _set_highest_supported_version(self, version)
cdef _set_trust_store(self, object store)
cdef _set_sni_callback(self, callback)
cdef _set_pre_shared_key(self, psk)
cdef _set_pre_shared_key_store(self, psk_store)


cdef class TLSConfiguration(_BaseConfiguration):
Expand Down
138 changes: 134 additions & 4 deletions src/mbedtls/tls.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,35 @@ from mbedtls.exceptions import *
cdef _rnd.Random __rng = _rnd.default_rng()


cdef class _PSKSToreProxy:
def __init__(self, psk_store):
if not isinstance(psk_store, abc.Mapping):
raise TypeError("Mapping expected but got %r instead" % psk_store)
self._mapping = psk_store

def unwrap(self):
return self._mapping

def __repr__(self):
return self._mapping.__repr__()

def __str__(self):
return self._mapping.__str__()

def __getitem__(self, key):
return self._mapping.__getitem__(key)

def __iter__(self):
return self._mapping.__iter__()

def __len__(self):
return self._mapping.__len__()


# Python 2.7: `register()` can be used as a decorator from 3.3.
abc.Mapping.register(_PSKSToreProxy)


@cython.boundscheck(False)
cdef void _my_debug(void *ctx, int level,
const char *file, int line, const char *str) nogil:
Expand Down Expand Up @@ -70,6 +99,27 @@ cdef int buffer_read(void *ctx, unsigned char *buf, const size_t len) nogil:
return _rb.c_readinto(c_buf.recv_ctx, buf, len)


@cython.boundscheck(False)
cdef int _psk_cb(
void *parameter,
_tls.mbedtls_ssl_context *ctx,
const unsigned char *c_identity,
size_t c_identity_len
) nogil:
"""Wrapper for the PSK callback."""
# If a valid PSK identity is found, call `mbedtls_ssl_set_hs_psk()` and
# return 0. Otherwise, return 1.
with gil:
store = <_tls._PSKSToreProxy> parameter
identity = c_identity[:c_identity_len]
try:
psk = store[identity.decode("utf8")]
_tls.mbedtls_ssl_set_hs_psk(ctx, psk, len(psk))
return 0
except Exception:
return 1


def _set_debug_level(int level):
"""Set debug level for logging."""
_tls.mbedtls_debug_set_threshold(level)
Expand Down Expand Up @@ -118,7 +168,7 @@ def ciphers_available():
while ids[n]:
ciphersuites.append(__get_ciphersuite_name(ids[n]))
n += 1
return ciphersuites
return tuple(ciphersuites)


class NextProtocol(Enum):
Expand Down Expand Up @@ -290,6 +340,8 @@ cdef class _BaseConfiguration:
highest_supported_version=None,
trust_store=None,
sni_callback=None,
pre_shared_key=None,
pre_shared_key_store=None,
_transport=None,
):
check_error(_tls.mbedtls_ssl_config_defaults(
Expand All @@ -306,6 +358,8 @@ cdef class _BaseConfiguration:
self._set_highest_supported_version(highest_supported_version)
self._set_trust_store(trust_store)
self._set_sni_callback(sni_callback)
self._set_pre_shared_key(pre_shared_key)
self._set_pre_shared_key_store(pre_shared_key_store)

# Set random engine.
_tls.mbedtls_ssl_conf_rng(
Expand Down Expand Up @@ -345,7 +399,9 @@ cdef class _BaseConfiguration:
"lowest_supported_version=%r, "
"highest_supported_version=%r, "
"trust_store=%r, "
"sni_callback=%r)"
"sni_callback=%r, "
"pre_shared_key=%r, "
"pre_shared_key_store=%r)"
% (type(self).__name__,
self.validate_certificates,
self.certificate_chain,
Expand All @@ -354,7 +410,10 @@ cdef class _BaseConfiguration:
self.lowest_supported_version,
self.highest_supported_version,
self.trust_store,
self.sni_callback))
self.sni_callback,
self.pre_shared_key,
self.pre_shared_key_store,
))

cdef _set_validate_certificates(self, validate):
"""Set the certificate verification mode.
Expand Down Expand Up @@ -446,7 +505,7 @@ cdef class _BaseConfiguration:
if cipher_id == 0:
break
ciphers.append(__get_ciphersuite_name(cipher_id))
return ciphers
return tuple(ciphers)

cdef _set_inner_protocols(self, protocols):
"""
Expand Down Expand Up @@ -556,6 +615,49 @@ cdef class _BaseConfiguration:
def sni_callback(self):
return None

cdef _set_pre_shared_key(self, psk):
"""Set a pre shared key (PSK) for the client.
Args:
psk ([Tuple[unicode, bytes]]): A tuple with the key and the exected
identity name.
"""
if psk is None:
return
try:
identity, key = psk
except ValueError:
raise TypeError("expected a tuple (name, key)")
c_identity = identity.encode("utf8")
check_error(_tls.mbedtls_ssl_conf_psk(
&self._ctx,
key, len(key),
c_identity, len(c_identity)))

@property
def pre_shared_key(self):
if self._ctx.psk == NULL or self._ctx.psk_identity == NULL:
return None
key = self._ctx.psk[:self._ctx.psk_len]
c_identity = self._ctx.psk_identity[:self._ctx.psk_identity_len]
identity = c_identity.decode("utf8")
return (identity, key)

cdef _set_pre_shared_key_store(self, psk_store):
# server-side
if psk_store is None:
return
self._store = _PSKSToreProxy(psk_store) # ownership
_tls.mbedtls_ssl_conf_psk_cb(&self._ctx, _psk_cb, <void *> self._store)

@property
def pre_shared_key_store(self):
if self._ctx.p_psk == NULL:
return None
psk_store = <_tls._PSKSToreProxy> self._ctx.p_psk
return psk_store.unwrap()

def update(self, *args):
raise NotImplementedError

Expand All @@ -572,6 +674,8 @@ cdef class TLSConfiguration(_BaseConfiguration):
highest_supported_version=None,
trust_store=None,
sni_callback=None,
pre_shared_key=None,
pre_shared_key_store=None,
):
super().__init__(
validate_certificates=validate_certificates,
Expand All @@ -582,6 +686,8 @@ cdef class TLSConfiguration(_BaseConfiguration):
highest_supported_version=highest_supported_version,
trust_store=trust_store,
sni_callback=sni_callback,
pre_shared_key=pre_shared_key,
pre_shared_key_store=pre_shared_key_store,
_transport=_tls.MBEDTLS_SSL_TRANSPORT_STREAM,
)

Expand All @@ -603,6 +709,8 @@ cdef class TLSConfiguration(_BaseConfiguration):
highest_supported_version=_DEFAULT_VALUE,
trust_store=_DEFAULT_VALUE,
sni_callback=_DEFAULT_VALUE,
pre_shared_key=_DEFAULT_VALUE,
pre_shared_key_store=_DEFAULT_VALUE,
):
"""Create a new ``TLSConfiguration``.
Expand Down Expand Up @@ -634,6 +742,12 @@ cdef class TLSConfiguration(_BaseConfiguration):
if sni_callback is _DEFAULT_VALUE:
sni_callback = self.sni_callback

if pre_shared_key is _DEFAULT_VALUE:
pre_shared_key = self.pre_shared_key

if pre_shared_key_store is _DEFAULT_VALUE:
pre_shared_key_store = self.pre_shared_key_store

return self.__class__(
validate_certificates=validate_certificates,
certificate_chain=certificate_chain,
Expand All @@ -643,6 +757,8 @@ cdef class TLSConfiguration(_BaseConfiguration):
highest_supported_version=highest_supported_version,
trust_store=trust_store,
sni_callback=sni_callback,
pre_shared_key=pre_shared_key,
pre_shared_key_store=pre_shared_key_store,
)


Expand All @@ -661,6 +777,8 @@ cdef class DTLSConfiguration(_BaseConfiguration):
# badmac_limit
# handshake_timeout
sni_callback=None,
pre_shared_key=None,
pre_shared_key_store=None,
):
super().__init__(
validate_certificates=validate_certificates,
Expand All @@ -671,6 +789,8 @@ cdef class DTLSConfiguration(_BaseConfiguration):
highest_supported_version=highest_supported_version,
trust_store=trust_store,
sni_callback=sni_callback,
pre_shared_key=pre_shared_key,
pre_shared_key_store=pre_shared_key_store,
_transport=_tls.MBEDTLS_SSL_TRANSPORT_DATAGRAM,
)
self._set_anti_replay(anti_replay)
Expand Down Expand Up @@ -730,6 +850,8 @@ cdef class DTLSConfiguration(_BaseConfiguration):
highest_supported_version=_DEFAULT_VALUE,
trust_store=_DEFAULT_VALUE,
sni_callback=_DEFAULT_VALUE,
pre_shared_key=_DEFAULT_VALUE,
pre_shared_key_store=_DEFAULT_VALUE,
anti_replay=_DEFAULT_VALUE,
):
"""Create a new ``DTLSConfiguration``.
Expand Down Expand Up @@ -762,6 +884,12 @@ cdef class DTLSConfiguration(_BaseConfiguration):
if sni_callback is _DEFAULT_VALUE:
sni_callback = self.sni_callback

if pre_shared_key is _DEFAULT_VALUE:
pre_shared_key = self.pre_shared_key

if pre_shared_key_store is _DEFAULT_VALUE:
pre_shared_key_store = self.pre_shared_key_store

if anti_replay is _DEFAULT_VALUE:
anti_replay = self.anti_replay

Expand All @@ -774,6 +902,8 @@ cdef class DTLSConfiguration(_BaseConfiguration):
highest_supported_version=highest_supported_version,
trust_store=trust_store,
sni_callback=sni_callback,
pre_shared_key=pre_shared_key,
pre_shared_key_store=pre_shared_key_store,
anti_replay=anti_replay,
)

Expand Down

0 comments on commit 40d0436

Please sign in to comment.