-
Notifications
You must be signed in to change notification settings - Fork 23
/
test_transports.py
132 lines (95 loc) · 3.54 KB
/
test_transports.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
# SPDX-FileCopyrightText: AISEC Pentesting Team
#
# SPDX-License-Identifier: Apache-2.0
import asyncio
import binascii
from collections.abc import AsyncIterator, Callable
import pytest
from gallia.log import setup_logging
from gallia.transports import BaseTransport, TargetURI, TCPLinesTransport, TCPTransport
listen_target = TargetURI("tcp://127.0.0.1:1234")
test_data = [b"hello" b"tcp"]
setup_logging()
class TCPServer:
def __init__(self) -> None:
self.queue: asyncio.Queue[TCPTransport] = asyncio.Queue(1)
async def _accept_cb(
self,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
await self.queue.put(TCPTransport(TargetURI("tcp://"), reader, writer))
async def listen(self, target: TargetURI) -> None:
self.server = await asyncio.start_server(
self._accept_cb,
host=target.hostname,
port=target.port,
)
async def accept(self) -> TCPTransport:
return await self.queue.get()
def close(self) -> None:
self.server.close()
async def _echo_test(
client: BaseTransport,
server: BaseTransport,
line: bytes,
converter: Callable[[bytes], bytes] | None = None,
) -> None:
data = converter(line) if converter is not None else line
await client.write(line)
d = await server.read()
assert data == d
await server.write(data)
d = await client.read()
assert line == d
@pytest.fixture()
@pytest.mark.asyncio
async def tcp_server() -> AsyncIterator[TCPServer]:
tcp_server = TCPServer()
await tcp_server.listen(listen_target)
yield tcp_server
tcp_server.close()
@pytest.mark.asyncio
async def test_tcp_wrong_scheme(tcp_server: TCPServer) -> None:
with pytest.raises(ValueError):
await TCPTransport.connect(TargetURI("foo://123"))
@pytest.mark.asyncio
async def test_tcp_reconnect(tcp_server: TCPServer) -> None:
client = await TCPTransport.connect(listen_target)
await tcp_server.accept()
client = await client.reconnect()
await tcp_server.accept()
@pytest.mark.asyncio
async def test_tcp_echo(tcp_server: TCPServer) -> None:
client = await TCPTransport.connect(listen_target)
server = await tcp_server.accept()
for line in test_data:
await _echo_test(client, server, line)
@pytest.mark.asyncio
async def test_tcp_linesep_echo(tcp_server: TCPServer) -> None:
client = await TCPLinesTransport.connect(TargetURI("tcp-lines://127.0.0.1:1234"))
server = await tcp_server.accept()
def converter(data: bytes) -> bytes:
return binascii.hexlify(data) + b"\n"
for line in test_data:
await _echo_test(client, server, line, converter)
@pytest.mark.asyncio
async def test_tcp_close(tcp_server: TCPServer) -> None:
client = await TCPTransport.connect(listen_target)
server = await tcp_server.accept()
await client.close()
await server.close()
@pytest.mark.asyncio
async def test_tcp_linesep_request(tcp_server: TCPServer) -> None:
client = await TCPLinesTransport.connect(TargetURI("tcp-lines://127.0.0.1:1234"))
server = await tcp_server.accept()
await server.write(binascii.hexlify(b"world") + b"\n")
resp = await client.request(b"hello")
await server.read()
assert resp == b"world"
@pytest.mark.asyncio
async def test_tcp_timeout(tcp_server: TCPServer) -> None:
client = await TCPLinesTransport.connect(TargetURI("tcp-lines://127.0.0.1:1234"))
await tcp_server.accept()
with pytest.raises(asyncio.TimeoutError):
await client.request(b"hello", timeout=0.5)