Skip to content

Commit

Permalink
THRIFT-3231 CPP: Limit recursion depth to 64
Browse files Browse the repository at this point in the history
Client: cpp
Patch: Ben Craig <bencraig@apache.org>
  • Loading branch information
Ben Craig committed Jul 9, 2015
1 parent 262cfb4 commit cfaadcc
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 112 deletions.
18 changes: 12 additions & 6 deletions compiler/cpp/src/generate/t_cpp_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1367,10 +1367,16 @@ void t_cpp_generator::generate_struct_reader(ofstream& out, t_struct* tstruct, b
vector<t_field*>::const_iterator f_iter;

// Declare stack tmp variables
out << endl << indent() << "uint32_t xfer = 0;" << endl << indent() << "std::string fname;"
<< endl << indent() << "::apache::thrift::protocol::TType ftype;" << endl << indent()
<< "int16_t fid;" << endl << endl << indent() << "xfer += iprot->readStructBegin(fname);"
<< endl << endl << indent() << "using ::apache::thrift::protocol::TProtocolException;" << endl
out << endl
<< indent() << "apache::thrift::protocol::TRecursionTracker tracker(*iprot);" << endl
<< indent() << "uint32_t xfer = 0;" << endl
<< indent() << "std::string fname;" << endl
<< indent() << "::apache::thrift::protocol::TType ftype;" << endl
<< indent() << "int16_t fid;" << endl
<< endl
<< indent() << "xfer += iprot->readStructBegin(fname);" << endl
<< endl
<< indent() << "using ::apache::thrift::protocol::TProtocolException;" << endl
<< endl;

// Required variables aren't in __isset, so we need tmp vars to check them.
Expand Down Expand Up @@ -1486,7 +1492,7 @@ void t_cpp_generator::generate_struct_writer(ofstream& out, t_struct* tstruct, b

out << indent() << "uint32_t xfer = 0;" << endl;

indent(out) << "oprot->incrementRecursionDepth();" << endl;
indent(out) << "apache::thrift::protocol::TRecursionTracker tracker(*oprot);" << endl;
indent(out) << "xfer += oprot->writeStructBegin(\"" << name << "\");" << endl;

for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
Expand Down Expand Up @@ -1522,7 +1528,7 @@ void t_cpp_generator::generate_struct_writer(ofstream& out, t_struct* tstruct, b
// Write the struct map
out << indent() << "xfer += oprot->writeFieldStop();" << endl << indent()
<< "xfer += oprot->writeStructEnd();" << endl << indent()
<< "oprot->decrementRecursionDepth();" << endl << indent() << "return xfer;" << endl;
<< "return xfer;" << endl;

indent_down();
indent(out) << "}" << endl << endl;
Expand Down
3 changes: 2 additions & 1 deletion lib/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ set( thriftcpp_SOURCES
src/thrift/concurrency/TimerManager.cpp
src/thrift/concurrency/Util.cpp
src/thrift/processor/PeekProcessor.cpp
src/thrift/protocol/TBase64Utils.cpp
src/thrift/protocol/TDebugProtocol.cpp
src/thrift/protocol/TDenseProtocol.cpp
src/thrift/protocol/TJSONProtocol.cpp
src/thrift/protocol/TBase64Utils.cpp
src/thrift/protocol/TMultiplexedProtocol.cpp
src/thrift/protocol/TProtocol.cpp
src/thrift/transport/TTransportException.cpp
src/thrift/transport/TFDTransport.cpp
src/thrift/transport/TSimpleFileTransport.cpp
Expand Down
1 change: 1 addition & 0 deletions lib/cpp/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ libthrift_la_SOURCES = src/thrift/TApplicationException.cpp \
src/thrift/protocol/TJSONProtocol.cpp \
src/thrift/protocol/TBase64Utils.cpp \
src/thrift/protocol/TMultiplexedProtocol.cpp \
src/thrift/protocol/TProtocol.cpp \
src/thrift/transport/TTransportException.cpp \
src/thrift/transport/TFDTransport.cpp \
src/thrift/transport/TFileTransport.cpp \
Expand Down
33 changes: 33 additions & 0 deletions lib/cpp/src/thrift/protocol/TProtocol.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <thrift/protocol/TProtocol.h>

namespace apache {
namespace thrift {
namespace protocol {

TProtocol::~TProtocol() {}
uint32_t TProtocol::skip_virt(TType type) {
return ::apache::thrift::protocol::skip(*this, type);
}

TProtocolFactory::~TProtocolFactory() {}

}}} // apache::thrift::protocol
223 changes: 118 additions & 105 deletions lib/cpp/src/thrift/protocol/TProtocol.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <string>
#include <map>
#include <vector>
#include <climits>

// Use this to get around strict aliasing rules.
// For example, uint64_t i = bitwise_cast<uint64_t>(returns_double());
Expand Down Expand Up @@ -199,105 +200,6 @@ enum TMessageType {
T_ONEWAY = 4
};


/**
* Helper template for implementing TProtocol::skip().
*
* Templatized to avoid having to make virtual function calls.
*/
template <class Protocol_>
uint32_t skip(Protocol_& prot, TType type) {
switch (type) {
case T_BOOL: {
bool boolv;
return prot.readBool(boolv);
}
case T_BYTE: {
int8_t bytev;
return prot.readByte(bytev);
}
case T_I16: {
int16_t i16;
return prot.readI16(i16);
}
case T_I32: {
int32_t i32;
return prot.readI32(i32);
}
case T_I64: {
int64_t i64;
return prot.readI64(i64);
}
case T_DOUBLE: {
double dub;
return prot.readDouble(dub);
}
case T_STRING: {
std::string str;
return prot.readBinary(str);
}
case T_STRUCT: {
uint32_t result = 0;
std::string name;
int16_t fid;
TType ftype;
result += prot.readStructBegin(name);
while (true) {
result += prot.readFieldBegin(name, ftype, fid);
if (ftype == T_STOP) {
break;
}
result += skip(prot, ftype);
result += prot.readFieldEnd();
}
result += prot.readStructEnd();
return result;
}
case T_MAP: {
uint32_t result = 0;
TType keyType;
TType valType;
uint32_t i, size;
result += prot.readMapBegin(keyType, valType, size);
for (i = 0; i < size; i++) {
result += skip(prot, keyType);
result += skip(prot, valType);
}
result += prot.readMapEnd();
return result;
}
case T_SET: {
uint32_t result = 0;
TType elemType;
uint32_t i, size;
result += prot.readSetBegin(elemType, size);
for (i = 0; i < size; i++) {
result += skip(prot, elemType);
}
result += prot.readSetEnd();
return result;
}
case T_LIST: {
uint32_t result = 0;
TType elemType;
uint32_t i, size;
result += prot.readListBegin(elemType, size);
for (i = 0; i < size; i++) {
result += skip(prot, elemType);
}
result += prot.readListEnd();
return result;
}
case T_STOP:
case T_VOID:
case T_U64:
case T_UTF8:
case T_UTF16:
break;
}
return 0;
}

static const uint32_t DEFAULT_RECURSION_LIMIT = 64;

/**
Expand All @@ -316,7 +218,7 @@ static const uint32_t DEFAULT_RECURSION_LIMIT = 64;
*/
class TProtocol {
public:
virtual ~TProtocol() {}
virtual ~TProtocol();

/**
* Writing functions.
Expand Down Expand Up @@ -641,7 +543,7 @@ class TProtocol {
T_VIRTUAL_CALL();
return skip_virt(type);
}
virtual uint32_t skip_virt(TType type) { return ::apache::thrift::protocol::skip(*this, type); }
virtual uint32_t skip_virt(TType type);

inline boost::shared_ptr<TTransport> getTransport() { return ptrans_; }

Expand All @@ -657,10 +559,13 @@ class TProtocol {
}

void decrementRecursionDepth() { --recursion_depth_; }
uint32_t getRecursionLimit() const {return recursion_limit_;}
void setRecurisionLimit(uint32_t depth) {recursion_limit_ = depth;}

protected:
TProtocol(boost::shared_ptr<TTransport> ptrans)
: ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT) {}
: ptrans_(ptrans), recursion_depth_(0), recursion_limit_(DEFAULT_RECURSION_LIMIT)
{}

boost::shared_ptr<TTransport> ptrans_;

Expand All @@ -677,7 +582,7 @@ class TProtocolFactory {
public:
TProtocolFactory() {}

virtual ~TProtocolFactory() {}
virtual ~TProtocolFactory();

virtual boost::shared_ptr<TProtocol> getProtocol(boost::shared_ptr<TTransport> trans) = 0;
};
Expand Down Expand Up @@ -712,8 +617,116 @@ struct TNetworkLittleEndian
static uint64_t fromWire64(uint64_t x) {return letohll(x);}
};

struct TRecursionTracker {
TProtocol &prot_;
TRecursionTracker(TProtocol &prot) : prot_(prot) {
prot_.incrementRecursionDepth();
}
~TRecursionTracker() {
prot_.decrementRecursionDepth();
}
};

/**
* Helper template for implementing TProtocol::skip().
*
* Templatized to avoid having to make virtual function calls.
*/
template <class Protocol_>
uint32_t skip(Protocol_& prot, TType type) {
TRecursionTracker tracker(prot);

switch (type) {
case T_BOOL: {
bool boolv;
return prot.readBool(boolv);
}
case T_BYTE: {
int8_t bytev;
return prot.readByte(bytev);
}
case T_I16: {
int16_t i16;
return prot.readI16(i16);
}
case T_I32: {
int32_t i32;
return prot.readI32(i32);
}
case T_I64: {
int64_t i64;
return prot.readI64(i64);
}
case T_DOUBLE: {
double dub;
return prot.readDouble(dub);
}
case T_STRING: {
std::string str;
return prot.readBinary(str);
}
case T_STRUCT: {
uint32_t result = 0;
std::string name;
int16_t fid;
TType ftype;
result += prot.readStructBegin(name);
while (true) {
result += prot.readFieldBegin(name, ftype, fid);
if (ftype == T_STOP) {
break;
}
result += skip(prot, ftype);
result += prot.readFieldEnd();
}
result += prot.readStructEnd();
return result;
}
case T_MAP: {
uint32_t result = 0;
TType keyType;
TType valType;
uint32_t i, size;
result += prot.readMapBegin(keyType, valType, size);
for (i = 0; i < size; i++) {
result += skip(prot, keyType);
result += skip(prot, valType);
}
result += prot.readMapEnd();
return result;
}
case T_SET: {
uint32_t result = 0;
TType elemType;
uint32_t i, size;
result += prot.readSetBegin(elemType, size);
for (i = 0; i < size; i++) {
result += skip(prot, elemType);
}
result += prot.readSetEnd();
return result;
}
case T_LIST: {
uint32_t result = 0;
TType elemType;
uint32_t i, size;
result += prot.readListBegin(elemType, size);
for (i = 0; i < size; i++) {
result += skip(prot, elemType);
}
result += prot.readListEnd();
return result;
}
case T_STOP:
case T_VOID:
case T_U64:
case T_UTF8:
case T_UTF16:
break;
}
return 0;
}
}
} // apache::thrift::protocol

}}} // apache::thrift::protocol

#endif // #define _THRIFT_PROTOCOL_TPROTOCOL_H_ 1

0 comments on commit cfaadcc

Please sign in to comment.