diff --git a/openttd_protocol/wire/tcp.py b/openttd_protocol/wire/tcp.py index 5281692..2b2df27 100644 --- a/openttd_protocol/wire/tcp.py +++ b/openttd_protocol/wire/tcp.py @@ -92,16 +92,23 @@ def _detect_source_ip_port(self, data): # This message arrived via the proxy protocol; use the information # from this to figure out the real ip and port. - proxy_end = data.find(b"\r\n") - proxy = data[0:proxy_end].decode() - data = data[proxy_end + 2 :] - # Example how 'proxy' looks: # PROXY TCP4 127.0.0.1 127.0.0.1 33487 12345 + # Search for \r\n, marking the end of the proxy protocol header. + for i in range(len(data) - 1): + if data[i] == 13 and data[i + 1] == 10: + proxy_end = i + break + else: + log.warning("Receive proxy protocol header without end from %s:%d", self.source.ip, self.source.port) + return data + + proxy = data[0:proxy_end].tobytes().decode() (_, _, ip, _, port, _) = proxy.split(" ") self.source = Source(self, self.source.addr, ip, int(port)) - return data + + return data[proxy_end + 2 :] def data_received(self, data): data = memoryview(data) diff --git a/openttd_protocol/wire/test_tcp.py b/openttd_protocol/wire/test_tcp.py index b8e25c9..fe34bb7 100644 --- a/openttd_protocol/wire/test_tcp.py +++ b/openttd_protocol/wire/test_tcp.py @@ -56,7 +56,7 @@ async def test_detect_source_ip_port(proxy_protocol, data, result, ip, port): test.source = Source(test, None, "127.0.0.2", 54321) test.proxy_protocol = proxy_protocol - assert test._detect_source_ip_port(data) == result + assert test._detect_source_ip_port(memoryview(data)) == result assert str(test.source.ip) == ip assert test.source.port == port @@ -74,7 +74,7 @@ async def test_data_received(data, data_left, result): test = OpenTTDProtocolTest(None) test.task.cancel() - test.data_received(data) + test.data_received(memoryview(data)) assert test._data == data_left try: assert test._queue.get_nowait() == result @@ -94,7 +94,7 @@ async def test_receive_packet(data, result): test = OpenTTDProtocolTest(None) test.task.cancel() - assert test.receive_packet(None, data) == result + assert test.receive_packet(None, memoryview(data)) == result @pytest.mark.parametrize( @@ -112,7 +112,7 @@ async def test_receive_packet_failure(data, failure): test.task.cancel() with pytest.raises(failure): - test.receive_packet(None, data) + test.receive_packet(None, memoryview(data)) @pytest.mark.parametrize( @@ -141,7 +141,7 @@ async def receive_PACKET_TWO(source, **kwargs): test.source = Source(test, None, "127.0.0.1", 12345) test.transport = FakeTransport() - test._queue.put_nowait(data) - test._queue.put_nowait(b"\x04\x00\x00") # Force an exception + test._queue.put_nowait(memoryview(data)) + test._queue.put_nowait(memoryview(b"\x04\x00\x00")) # Force an exception await test._process_queue() assert seen_packet[0] is True