Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: proxy-protocol support was broken #1

Merged
merged 1 commit into from Jul 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 12 additions & 5 deletions openttd_protocol/wire/tcp.py
Expand Up @@ -92,16 +92,23 @@ def _detect_source_ip_port(self, data):

# This message arrived via the proxy protocol; use the information
# from this to figure out the real ip and port.
proxy_end = data.find(b"\r\n")
proxy = data[0:proxy_end].decode()
data = data[proxy_end + 2 :]

# Example how 'proxy' looks:
# PROXY TCP4 127.0.0.1 127.0.0.1 33487 12345

# Search for \r\n, marking the end of the proxy protocol header.
for i in range(len(data) - 1):
if data[i] == 13 and data[i + 1] == 10:
proxy_end = i
break
else:
log.warning("Receive proxy protocol header without end from %s:%d", self.source.ip, self.source.port)
return data

proxy = data[0:proxy_end].tobytes().decode()
(_, _, ip, _, port, _) = proxy.split(" ")
self.source = Source(self, self.source.addr, ip, int(port))
return data

return data[proxy_end + 2 :]

def data_received(self, data):
data = memoryview(data)
Expand Down
12 changes: 6 additions & 6 deletions openttd_protocol/wire/test_tcp.py
Expand Up @@ -56,7 +56,7 @@ async def test_detect_source_ip_port(proxy_protocol, data, result, ip, port):
test.source = Source(test, None, "127.0.0.2", 54321)
test.proxy_protocol = proxy_protocol

assert test._detect_source_ip_port(data) == result
assert test._detect_source_ip_port(memoryview(data)) == result
assert str(test.source.ip) == ip
assert test.source.port == port

Expand All @@ -74,7 +74,7 @@ async def test_data_received(data, data_left, result):
test = OpenTTDProtocolTest(None)
test.task.cancel()

test.data_received(data)
test.data_received(memoryview(data))
assert test._data == data_left
try:
assert test._queue.get_nowait() == result
Expand All @@ -94,7 +94,7 @@ async def test_receive_packet(data, result):
test = OpenTTDProtocolTest(None)
test.task.cancel()

assert test.receive_packet(None, data) == result
assert test.receive_packet(None, memoryview(data)) == result


@pytest.mark.parametrize(
Expand All @@ -112,7 +112,7 @@ async def test_receive_packet_failure(data, failure):
test.task.cancel()

with pytest.raises(failure):
test.receive_packet(None, data)
test.receive_packet(None, memoryview(data))


@pytest.mark.parametrize(
Expand Down Expand Up @@ -141,7 +141,7 @@ async def receive_PACKET_TWO(source, **kwargs):
test.source = Source(test, None, "127.0.0.1", 12345)
test.transport = FakeTransport()

test._queue.put_nowait(data)
test._queue.put_nowait(b"\x04\x00\x00") # Force an exception
test._queue.put_nowait(memoryview(data))
test._queue.put_nowait(memoryview(b"\x04\x00\x00")) # Force an exception
await test._process_queue()
assert seen_packet[0] is True