From c3e666e93e910b2c4a87069942afe43c6f9f00f0 Mon Sep 17 00:00:00 2001 From: "assisted-by-ai (Bot Account)" Date: Tue, 2 Dec 2025 04:13:25 -0500 Subject: [PATCH] Enforce byte-accurate stdin limits for strip-markup --- .../strip_markup/strip_markup.py | 32 +++- .../strip_markup/strip_markup_lib.py | 14 +- .../strip_markup/tests/strip_markup.py | 170 ++++++++++++++++-- 3 files changed, 201 insertions(+), 15 deletions(-) diff --git a/usr/lib/python3/dist-packages/strip_markup/strip_markup.py b/usr/lib/python3/dist-packages/strip_markup/strip_markup.py index ff8c4b5..7bfe2b4 100644 --- a/usr/lib/python3/dist-packages/strip_markup/strip_markup.py +++ b/usr/lib/python3/dist-packages/strip_markup/strip_markup.py @@ -24,6 +24,9 @@ def print_usage() -> None: ) +MAX_INPUT_BYTES = 1024 * 1024 + + def main() -> int: """ Main function. @@ -52,13 +55,36 @@ def main() -> int: print_usage() return 1 untrusted_string = arg_list[0] + if len(untrusted_string.encode()) > MAX_INPUT_BYTES: + print( + f"strip-markup: input exceeds maximum size of {MAX_INPUT_BYTES} bytes.", + file=sys.stderr, + ) + return 1 ## Read untrusted_string from stdin if needed if untrusted_string is None: if sys.stdin is not None: - if "pytest" not in sys.modules: - sys.stdin.reconfigure(errors="ignore") # type: ignore - untrusted_string = sys.stdin.read() + if hasattr(sys.stdin, "buffer"): + raw_stdin = sys.stdin.buffer.read(MAX_INPUT_BYTES + 1) + if len(raw_stdin) > MAX_INPUT_BYTES: + print( + f"strip-markup: input exceeds maximum size of {MAX_INPUT_BYTES} bytes.", + file=sys.stderr, + ) + return 1 + encoding = getattr(sys.stdin, "encoding", None) or "utf-8" + untrusted_string = raw_stdin.decode(encoding, errors="ignore") + else: + if "pytest" not in sys.modules and hasattr(sys.stdin, "reconfigure"): + sys.stdin.reconfigure(errors="ignore") # type: ignore + untrusted_string = sys.stdin.read(MAX_INPUT_BYTES + 1) + if len(untrusted_string.encode()) > MAX_INPUT_BYTES: + print( + f"strip-markup: input exceeds maximum size of {MAX_INPUT_BYTES} bytes.", + file=sys.stderr, + ) + return 1 else: ## No way to get an untrusted string, print nothing and ## exit successfully diff --git a/usr/lib/python3/dist-packages/strip_markup/strip_markup_lib.py b/usr/lib/python3/dist-packages/strip_markup/strip_markup_lib.py index 29c3306..ea24a8e 100644 --- a/usr/lib/python3/dist-packages/strip_markup/strip_markup_lib.py +++ b/usr/lib/python3/dist-packages/strip_markup/strip_markup_lib.py @@ -42,6 +42,16 @@ def get_data(self) -> str: return self.text.getvalue() +def _strip_control_characters(untrusted_string: str) -> str: + """ + Remove control characters that could be used for terminal escapes. + """ + + return "".join( + char for char in untrusted_string if char.isprintable() or char in "\n\t" + ) + + def strip_markup(untrusted_string: str) -> str: """ Stripping function. @@ -49,10 +59,10 @@ def strip_markup(untrusted_string: str) -> str: markup_stripper: StripMarkupEngine = StripMarkupEngine() markup_stripper.feed(untrusted_string) - strip_one_string: str = markup_stripper.get_data() + strip_one_string: str = _strip_control_characters(markup_stripper.get_data()) markup_stripper = StripMarkupEngine() markup_stripper.feed(strip_one_string) - strip_two_string: str = markup_stripper.get_data() + strip_two_string: str = _strip_control_characters(markup_stripper.get_data()) if strip_one_string == strip_two_string: return strip_one_string diff --git a/usr/lib/python3/dist-packages/strip_markup/tests/strip_markup.py b/usr/lib/python3/dist-packages/strip_markup/tests/strip_markup.py index a97b1f6..426cce4 100644 --- a/usr/lib/python3/dist-packages/strip_markup/tests/strip_markup.py +++ b/usr/lib/python3/dist-packages/strip_markup/tests/strip_markup.py @@ -5,12 +5,12 @@ # pylint: disable=missing-module-docstring,fixme -import unittest +import io import sys -from io import StringIO +import unittest from typing import Callable from unittest import mock -from strip_markup.strip_markup import main as strip_markup_main +from strip_markup.strip_markup import MAX_INPUT_BYTES, main as strip_markup_main class TestStripMarkupBase(unittest.TestCase): @@ -36,8 +36,8 @@ def _test_args( """ args_arr: list[str] = [argv0, *args] - stdout_buf: StringIO = StringIO() - stderr_buf: StringIO = StringIO() + stdout_buf: io.StringIO = io.StringIO() + stderr_buf: io.StringIO = io.StringIO() with ( mock.patch.object(sys, "argv", args_arr), mock.patch.object(sys, "stdout", stdout_buf), @@ -50,6 +50,35 @@ def _test_args( stdout_buf.close() stderr_buf.close() + # pylint: disable=too-many-arguments,too-many-positional-arguments + def _test_args_failure( + self, + main_func: Callable[[], int], + argv0: str, + stdout_string: str, + stderr_string: str, + args: list[str], + exit_code: int = 1, + ) -> None: + """ + Executes the provided main function expecting failure output. + """ + + args_arr: list[str] = [argv0, *args] + stdout_buf: io.StringIO = io.StringIO() + stderr_buf: io.StringIO = io.StringIO() + with ( + mock.patch.object(sys, "argv", args_arr), + mock.patch.object(sys, "stdout", stdout_buf), + mock.patch.object(sys, "stderr", stderr_buf), + ): + returned_exit_code: int = main_func() + self.assertEqual(stdout_buf.getvalue(), stdout_string) + self.assertEqual(stderr_buf.getvalue(), stderr_string) + self.assertEqual(returned_exit_code, exit_code) + stdout_buf.close() + stderr_buf.close() + # pylint: disable=too-many-arguments,too-many-positional-arguments def _test_stdin( self, @@ -65,11 +94,12 @@ def _test_stdin( ensures its output matches an expected value. """ - stdout_buf: StringIO = StringIO() - stderr_buf: StringIO = StringIO() - stdin_buf: StringIO = StringIO() - stdin_buf.write(stdin_string) - stdin_buf.seek(0) + stdout_buf: io.StringIO = io.StringIO() + stderr_buf: io.StringIO = io.StringIO() + stdin_bytes: io.BytesIO = io.BytesIO() + stdin_bytes.write(stdin_string.encode("utf-8")) + stdin_bytes.seek(0) + stdin_buf: io.TextIOWrapper = io.TextIOWrapper(stdin_bytes, encoding="utf-8") args_arr: list[str] = [argv0, *args] with ( mock.patch.object(sys, "argv", args_arr), @@ -84,6 +114,44 @@ def _test_stdin( stdout_buf.close() stderr_buf.close() stdin_buf.close() + stdin_bytes.close() + + # pylint: disable=too-many-arguments,too-many-positional-arguments + def _test_stdin_failure( + self, + main_func: Callable[[], int], + argv0: str, + stdout_string: str, + stderr_string: str, + args: list[str], + stdin_string: str, + exit_code: int = 1, + ) -> None: + """ + Executes the provided main function expecting failure output when using stdin. + """ + + stdout_buf: io.StringIO = io.StringIO() + stderr_buf: io.StringIO = io.StringIO() + stdin_bytes: io.BytesIO = io.BytesIO() + stdin_bytes.write(stdin_string.encode("utf-8")) + stdin_bytes.seek(0) + stdin_buf: io.TextIOWrapper = io.TextIOWrapper(stdin_bytes, encoding="utf-8") + args_arr: list[str] = [argv0, *args] + with ( + mock.patch.object(sys, "argv", args_arr), + mock.patch.object(sys, "stdin", stdin_buf), + mock.patch.object(sys, "stdout", stdout_buf), + mock.patch.object(sys, "stderr", stderr_buf), + ): + returned_exit_code: int = main_func() + self.assertEqual(stdout_buf.getvalue(), stdout_string) + self.assertEqual(stderr_buf.getvalue(), stderr_string) + self.assertEqual(returned_exit_code, exit_code) + stdout_buf.close() + stderr_buf.close() + stdin_buf.close() + stdin_bytes.close() def _test_safe_strings( self, @@ -369,3 +437,85 @@ def test_malicious_markup_strings(self) -> None: """ self._test_malicious_markup_strings(strip_markup_main, self.argv0) + + def test_control_characters_are_removed(self) -> None: + """ + Ensure decoded control characters are stripped from output. + """ + + self._test_args( + main_func=strip_markup_main, + argv0=self.argv0, + stdout_string="control sequence", + stderr_string="", + args=["control  sequence"], + ) + + def test_carriage_return_is_removed(self) -> None: + """ + Ensure carriage returns cannot be used for overprinting attacks. + """ + + self._test_args( + main_func=strip_markup_main, + argv0=self.argv0, + stdout_string="line one overwritten", + stderr_string="", + args=["line one overwritten"], + ) + + def test_rejects_oversized_argument(self) -> None: + """ + Ensure overly large arguments are rejected to avoid excessive memory use. + """ + + oversized_input: str = "x" * (MAX_INPUT_BYTES + 1) + error_msg: str = ( + f"strip-markup: input exceeds maximum size of {MAX_INPUT_BYTES} bytes.\n" + ) + self._test_args_failure( + main_func=strip_markup_main, + argv0=self.argv0, + stdout_string="", + stderr_string=error_msg, + args=[oversized_input], + ) + + def test_rejects_oversized_stdin(self) -> None: + """ + Ensure overly large stdin payloads are rejected. + """ + + oversized_input: str = "y" * (MAX_INPUT_BYTES + 1) + error_msg: str = ( + f"strip-markup: input exceeds maximum size of {MAX_INPUT_BYTES} bytes.\n" + ) + self._test_stdin_failure( + main_func=strip_markup_main, + argv0=self.argv0, + stdout_string="", + stderr_string=error_msg, + args=[], + stdin_string=oversized_input, + ) + + def test_rejects_multibyte_oversized_stdin(self) -> None: + """ + Ensure byte-size enforcement applies when stdin includes multibyte characters. + """ + + multibyte_char: str = "€" + multibyte_oversized_input: str = multibyte_char * ( + (MAX_INPUT_BYTES // len(multibyte_char.encode("utf-8"))) + 1 + ) + error_msg: str = ( + f"strip-markup: input exceeds maximum size of {MAX_INPUT_BYTES} bytes.\n" + ) + self._test_stdin_failure( + main_func=strip_markup_main, + argv0=self.argv0, + stdout_string="", + stderr_string=error_msg, + args=[], + stdin_string=multibyte_oversized_input, + )