Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions compiler/cpp/src/thrift/generate/t_cpp_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3129,6 +3129,8 @@ void t_cpp_generator::generate_service_client(t_service* tservice, string style)
} else {
out << "recv_" << funname << "(" << seqIdUse << ");" << '\n';
}
} else {
out << indent() << _this << "oprot_->getTransport()->onewayComplete();" << '\n';
}
} else {
if (!(*f_iter)->is_oneway()) {
Expand Down Expand Up @@ -3889,7 +3891,7 @@ void t_cpp_generator::generate_process_function(t_service* tservice,
if (gen_templates_) {
out << indent() << "template <class Protocol_>" << '\n';
}
const bool unnamed_oprot_seqid = tfunction->is_oneway() && !(gen_templates_ && !specialized);
const bool unnamed_oprot_seqid = false;
out << "void " << tservice->get_name() << "Processor" << class_suffix << "::"
<< "process_" << tfunction->get_name() << "("
<< "int32_t" << (unnamed_oprot_seqid ? ", " : " seqid, ") << prot_type << "* iprot, "
Expand All @@ -3901,7 +3903,7 @@ void t_cpp_generator::generate_process_function(t_service* tservice,
string resultname = tservice->get_name() + "_" + tfunction->get_name() + "_result";

if (tfunction->is_oneway() && !unnamed_oprot_seqid) {
out << indent() << "(void) seqid;" << '\n' << indent() << "(void) oprot;" << '\n';
out << indent() << "(void) seqid;" << '\n';
}

out << indent() << "void* ctx = nullptr;" << '\n' << indent()
Expand Down Expand Up @@ -3994,6 +3996,8 @@ void t_cpp_generator::generate_process_function(t_service* tservice,
<< "x.write(oprot);" << '\n' << indent() << "oprot->writeMessageEnd();" << '\n'
<< indent() << "oprot->getTransport()->writeEnd();" << '\n' << indent()
<< "oprot->getTransport()->flush();" << '\n';
} else {
out << '\n' << indent() << "oprot->getTransport()->onewayComplete();" << '\n';
}
out << indent() << "return;" << '\n';
indent_down();
Expand All @@ -4003,7 +4007,8 @@ void t_cpp_generator::generate_process_function(t_service* tservice,
if (tfunction->is_oneway()) {
out << indent() << "if (this->eventHandler_.get() != nullptr) {" << '\n' << indent()
<< " this->eventHandler_->asyncComplete(ctx, " << service_func_name << ");" << '\n'
<< indent() << "}" << '\n' << '\n' << indent() << "return;" << '\n';
<< indent() << "}" << '\n' << '\n' << indent()
<< "oprot->getTransport()->onewayComplete();" << '\n' << indent() << "return;" << '\n';
indent_down();
out << "}" << '\n' << '\n';
return;
Expand Down
4 changes: 4 additions & 0 deletions lib/cpp/src/thrift/transport/TBufferTransports.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ class TBufferedTransport : public TVirtualTransport<TBufferedTransport, TBufferB

void flush() override;

void onewayComplete() override { transport_->onewayComplete(); }

/**
* Returns the origin of the underlying transport
*/
Expand Down Expand Up @@ -375,6 +377,8 @@ class TFramedTransport : public TVirtualTransport<TFramedTransport, TBufferBase>

void flush() override;

void onewayComplete() override { transport_->onewayComplete(); }

uint32_t readEnd() override;

uint32_t writeEnd() override;
Expand Down
2 changes: 2 additions & 0 deletions lib/cpp/src/thrift/transport/THeaderTransport.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class THeaderTransport : public TVirtualTransport<THeaderTransport, TFramedTrans
uint32_t readSlow(uint8_t* buf, uint32_t len) override;
void flush() override;

void onewayComplete() override { outTransport_->onewayComplete(); }

void resizeTransformBuffer(uint32_t additionalSize = 0);

uint16_t getProtocolId() const;
Expand Down
73 changes: 71 additions & 2 deletions lib/cpp/src/thrift/transport/THttpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/

#include <algorithm>
#include <limits>
#include <cstdlib>
#include <sstream>
Expand All @@ -38,18 +39,25 @@ THttpClient::THttpClient(std::shared_ptr<TTransport> transport,
std::shared_ptr<TConfiguration> config)
: THttpTransport(transport, config),
host_(host),
path_(path) {
path_(path),
onewayResponsePending_(false) {
}

THttpClient::THttpClient(string host, int port, string path,
std::shared_ptr<TConfiguration> config)
: THttpTransport(std::shared_ptr<TTransport>(new TSocket(host, port)), config),
host_(host),
path_(path) {
path_(path),
onewayResponsePending_(false) {
}

THttpClient::~THttpClient() = default;

void THttpClient::close() {
onewayResponsePending_ = false;
THttpTransport::close();
}

void THttpClient::parseHeader(char* header) {
char* colon = strchr(header, ':');
if (colon == nullptr) {
Expand Down Expand Up @@ -98,6 +106,19 @@ bool THttpClient::parseStatusLine(char* status) {

void THttpClient::flush() {
resetConsumedMessageSize();

if (onewayResponsePending_) {
if (transport_->isOpen()) {
drainPendingOnewayResponse();
} else {
onewayResponsePending_ = false;
}
}

if (!transport_->isOpen()) {
transport_->open();
}

// Fetch the contents of the write buffer
uint8_t* buf;
uint32_t len;
Expand All @@ -123,6 +144,54 @@ void THttpClient::flush() {
readHeaders_ = true;
}

void THttpClient::onewayComplete() {
onewayResponsePending_ = true;
}

void THttpClient::drainPendingOnewayResponse() {
readBuffer_.resetBuffer();

if (readHeaders_) {
readHeaders();
}

if (chunked_) {
while (!chunkedDone_) {
char* line = readLine();
uint32_t chunkSize = parseChunkSize(line);
if (chunkSize == 0) {
readChunkedFooters();
break;
}

discardResponseBody(chunkSize);
readLine();
}
} else {
discardResponseBody(contentLength_);
}

readHeaders_ = true;
onewayResponsePending_ = false;
}

void THttpClient::discardResponseBody(uint32_t size) {
uint32_t remaining = size;
while (remaining > 0) {
uint32_t avail = httpBufLen_ - httpPos_;
if (avail == 0) {
httpPos_ = 0;
httpBufLen_ = 0;
refill();
avail = httpBufLen_;
}

uint32_t give = (std::min)(remaining, avail);
httpPos_ += give;
remaining -= give;
}
}

void THttpClient::setPath(std::string path) {
path_ = path;
}
Expand Down
8 changes: 8 additions & 0 deletions lib/cpp/src/thrift/transport/THttpClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,24 @@ class THttpClient : public THttpTransport {

~THttpClient() override;

void close() override;

void flush() override;

void onewayComplete() override;

void setPath(std::string path);

protected:
std::string host_;
std::string path_;
bool onewayResponsePending_;

void parseHeader(char* header) override;
bool parseStatusLine(char* status) override;

void drainPendingOnewayResponse();
void discardResponseBody(uint32_t size);
};
}
}
Expand Down
4 changes: 4 additions & 0 deletions lib/cpp/src/thrift/transport/THttpServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ void THttpServer::flush() {
readHeaders_ = true;
}

void THttpServer::onewayComplete() {
flush();
}

std::string THttpServer::getHeader(uint32_t len) {
std::ostringstream h;
h << "HTTP/1.1 200 OK" << CRLF << "Date: " << getTimeRFC1123() << CRLF << "Server: Thrift/"
Expand Down
2 changes: 2 additions & 0 deletions lib/cpp/src/thrift/transport/THttpServer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class THttpServer : public THttpTransport {

void flush() override;

void onewayComplete() override;

protected:
virtual std::string getHeader(uint32_t len);
void readHeaders();
Expand Down
10 changes: 10 additions & 0 deletions lib/cpp/src/thrift/transport/TTransport.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ class TTransport {
// default behaviour is to do nothing
}

/**
* Called by generated code after a one-way method has been sent or handled.
* Transports that receive an out-of-band response can override this to mark
* it for discard without making one-way calls wait for a reply. Transports
* that must emit an out-of-band response can override this to send it.
*/
virtual void onewayComplete() {
// default behaviour is to do nothing
}

/**
* Attempts to return a pointer to \c len bytes, possibly copied into \c buf.
* Does not consume the bytes read (i.e.: a later read will return the same
Expand Down
2 changes: 2 additions & 0 deletions lib/cpp/src/thrift/transport/TZlibTransport.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class TZlibTransport : public TVirtualTransport<TZlibTransport> {

void flush() override;

void onewayComplete() override { transport_->onewayComplete(); }

/**
* Finalize the zlib stream.
*
Expand Down
50 changes: 50 additions & 0 deletions lib/cpp/test/OneWayHTTPTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ BOOST_AUTO_TEST_CASE( JSON_BufferedHTTP )
blockable_transport->unblock() ;
client.send_roundTripRPC() ;
blockable_transport->flush() ;
uint8_t discard;
BOOST_CHECK_EQUAL(transport->read(&discard, 1), 0U);
BOOST_CHECK_EQUAL(transport->read(&discard, 1), 0U);
try {
client.recv_roundTripRPC() ;
} catch (const TTransportException &e) {
Expand All @@ -236,4 +239,51 @@ BOOST_AUTO_TEST_CASE( JSON_BufferedHTTP )
#endif
}

BOOST_AUTO_TEST_CASE( JSON_HTTP_OneWayWrapperDoesNotPoisonNextCall )
{
std::shared_ptr<TServerSocket> ss = std::make_shared<TServerSocket>(0);
TThreadedServer server(
std::make_shared<onewaytest::OneWayServiceProcessorFactory>(
std::make_shared<OneWayServiceCloneFactory>()),
ss,
std::make_shared<THttpServerTransportFactory>(),
std::make_shared<TJSONProtocolFactory>());

std::shared_ptr<TServerReadyEventHandler> pEventHandler(new TServerReadyEventHandler);
server.setServerEventHandler(pEventHandler);

RPC0ThreadClass t(server);
boost::thread thread(&RPC0ThreadClass::Run, &t);

{
Synchronized sync(*(pEventHandler.get()));
while (!pEventHandler->isListening()) {
pEventHandler->wait();
}
}

{
std::shared_ptr<TSocket> socket(new TSocket("localhost", ss->getPort()));
socket->setRecvTimeout(10000);
std::shared_ptr<TTransport> transport(new THttpClient(socket, "localhost", "/service"));
std::shared_ptr<TProtocol> protocol(new TJSONProtocol(transport));
onewaytest::OneWayServiceClient client(protocol);

transport->open();
client.roundTripRPC();
BOOST_CHECK_EQUAL(pEventHandler->acceptedCount(), 1U);
client.oneWayRPC();
try {
client.roundTripRPC();
} catch (const std::exception& e) {
BOOST_ERROR("roundTripRPC after oneWayRPC failed: " + std::string(e.what()));
}
BOOST_CHECK_EQUAL(pEventHandler->acceptedCount(), 1U);
transport->close();
}

server.stop();
thread.join();
}

BOOST_AUTO_TEST_SUITE_END()
Loading
Loading