diff --git a/azure/functions/_http_wsgi.py b/azure/functions/_http_wsgi.py index f51317e1..51fa7a9d 100644 --- a/azure/functions/_http_wsgi.py +++ b/azure/functions/_http_wsgi.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import logging +import re + from io import BytesIO, StringIO from os import linesep from typing import Dict, List, Optional, Any @@ -17,6 +19,7 @@ def wsgi_encoding_dance(value): class WsgiRequest: _environ_cache: Optional[Dict[str, Any]] = None + _logger = logging.getLogger('azure.functions.WsgiMiddleware') def __init__(self, func_req: HttpRequest, @@ -113,7 +116,18 @@ def to_environ(self, errors_buffer: StringIO) -> Dict[str, Any]: def _get_port(self, parsed_url, lowercased_headers: Dict[str, str]) -> int: port: int = 80 if lowercased_headers.get('x-forwarded-port'): - return int(lowercased_headers['x-forwarded-port']) + # Split on commas in case of multiple proxy hops + parts = [p.strip() for p in lowercased_headers['x-forwarded-port'].split(',')] + + for part in parts: + # Extract leading number (port must start with digits) + match = re.match(r"(\d+)", part) + if match: + port = int(match.group(1)) + return port + # If no valid port found, log a warning + self._logger.warning("Invalid X-Forwarded-Port header value: %s. " + "Using default port 80", part) elif getattr(parsed_url, 'port', None): return int(parsed_url.port) elif parsed_url.scheme == 'https': diff --git a/tests/test_http_wsgi.py b/tests/test_http_wsgi.py index 224d35dd..d3fecece 100644 --- a/tests/test_http_wsgi.py +++ b/tests/test_http_wsgi.py @@ -100,6 +100,51 @@ def test_request_protocol_by_header(self): self.assertEqual(environ['SERVER_PORT'], str(8081)) self.assertEqual(environ['wsgi.url_scheme'], 'https') + def test_request_protocol_by_header_hostlike(self): + func_request = self._generate_func_request(headers={ + "x-forwarded-port": "443.example.com" + }) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['SERVER_PORT'], str(443)) + self.assertEqual(environ['wsgi.url_scheme'], 'https') + + def test_request_protocol_by_header_unusual_tokens(self): + func_request = self._generate_func_request(headers={ + "x-forwarded-port": "443;proto=https" + }) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['SERVER_PORT'], str(443)) + self.assertEqual(environ['wsgi.url_scheme'], 'https') + + def test_request_protocol_by_header_with_multiple_ports(self): + func_request = self._generate_func_request(headers={ + "x-forwarded-port": "443,8080,433" + }) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['SERVER_PORT'], str(443)) + self.assertEqual(environ['wsgi.url_scheme'], 'https') + + def test_request_protocol_by_header_with_spaces(self): + func_request = self._generate_func_request(headers={ + "x-forwarded-port": " 8443 , 8080 " + }) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['SERVER_PORT'], str(8443)) + self.assertEqual(environ['wsgi.url_scheme'], 'https') + + def test_request_protocol_by_header_invalid(self): + func_request = self._generate_func_request(headers={ + "x-forwarded-port": "abc" + }) + error_buffer = StringIO() + environ = WsgiRequest(func_request).to_environ(error_buffer) + self.assertEqual(environ['SERVER_PORT'], str(80)) + self.assertEqual(environ['wsgi.url_scheme'], 'https') + def test_request_protocol_by_scheme(self): func_request = self._generate_func_request(url="http://a.b.com") error_buffer = StringIO()