|
1 |
| -from ussl import * |
2 |
| -import ussl as _ussl |
| 1 | +import tls |
| 2 | +from tls import ( |
| 3 | + CERT_NONE, |
| 4 | + CERT_OPTIONAL, |
| 5 | + CERT_REQUIRED, |
| 6 | + MBEDTLS_VERSION, |
| 7 | + PROTOCOL_TLS_CLIENT, |
| 8 | + PROTOCOL_TLS_SERVER, |
| 9 | +) |
3 | 10 |
|
4 |
| -# Constants |
5 |
| -for sym in "CERT_NONE", "CERT_OPTIONAL", "CERT_REQUIRED": |
6 |
| - if sym not in globals(): |
7 |
| - globals()[sym] = object() |
| 11 | + |
| 12 | +class SSLContext: |
| 13 | + def __init__(self, *args): |
| 14 | + self._context = tls.SSLContext(*args) |
| 15 | + self._context.verify_mode = CERT_NONE |
| 16 | + |
| 17 | + @property |
| 18 | + def verify_mode(self): |
| 19 | + return self._context.verify_mode |
| 20 | + |
| 21 | + @verify_mode.setter |
| 22 | + def verify_mode(self, val): |
| 23 | + self._context.verify_mode = val |
| 24 | + |
| 25 | + def load_cert_chain(self, certfile, keyfile): |
| 26 | + if isinstance(certfile, str): |
| 27 | + with open(certfile, "rb") as f: |
| 28 | + certfile = f.read() |
| 29 | + if isinstance(keyfile, str): |
| 30 | + with open(keyfile, "rb") as f: |
| 31 | + keyfile = f.read() |
| 32 | + self._context.load_cert_chain(certfile, keyfile) |
| 33 | + |
| 34 | + def load_verify_locations(self, cafile=None, cadata=None): |
| 35 | + if cafile: |
| 36 | + with open(cafile, "rb") as f: |
| 37 | + cadata = f.read() |
| 38 | + self._context.load_verify_locations(cadata) |
| 39 | + |
| 40 | + def wrap_socket( |
| 41 | + self, sock, server_side=False, do_handshake_on_connect=True, server_hostname=None |
| 42 | + ): |
| 43 | + return self._context.wrap_socket( |
| 44 | + sock, |
| 45 | + server_side=server_side, |
| 46 | + do_handshake_on_connect=do_handshake_on_connect, |
| 47 | + server_hostname=server_hostname, |
| 48 | + ) |
8 | 49 |
|
9 | 50 |
|
10 | 51 | def wrap_socket(
|
11 | 52 | sock,
|
12 |
| - keyfile=None, |
13 |
| - certfile=None, |
14 | 53 | server_side=False,
|
| 54 | + key=None, |
| 55 | + cert=None, |
15 | 56 | cert_reqs=CERT_NONE,
|
16 |
| - *, |
17 |
| - ca_certs=None, |
18 |
| - server_hostname=None |
| 57 | + cadata=None, |
| 58 | + server_hostname=None, |
| 59 | + do_handshake=True, |
19 | 60 | ):
|
20 |
| - # TODO: More arguments accepted by CPython could also be handled here. |
21 |
| - # That would allow us to accept ca_certs as a positional argument, which |
22 |
| - # we should. |
23 |
| - kw = {} |
24 |
| - if keyfile is not None: |
25 |
| - kw["keyfile"] = keyfile |
26 |
| - if certfile is not None: |
27 |
| - kw["certfile"] = certfile |
28 |
| - if server_side is not False: |
29 |
| - kw["server_side"] = server_side |
30 |
| - if cert_reqs is not CERT_NONE: |
31 |
| - kw["cert_reqs"] = cert_reqs |
32 |
| - if ca_certs is not None: |
33 |
| - kw["ca_certs"] = ca_certs |
34 |
| - if server_hostname is not None: |
35 |
| - kw["server_hostname"] = server_hostname |
36 |
| - return _ussl.wrap_socket(sock, **kw) |
| 61 | + con = SSLContext(PROTOCOL_TLS_SERVER if server_side else PROTOCOL_TLS_CLIENT) |
| 62 | + if cert or key: |
| 63 | + con.load_cert_chain(cert, key) |
| 64 | + if cadata: |
| 65 | + con.load_verify_locations(cadata=cadata) |
| 66 | + con.verify_mode = cert_reqs |
| 67 | + return con.wrap_socket( |
| 68 | + sock, |
| 69 | + server_side=server_side, |
| 70 | + do_handshake_on_connect=do_handshake, |
| 71 | + server_hostname=server_hostname, |
| 72 | + ) |
0 commit comments