Skip to content

Commit

Permalink
[config] add QuicConfiguration.load_cert_chain method
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaine committed Sep 24, 2019
1 parent 5337bae commit 214808c
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 86 deletions.
42 changes: 23 additions & 19 deletions aioquic/quic/configuration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dataclasses import dataclass, field
from os import PathLike
from typing import Any, List, Optional, TextIO

from ..tls import SessionTicket
from ..tls import SessionTicket, load_pem_private_key, load_pem_x509_certificates
from .logger import QuicLogger
from .packet import QuicProtocolVersion

Expand All @@ -17,15 +18,6 @@ class QuicConfiguration:
A list of supported ALPN protocols.
"""

certificate: Any = None
"""
The server's TLS certificate.
See :func:`cryptography.x509.load_pem_x509_certificate`.
.. note:: This is only used by servers.
"""

connection_id_length: int = 8
"""
The length in bytes of local connection IDs.
Expand All @@ -43,15 +35,6 @@ class QuicConfiguration:
Whether this is the client side of the QUIC connection.
"""

private_key: Any = None
"""
The server's TLS private key.
See :func:`cryptography.hazmat.primitives.serialization.load_pem_private_key`.
.. note:: This is only used by servers.
"""

quic_logger: Optional[QuicLogger] = None
"""
The :class:`~aioquic.quic.logger.QuicLogger` instance to log events to.
Expand All @@ -76,9 +59,30 @@ class QuicConfiguration:
The TLS session ticket which should be used for session resumption.
"""

certificate: Any = None
certificate_chain: List[Any] = field(default_factory=list)
private_key: Any = None
supported_versions: List[int] = field(
default_factory=lambda: [
QuicProtocolVersion.DRAFT_23,
QuicProtocolVersion.DRAFT_22,
]
)

def load_cert_chain(
self,
certfile: PathLike,
keyfile: Optional[PathLike] = None,
password: Optional[str] = None,
):
"""
Load a private key and the corresponding certificate.
"""
with open(certfile, "rb") as fp:
certificates = load_pem_x509_certificates(fp.read())
self.certificate = certificates[0]
self.certificate_chain = certificates[1:]

if keyfile is not None:
with open(keyfile, "rb") as fp:
self.private_key = load_pem_private_key(fp.read(), password=password)
1 change: 1 addition & 0 deletions aioquic/quic/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,6 +1036,7 @@ def _initialize(self, peer_cid: bytes) -> None:
self.tls = tls.Context(is_client=self._is_client, logger=self._logger)
self.tls.alpn_protocols = self._configuration.alpn_protocols
self.tls.certificate = self._configuration.certificate
self.tls.certificate_chain = self._configuration.certificate_chain
self.tls.certificate_private_key = self._configuration.private_key
self.tls.handshake_extensions = [
(
Expand Down
33 changes: 31 additions & 2 deletions aioquic/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from cryptography import x509
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives import hashes, hmac, serialization
from cryptography.hazmat.primitives.asymmetric import dsa, ec, padding, rsa, x25519
from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand
from cryptography.hazmat.primitives.serialization import Encoding, PublicFormat
Expand Down Expand Up @@ -164,6 +164,33 @@ def hkdf_extract(
return h.finalize()


def load_pem_private_key(
data: bytes, password: Optional[str]
) -> Union[dsa.DSAPrivateKey, ec.EllipticCurvePrivateKey, rsa.RSAPrivateKey]:
"""
Load a PEM-encoded private key.
"""
return serialization.load_pem_private_key(
data, password=password, backend=default_backend()
)


def load_pem_x509_certificates(data: bytes) -> List[x509.Certificate]:
"""
Load a chain of PEM-encoded X509 certificates.
"""
boundary = b"-----END CERTIFICATE-----\n"
certificates = []
for chunk in data.split(boundary):
if chunk:
certificates.append(
x509.load_pem_x509_certificate(
chunk + boundary, backend=default_backend()
)
)
return certificates


def verify_certificate(
certificate: x509.Certificate, server_name: Optional[str]
) -> None:
Expand Down Expand Up @@ -1064,6 +1091,7 @@ def __init__(
self.alpn_negotiated: Optional[str] = None
self.alpn_protocols: Optional[List[str]] = None
self.certificate: Optional[x509.Certificate] = None
self.certificate_chain: List[x509.Certificate] = []
self.certificate_private_key: Optional[
Union[dsa.DSAPublicKey, ec.EllipticCurvePublicKey, rsa.RSAPublicKey]
] = None
Expand Down Expand Up @@ -1650,7 +1678,8 @@ def _server_handle_hello(
Certificate(
request_context=b"",
certificates=[
(self.certificate.public_bytes(Encoding.DER), b"")
(x.public_bytes(Encoding.DER), b"")
for x in [self.certificate] + self.certificate_chain
],
),
)
Expand Down
19 changes: 4 additions & 15 deletions examples/http3_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@

import wsproto
import wsproto.events
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from aioquic.asyncio import QuicConnectionProtocol, serve
from aioquic.h0.connection import H0Connection
Expand Down Expand Up @@ -449,24 +446,16 @@ def end_trace(self, trace: QuicLoggerTrace) -> None:
else:
secrets_log_file = None

# load SSL certificate and key
with open(args.certificate, "rb") as fp:
certificate = x509.load_pem_x509_certificate(
fp.read(), backend=default_backend()
)
with open(args.private_key, "rb") as fp:
private_key = serialization.load_pem_private_key(
fp.read(), password=None, backend=default_backend()
)

configuration = QuicConfiguration(
alpn_protocols=["h3-23", "h3-22", "hq-23", "hq-22"],
certificate=certificate,
is_client=False,
private_key=private_key,
quic_logger=quic_logger,
secrets_log_file=secrets_log_file,
)

# load SSL certificate and key
configuration.load_cert_chain(args.certificate, args.private_key)

ticket_store = SessionTicketStore()

if uvloop is not None:
Expand Down
26 changes: 10 additions & 16 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aioquic.quic.logger import QuicLogger
from aioquic.quic.packet import QuicProtocolVersion

from .utils import SERVER_CERTIFICATE, SERVER_PRIVATE_KEY, generate_ec_certificate, run
from .utils import SERVER_CERTFILE, SERVER_KEYFILE, generate_ec_certificate, run

real_sendto = socket.socket.sendto

Expand Down Expand Up @@ -59,11 +59,8 @@ async def serve():

async def run_server(configuration=None, **kwargs):
if configuration is None:
configuration = QuicConfiguration(
certificate=SERVER_CERTIFICATE,
private_key=SERVER_PRIVATE_KEY,
is_client=False,
)
configuration = QuicConfiguration(is_client=False)
configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)
return await serve(
host="::",
port="4433",
Expand Down Expand Up @@ -133,18 +130,15 @@ def test_connect_and_serve_with_packet_loss(self, mock_sendto):
and received in the presence of packet loss (randomized 25% in each direction).
"""
data = b"Z" * 65536

