Skip to content

Commit

Permalink
[examples] use asyncio.run() to run main entry point
Browse files Browse the repository at this point in the history
  • Loading branch information
jlaine committed Jul 17, 2022
1 parent cd54b92 commit 2c9d904
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 50 deletions.
7 changes: 3 additions & 4 deletions examples/doq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def save_session_ticket(ticket):
pickle.dump(ticket, fp)


async def run(
async def main(
configuration: QuicConfiguration,
host: str,
port: int,
Expand Down Expand Up @@ -155,9 +155,8 @@ async def run(
else:
logger.debug("No session ticket defined...")

loop = asyncio.get_event_loop()
loop.run_until_complete(
run(
asyncio.run(
main(
configuration=configuration,
host=args.host,
port=args.port,
Expand Down
44 changes: 29 additions & 15 deletions examples/doq_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ def pop(self, label: bytes) -> Optional[SessionTicket]:
return self.tickets.pop(label, None)


async def main(
host: str,
port: int,
configuration: QuicConfiguration,
session_ticket_store: SessionTicketStore,
retry: bool,
) -> None:
await serve(
args.host,
args.port,
configuration=configuration,
create_protocol=DnsServerProtocol,
session_ticket_fetcher=session_ticket_store.pop,
session_ticket_handler=session_ticket_store.add,
retry=args.retry,
)
await asyncio.Future()


if __name__ == "__main__":

parser = argparse.ArgumentParser(description="DNS over QUIC server")
Expand Down Expand Up @@ -99,6 +118,7 @@ def pop(self, label: bytes) -> Optional[SessionTicket]:
level=logging.DEBUG if args.verbose else logging.INFO,
)

# create QUIC logger
if args.quic_log:
quic_logger = QuicFileLogger(args.quic_log)
else:
Expand All @@ -112,21 +132,15 @@ def pop(self, label: bytes) -> Optional[SessionTicket]:

configuration.load_cert_chain(args.certificate, args.private_key)

ticket_store = SessionTicketStore()

loop = asyncio.get_event_loop()
loop.run_until_complete(
serve(
args.host,
args.port,
configuration=configuration,
create_protocol=DnsServerProtocol,
session_ticket_fetcher=ticket_store.pop,
session_ticket_handler=ticket_store.add,
retry=args.retry,
)
)
try:
loop.run_forever()
asyncio.run(
main(
host=args.host,
port=args.port,
configuration=configuration,
session_ticket_store=SessionTicketStore(),
retry=args.retry,
)
)
except KeyboardInterrupt:
pass
7 changes: 3 additions & 4 deletions examples/http3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def save_session_ticket(ticket: SessionTicket) -> None:
pickle.dump(ticket, fp)


async def run(
async def main(
configuration: QuicConfiguration,
urls: List[str],
data: Optional[str],
Expand Down Expand Up @@ -544,9 +544,8 @@ async def run(

if uvloop is not None:
uvloop.install()
loop = asyncio.get_event_loop()
loop.run_until_complete(
run(
asyncio.run(
main(
configuration=configuration,
urls=args.url,
data=args.data,
Expand Down
44 changes: 29 additions & 15 deletions examples/http3_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,25 @@ def pop(self, label: bytes) -> Optional[SessionTicket]:
return self.tickets.pop(label, None)


async def main(
host: str,
port: int,
configuration: QuicConfiguration,
session_ticket_store: SessionTicketStore,
retry: bool,
) -> None:
await serve(
args.host,
args.port,
configuration=configuration,
create_protocol=HttpServerProtocol,
session_ticket_fetcher=session_ticket_store.pop,
session_ticket_handler=session_ticket_store.add,
retry=args.retry,
)
await asyncio.Future()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="QUIC server")
parser.add_argument(
Expand Down Expand Up @@ -565,23 +584,18 @@ def pop(self, label: bytes) -> Optional[SessionTicket]:
# load SSL certificate and key
configuration.load_cert_chain(args.certificate, args.private_key)

ticket_store = SessionTicketStore()

if uvloop is not None:
uvloop.install()
loop = asyncio.get_event_loop()
loop.run_until_complete(
serve(
args.host,
args.port,
configuration=configuration,
create_protocol=HttpServerProtocol,
session_ticket_fetcher=ticket_store.pop,
session_ticket_handler=ticket_store.add,
retry=args.retry,
)
)

try:
loop.run_forever()
asyncio.run(
main(
host=args.host,
port=args.port,
configuration=configuration,
session_ticket_store=SessionTicketStore(),
retry=args.retry,
)
)
except KeyboardInterrupt:
pass
7 changes: 3 additions & 4 deletions examples/httpx_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def save_session_ticket(ticket):
pickle.dump(ticket, fp)


async def run(
async def main(
configuration: QuicConfiguration,
url: str,
data: str,
Expand Down Expand Up @@ -287,9 +287,8 @@ async def run(
except FileNotFoundError:
pass

loop = asyncio.get_event_loop()
loop.run_until_complete(
run(
asyncio.run(
main(
configuration=configuration,
url=args.url,
data=args.data,
Expand Down
9 changes: 5 additions & 4 deletions examples/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ async def test_nat_rebinding(server: Server, configuration: QuicConfiguration):

# replace transport
protocol._transport.close()
loop = asyncio.get_event_loop()
await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0))

# cause more traffic
Expand Down Expand Up @@ -379,6 +380,7 @@ async def test_address_mobility(server: Server, configuration: QuicConfiguration

# replace transport
protocol._transport.close()
loop = asyncio.get_event_loop()
await loop.create_datagram_endpoint(lambda: protocol, local_addr=("::", 0))

# change connection ID
Expand Down Expand Up @@ -480,7 +482,7 @@ def print_result(server: Server) -> None:
print("%s%s%s" % (server.name, " " * (20 - len(server.name)), result))


async def run(servers, tests, quic_log=False, secrets_log_file=None) -> None:
async def main(servers, tests, quic_log=False, secrets_log_file=None) -> None:
for server in servers:
if server.structured_logging:
server.result |= Result.L
Expand Down Expand Up @@ -557,9 +559,8 @@ async def run(servers, tests, quic_log=False, secrets_log_file=None) -> None:
if args.test:
tests = list(filter(lambda x: x[0] == args.test, tests))

loop = asyncio.get_event_loop()
loop.run_until_complete(
run(
asyncio.run(
main(
servers=servers,
tests=tests,
quic_log=args.quic_log,
Expand Down
11 changes: 7 additions & 4 deletions examples/siduck_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def quic_event_received(self, event: QuicEvent) -> None:
waiter.set_result(None)


async def run(configuration: QuicConfiguration, host: str, port: int) -> None:
async def main(configuration: QuicConfiguration, host: str, port: int) -> None:
async with connect(
host, port, configuration=configuration, create_protocol=SiduckClient
) as client:
Expand Down Expand Up @@ -91,7 +91,10 @@ async def run(configuration: QuicConfiguration, host: str, port: int) -> None:
if args.secrets_log:
configuration.secrets_log_file = open(args.secrets_log, "a")

loop = asyncio.get_event_loop()
loop.run_until_complete(
run(configuration=configuration, host=args.host, port=args.port)
asyncio.run(
main(
configuration=configuration,
host=args.host,
port=args.port,
)
)

0 comments on commit 2c9d904

Please sign in to comment.