From 95ba7e767e89ea5d090a8a581fec23bf7cca4618 Mon Sep 17 00:00:00 2001 From: Jens Geyer Date: Wed, 22 Apr 2026 01:05:00 +0200 Subject: [PATCH] Add default recursion depth limit to TProtocol.skip() Client: py C++ and Go already enforce a limit of 64 recursive skip() calls via TInputRecursionTracker and maxDepth respectively. This change adds an equivalent max_depth=64 parameter to the Python skip() implementation in TProtocolBase, raising TProtocolException(DEPTH_LIMIT) when the limit is exceeded. Co-Authored-By: Claude Sonnet 4.6 --- lib/py/src/protocol/TProtocol.py | 15 +++++++----- lib/py/test/thrift_TBinaryProtocol.py | 33 +++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 6 deletions(-) diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py index 975cbf5915e..a32e7778721 100644 --- a/lib/py/src/protocol/TProtocol.py +++ b/lib/py/src/protocol/TProtocol.py @@ -186,7 +186,10 @@ def readBinary(self): def readUuid(self): pass - def skip(self, ttype): + def skip(self, ttype, max_depth=64): + if max_depth <= 0: + raise TProtocolException(TProtocolException.DEPTH_LIMIT, + "Maximum skip depth exceeded") if ttype == TType.BOOL: self.readBool() elif ttype == TType.BYTE: @@ -207,24 +210,24 @@ def skip(self, ttype): (name, ttype, id) = self.readFieldBegin() if ttype == TType.STOP: break - self.skip(ttype) + self.skip(ttype, max_depth - 1) self.readFieldEnd() self.readStructEnd() elif ttype == TType.MAP: (ktype, vtype, size) = self.readMapBegin() for i in range(size): - self.skip(ktype) - self.skip(vtype) + self.skip(ktype, max_depth - 1) + self.skip(vtype, max_depth - 1) self.readMapEnd() elif ttype == TType.SET: (etype, size) = self.readSetBegin() for i in range(size): - self.skip(etype) + self.skip(etype, max_depth - 1) self.readSetEnd() elif ttype == TType.LIST: (etype, size) = self.readListBegin() for i in range(size): - self.skip(etype) + self.skip(etype, max_depth - 1) self.readListEnd() elif ttype == TType.UUID: self.readUuid() diff --git a/lib/py/test/thrift_TBinaryProtocol.py b/lib/py/test/thrift_TBinaryProtocol.py index e84bfe1e846..d4269eb6175 100644 --- a/lib/py/test/thrift_TBinaryProtocol.py +++ b/lib/py/test/thrift_TBinaryProtocol.py @@ -17,11 +17,13 @@ # under the License. # +import struct import unittest import uuid import _import_local_thrift # noqa from thrift.protocol.TBinaryProtocol import TBinaryProtocol +from thrift.protocol.TProtocol import TProtocolException from thrift.transport import TTransport @@ -297,5 +299,36 @@ def test_TBinaryProtocol_no_strict_write_read(self): raise e +def _craft_nested_structs(depth): + buf = bytearray() + for _ in range(depth): + buf += bytes([0x0c]) # TType.STRUCT = 12 + buf += struct.pack('>h', 1) # field ID 1 + for _ in range(depth + 1): + buf += bytes([0x00]) # STOP per level + innermost + return bytes(buf) + + +class TestSkipDepthLimit(unittest.TestCase): + + def _make_proto(self, payload): + trans = TTransport.TMemoryBuffer(payload) + return TBinaryProtocol(trans) + + def test_skip_rejects_deeply_nested_struct(self): + from thrift.Thrift import TType + payload = _craft_nested_structs(64) + proto = self._make_proto(payload) + with self.assertRaises(TProtocolException) as ctx: + proto.skip(TType.STRUCT) + self.assertEqual(ctx.exception.type, TProtocolException.DEPTH_LIMIT) + + def test_skip_accepts_struct_within_depth_limit(self): + from thrift.Thrift import TType + payload = _craft_nested_structs(63) + proto = self._make_proto(payload) + proto.skip(TType.STRUCT) # must not raise + + if __name__ == '__main__': unittest.main()