server_configuration = QuicConfiguration(
idle_timeout=300.0, is_client=False, quic_logger=QuicLogger()
)
server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)

server, response = run(
asyncio.gather(
run_server(
configuration=QuicConfiguration(
certificate=SERVER_CERTIFICATE,
idle_timeout=300.0,
is_client=False,
private_key=SERVER_PRIVATE_KEY,
quic_logger=QuicLogger(),
),
stateless_retry=True,
),
run_server(configuration=server_configuration, stateless_retry=True),
run_client(
"127.0.0.1",
configuration=QuicConfiguration(
Expand Down
38 changes: 14 additions & 24 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from aioquic.quic.packet_builder import QuicDeliveryState, QuicPacketBuilder

from .utils import SERVER_CERTIFICATE, SERVER_PRIVATE_KEY
from .utils import SERVER_CERTFILE, SERVER_KEYFILE

CLIENT_ADDR = ("1.2.3.4", 1234)

Expand Down Expand Up @@ -95,16 +95,12 @@ def client_and_server(
client._ack_delay = 0
client_patch(client)

server = QuicConnection(
configuration=QuicConfiguration(
is_client=False,
certificate=SERVER_CERTIFICATE,
private_key=SERVER_PRIVATE_KEY,
quic_logger=QuicLogger(),
**server_options
),
**server_kwargs
server_configuration = QuicConfiguration(
is_client=False, quic_logger=QuicLogger(), **server_options
)
server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)

server = QuicConnection(configuration=server_configuration, **server_kwargs)
server._ack_delay = 0
server_patch(server)

Expand Down Expand Up @@ -278,13 +274,10 @@ def datagram_sizes(items):
client = QuicConnection(configuration=QuicConfiguration(is_client=True))
client._ack_delay = 0

server = QuicConnection(
configuration=QuicConfiguration(
is_client=False,
certificate=SERVER_CERTIFICATE,
private_key=SERVER_PRIVATE_KEY,
)
)
server_configuration = QuicConfiguration(is_client=False)
server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)

server = QuicConnection(configuration=server_configuration)
server._ack_delay = 0

# client sends INITIAL
Expand Down Expand Up @@ -350,13 +343,10 @@ def datagram_sizes(items):
client = QuicConnection(configuration=QuicConfiguration(is_client=True))
client._ack_delay = 0

server = QuicConnection(
configuration=QuicConfiguration(
is_client=False,
certificate=SERVER_CERTIFICATE,
private_key=SERVER_PRIVATE_KEY,
)
)
server_configuration = QuicConfiguration(is_client=False)
server_configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)

server = QuicConnection(configuration=server_configuration)
server._ack_delay = 0

# client sends INITIAL
Expand Down
10 changes: 7 additions & 3 deletions tests/test_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from aioquic import tls
from aioquic.buffer import Buffer, BufferReadError
from aioquic.quic.configuration import QuicConfiguration
from aioquic.tls import (
Certificate,
CertificateVerify,
Expand Down Expand Up @@ -36,7 +37,7 @@
verify_certificate,
)

from .utils import SERVER_CERTIFICATE, SERVER_PRIVATE_KEY, generate_ec_certificate, load
from .utils import SERVER_CERTFILE, SERVER_KEYFILE, generate_ec_certificate, load

CERTIFICATE_DATA = load("tls_certificate.bin")[11:-2]
CERTIFICATE_VERIFY_SIGNATURE = load("tls_certificate_verify.bin")[-384:]
Expand Down Expand Up @@ -103,9 +104,12 @@ def create_client(self):
return client

def create_server(self):
configuration = QuicConfiguration(is_client=False)
configuration.load_cert_chain(SERVER_CERTFILE, SERVER_KEYFILE)

server = Context(is_client=False)
server.certificate = SERVER_CERTIFICATE
server.certificate_private_key = SERVER_PRIVATE_KEY
server.certificate = configuration.certificate
server.certificate_private_key = configuration.private_key
server.handshake_extensions = [
(
tls.ExtensionType.QUIC_TRANSPORT_PARAMETERS,
Expand Down
10 changes: 3 additions & 7 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec


Expand Down Expand Up @@ -46,12 +46,8 @@ def run(coro):
return asyncio.get_event_loop().run_until_complete(coro)


SERVER_CERTIFICATE = x509.load_pem_x509_certificate(
load("ssl_cert.pem"), backend=default_backend()
)
SERVER_PRIVATE_KEY = serialization.load_pem_private_key(
load("ssl_key.pem"), password=None, backend=default_backend()
)
SERVER_CERTFILE = os.path.join(os.path.dirname(__file__), "ssl_cert.pem")
SERVER_KEYFILE = os.path.join(os.path.dirname(__file__), "ssl_key.pem")

if os.environ.get("AIOQUIC_DEBUG"):
logging.basicConfig(level=logging.DEBUG)

0 comments on commit 214808c

Please sign in to comment.