Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions proxy/http/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,27 +189,20 @@ async def handle_events(
return False

def handle_data(self, data: memoryview) -> Optional[bool]:
"""Handles incoming data from client."""
if data is None:
logger.debug('Client closed connection, tearing down...')
self.work.closed = True
return True

try:
# HttpProtocolHandlerPlugin.on_client_data
# Can raise HttpProtocolException to tear down the connection
for plugin in self.plugins.values():
optional_data = plugin.on_client_data(data)
if optional_data is None:
break
data = optional_data
# Don't parse incoming data any further after 1st request has completed.
#
# This specially does happen for pipeline requests.
#
# Plugins can utilize on_client_data for such cases and
# apply custom logic to handle request data sent after 1st
# valid request.
if data and self.request.state != httpParserStates.COMPLETE:
if self.request.state != httpParserStates.COMPLETE:
# Parse http request
#
# TODO(abhinavsingh): Remove .tobytes after parser is
Expand All @@ -229,6 +222,14 @@ def handle_data(self, data: memoryview) -> Optional[bool]:
plugin_.client._conn = upgraded_sock
elif isinstance(upgraded_sock, bool) and upgraded_sock is True:
return True
else:
# HttpProtocolHandlerPlugin.on_client_data
# Can raise HttpProtocolException to tear down the connection
for plugin in self.plugins.values():
optional_data = plugin.on_client_data(data)
if optional_data is None:
break
data = optional_data
except HttpProtocolException as e:
logger.debug('HttpProtocolException raised')
response: Optional[memoryview] = e.response(self.request)
Expand Down
10 changes: 6 additions & 4 deletions proxy/http/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import TypeVar, Optional, Dict, Type, Tuple, List

from ...common.constants import DEFAULT_DISABLE_HEADERS, COLON, DEFAULT_ENABLE_PROXY_PROTOCOL
from ...common.constants import HTTP_1_1, HTTP_1_0, SLASH, CRLF
from ...common.constants import HTTP_1_1, SLASH, CRLF
from ...common.constants import WHITESPACE, DEFAULT_HTTP_PORT
from ...common.utils import build_http_request, build_http_response, find_http_line, text_
from ...common.flag import flags
Expand Down Expand Up @@ -271,7 +271,7 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]:
self.body = self.chunk.body
self.state = httpParserStates.COMPLETE
more = False
elif b'content-length' in self.headers:
elif self.content_expected:
self.state = httpParserStates.RCVING_BODY
if self.body is None:
self.body = b''
Expand All @@ -283,13 +283,15 @@ def _process_body(self, raw: bytes) -> Tuple[bool, bytes]:
self.state = httpParserStates.COMPLETE
more, raw = len(raw) > 0, raw[total_size - received_size:]
else:
# HTTP/1.0 scenario only
assert self.version == HTTP_1_0
self.state = httpParserStates.RCVING_BODY
# Received a packet without content-length header
# and no transfer-encoding specified.
#
# This can happen for both HTTP/1.0 and HTTP/1.1 scenarios.
# Currently, we consume the remaining buffer as body.
#
# Ref https://github.com/abhinavsingh/proxy.py/issues/398
#
# See TestHttpParser.test_issue_398 scenario
self.body = raw
more, raw = False, b''
Expand Down
1 change: 1 addition & 0 deletions proxy/http/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ async def read_from_descriptors(self, r: Readables) -> bool:

@abstractmethod
def on_client_data(self, raw: memoryview) -> Optional[memoryview]:
"""Called only after original request has been completely received."""
return raw # pragma: no cover

@abstractmethod
Expand Down
2 changes: 2 additions & 0 deletions proxy/http/proxy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def handle_client_data(
Essentially, if you return None from within before_upstream_connection,
be prepared to handle_client_data and not handle_client_request.

Only called after initial request from client has been received.

Raise HttpRequestRejected to tear down the connection
Return None to drop the connection
"""
Expand Down
5 changes: 0 additions & 5 deletions proxy/http/proxy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,6 @@ def wrap_client(self) -> bool:
def emit_request_complete(self) -> None:
if not self.flags.enable_events:
return

assert self.request.port
self.event_queue.publish(
request_id=self.uid.hex,
Expand All @@ -924,7 +923,6 @@ def emit_request_complete(self) -> None:
def emit_response_events(self, chunk_size: int) -> None:
if not self.flags.enable_events:
return

if self.response.state == httpParserStates.COMPLETE:
self.emit_response_complete()
elif self.response.state == httpParserStates.RCVING_BODY:
Expand All @@ -935,7 +933,6 @@ def emit_response_events(self, chunk_size: int) -> None:
def emit_response_headers_complete(self) -> None:
if not self.flags.enable_events:
return

self.event_queue.publish(
request_id=self.uid.hex,
event_name=eventNames.RESPONSE_HEADERS_COMPLETE,
Expand All @@ -948,7 +945,6 @@ def emit_response_headers_complete(self) -> None:
def emit_response_chunk_received(self, chunk_size: int) -> None:
if not self.flags.enable_events:
return

self.event_queue.publish(
request_id=self.uid.hex,
event_name=eventNames.RESPONSE_CHUNK_RECEIVED,
Expand All @@ -962,7 +958,6 @@ def emit_response_chunk_received(self, chunk_size: int) -> None:
def emit_response_complete(self) -> None:
if not self.flags.enable_events:
return

self.event_queue.publish(
request_id=self.uid.hex,
event_name=eventNames.RESPONSE_COMPLETE,
Expand Down
25 changes: 17 additions & 8 deletions proxy/plugin/proxy_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
"""
import random
import logging
import ipaddress

from typing import Dict, List, Optional, Any

from ..common.flag import flags
from ..common.utils import text_

from ..http import Url, httpMethods
from ..http.parser import HttpParser
Expand Down Expand Up @@ -78,15 +80,22 @@ def before_upstream_connection(
) -> Optional[HttpParser]:
"""Avoids establishing the default connection to upstream server
by returning None.

TODO(abhinavsingh): Ideally connection to upstream proxy endpoints
must be bootstrapped within it's own re-usable and garbage collected pool,
to avoid establishing a new upstream proxy connection for each client request.

See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work
in progress for SSL cache handling.
"""
# TODO(abhinavsingh): Ideally connection to upstream proxy endpoints
# must be bootstrapped within it's own re-usable and gc'd pool, to avoid establishing
# a fresh upstream proxy connection for each client request.
#
# See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work
# in progress for SSL cache handling.
#
# Implement your own logic here e.g. round-robin, least connection etc.
# We don't want to send private IP requests to remote proxies
try:
if ipaddress.ip_address(text_(request.host)).is_private:
return request
except ValueError:
pass
# Choose a random proxy from the pool
# TODO: Implement your own logic here e.g. round-robin, least connection etc.
endpoint = random.choice(self.flags.proxy_pool)[0].split(':', 1)
if endpoint[0] == 'localhost' and endpoint[1] == '8899':
return request
Expand Down
5 changes: 2 additions & 3 deletions tests/http/test_http_proxy_tls_interception.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ async def asyncReturnBool(val: bool) -> bool:
# Assert our mocked plugins invocations
self.plugin.return_value.get_descriptors.assert_called()
self.plugin.return_value.write_to_descriptors.assert_called_with([])
self.plugin.return_value.on_client_data.assert_called_with(
connect_request,
)
# on_client_data is only called after initial request has completed
self.plugin.return_value.on_client_data.assert_not_called()
self.plugin.return_value.on_request_complete.assert_called()
self.plugin.return_value.read_from_descriptors.assert_called_with([
self._conn.fileno(),
Expand Down