diff --git a/compiler/cpp/src/thrift/generate/t_cpp_generator.cc b/compiler/cpp/src/thrift/generate/t_cpp_generator.cc index c7702a570c7..188a06b5165 100644 --- a/compiler/cpp/src/thrift/generate/t_cpp_generator.cc +++ b/compiler/cpp/src/thrift/generate/t_cpp_generator.cc @@ -3129,6 +3129,8 @@ void t_cpp_generator::generate_service_client(t_service* tservice, string style) } else { out << "recv_" << funname << "(" << seqIdUse << ");" << '\n'; } + } else { + out << indent() << _this << "oprot_->getTransport()->onewayComplete();" << '\n'; } } else { if (!(*f_iter)->is_oneway()) { @@ -3889,7 +3891,7 @@ void t_cpp_generator::generate_process_function(t_service* tservice, if (gen_templates_) { out << indent() << "template " << '\n'; } - const bool unnamed_oprot_seqid = tfunction->is_oneway() && !(gen_templates_ && !specialized); + const bool unnamed_oprot_seqid = false; out << "void " << tservice->get_name() << "Processor" << class_suffix << "::" << "process_" << tfunction->get_name() << "(" << "int32_t" << (unnamed_oprot_seqid ? ", " : " seqid, ") << prot_type << "* iprot, " @@ -3901,7 +3903,7 @@ void t_cpp_generator::generate_process_function(t_service* tservice, string resultname = tservice->get_name() + "_" + tfunction->get_name() + "_result"; if (tfunction->is_oneway() && !unnamed_oprot_seqid) { - out << indent() << "(void) seqid;" << '\n' << indent() << "(void) oprot;" << '\n'; + out << indent() << "(void) seqid;" << '\n'; } out << indent() << "void* ctx = nullptr;" << '\n' << indent() @@ -3994,6 +3996,8 @@ void t_cpp_generator::generate_process_function(t_service* tservice, << "x.write(oprot);" << '\n' << indent() << "oprot->writeMessageEnd();" << '\n' << indent() << "oprot->getTransport()->writeEnd();" << '\n' << indent() << "oprot->getTransport()->flush();" << '\n'; + } else { + out << '\n' << indent() << "oprot->getTransport()->onewayComplete();" << '\n'; } out << indent() << "return;" << '\n'; indent_down(); @@ -4003,7 +4007,8 @@ void t_cpp_generator::generate_process_function(t_service* tservice, if (tfunction->is_oneway()) { out << indent() << "if (this->eventHandler_.get() != nullptr) {" << '\n' << indent() << " this->eventHandler_->asyncComplete(ctx, " << service_func_name << ");" << '\n' - << indent() << "}" << '\n' << '\n' << indent() << "return;" << '\n'; + << indent() << "}" << '\n' << '\n' << indent() + << "oprot->getTransport()->onewayComplete();" << '\n' << indent() << "return;" << '\n'; indent_down(); out << "}" << '\n' << '\n'; return; diff --git a/lib/cpp/src/thrift/transport/TBufferTransports.h b/lib/cpp/src/thrift/transport/TBufferTransports.h index 40e2e6b6162..09633c19ba4 100644 --- a/lib/cpp/src/thrift/transport/TBufferTransports.h +++ b/lib/cpp/src/thrift/transport/TBufferTransports.h @@ -243,6 +243,8 @@ class TBufferedTransport : public TVirtualTransportonewayComplete(); } + /** * Returns the origin of the underlying transport */ @@ -375,6 +377,8 @@ class TFramedTransport : public TVirtualTransport void flush() override; + void onewayComplete() override { transport_->onewayComplete(); } + uint32_t readEnd() override; uint32_t writeEnd() override; diff --git a/lib/cpp/src/thrift/transport/THeaderTransport.h b/lib/cpp/src/thrift/transport/THeaderTransport.h index a6ffd09a157..24f7573916a 100644 --- a/lib/cpp/src/thrift/transport/THeaderTransport.h +++ b/lib/cpp/src/thrift/transport/THeaderTransport.h @@ -105,6 +105,8 @@ class THeaderTransport : public TVirtualTransportonewayComplete(); } + void resizeTransformBuffer(uint32_t additionalSize = 0); uint16_t getProtocolId() const; diff --git a/lib/cpp/src/thrift/transport/THttpClient.cpp b/lib/cpp/src/thrift/transport/THttpClient.cpp index ea2eb99af6e..3cc83345018 100644 --- a/lib/cpp/src/thrift/transport/THttpClient.cpp +++ b/lib/cpp/src/thrift/transport/THttpClient.cpp @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -38,18 +39,25 @@ THttpClient::THttpClient(std::shared_ptr transport, std::shared_ptr config) : THttpTransport(transport, config), host_(host), - path_(path) { + path_(path), + onewayResponsePending_(false) { } THttpClient::THttpClient(string host, int port, string path, std::shared_ptr config) : THttpTransport(std::shared_ptr(new TSocket(host, port)), config), host_(host), - path_(path) { + path_(path), + onewayResponsePending_(false) { } THttpClient::~THttpClient() = default; +void THttpClient::close() { + onewayResponsePending_ = false; + THttpTransport::close(); +} + void THttpClient::parseHeader(char* header) { char* colon = strchr(header, ':'); if (colon == nullptr) { @@ -98,6 +106,19 @@ bool THttpClient::parseStatusLine(char* status) { void THttpClient::flush() { resetConsumedMessageSize(); + + if (onewayResponsePending_) { + if (transport_->isOpen()) { + drainPendingOnewayResponse(); + } else { + onewayResponsePending_ = false; + } + } + + if (!transport_->isOpen()) { + transport_->open(); + } + // Fetch the contents of the write buffer uint8_t* buf; uint32_t len; @@ -123,6 +144,54 @@ void THttpClient::flush() { readHeaders_ = true; } +void THttpClient::onewayComplete() { + onewayResponsePending_ = true; +} + +void THttpClient::drainPendingOnewayResponse() { + readBuffer_.resetBuffer(); + + if (readHeaders_) { + readHeaders(); + } + + if (chunked_) { + while (!chunkedDone_) { + char* line = readLine(); + uint32_t chunkSize = parseChunkSize(line); + if (chunkSize == 0) { + readChunkedFooters(); + break; + } + + discardResponseBody(chunkSize); + readLine(); + } + } else { + discardResponseBody(contentLength_); + } + + readHeaders_ = true; + onewayResponsePending_ = false; +} + +void THttpClient::discardResponseBody(uint32_t size) { + uint32_t remaining = size; + while (remaining > 0) { + uint32_t avail = httpBufLen_ - httpPos_; + if (avail == 0) { + httpPos_ = 0; + httpBufLen_ = 0; + refill(); + avail = httpBufLen_; + } + + uint32_t give = (std::min)(remaining, avail); + httpPos_ += give; + remaining -= give; + } +} + void THttpClient::setPath(std::string path) { path_ = path; } diff --git a/lib/cpp/src/thrift/transport/THttpClient.h b/lib/cpp/src/thrift/transport/THttpClient.h index f0d7e8b2706..b8c14c4a6e2 100644 --- a/lib/cpp/src/thrift/transport/THttpClient.h +++ b/lib/cpp/src/thrift/transport/THttpClient.h @@ -53,16 +53,24 @@ class THttpClient : public THttpTransport { ~THttpClient() override; + void close() override; + void flush() override; + void onewayComplete() override; + void setPath(std::string path); protected: std::string host_; std::string path_; + bool onewayResponsePending_; void parseHeader(char* header) override; bool parseStatusLine(char* status) override; + + void drainPendingOnewayResponse(); + void discardResponseBody(uint32_t size); }; } } diff --git a/lib/cpp/src/thrift/transport/THttpServer.cpp b/lib/cpp/src/thrift/transport/THttpServer.cpp index ace59cc9636..b6a9908c440 100644 --- a/lib/cpp/src/thrift/transport/THttpServer.cpp +++ b/lib/cpp/src/thrift/transport/THttpServer.cpp @@ -140,6 +140,10 @@ void THttpServer::flush() { readHeaders_ = true; } +void THttpServer::onewayComplete() { + flush(); +} + std::string THttpServer::getHeader(uint32_t len) { std::ostringstream h; h << "HTTP/1.1 200 OK" << CRLF << "Date: " << getTimeRFC1123() << CRLF << "Server: Thrift/" diff --git a/lib/cpp/src/thrift/transport/THttpServer.h b/lib/cpp/src/thrift/transport/THttpServer.h index bc98986d7cf..ad7a67b7d71 100644 --- a/lib/cpp/src/thrift/transport/THttpServer.h +++ b/lib/cpp/src/thrift/transport/THttpServer.h @@ -34,6 +34,8 @@ class THttpServer : public THttpTransport { void flush() override; + void onewayComplete() override; + protected: virtual std::string getHeader(uint32_t len); void readHeaders(); diff --git a/lib/cpp/src/thrift/transport/TTransport.h b/lib/cpp/src/thrift/transport/TTransport.h index 1158bcf04d2..02385418d7e 100644 --- a/lib/cpp/src/thrift/transport/TTransport.h +++ b/lib/cpp/src/thrift/transport/TTransport.h @@ -188,6 +188,16 @@ class TTransport { // default behaviour is to do nothing } + /** + * Called by generated code after a one-way method has been sent or handled. + * Transports that receive an out-of-band response can override this to mark + * it for discard without making one-way calls wait for a reply. Transports + * that must emit an out-of-band response can override this to send it. + */ + virtual void onewayComplete() { + // default behaviour is to do nothing + } + /** * Attempts to return a pointer to \c len bytes, possibly copied into \c buf. * Does not consume the bytes read (i.e.: a later read will return the same diff --git a/lib/cpp/src/thrift/transport/TZlibTransport.h b/lib/cpp/src/thrift/transport/TZlibTransport.h index 85765e6be74..f3a41147693 100644 --- a/lib/cpp/src/thrift/transport/TZlibTransport.h +++ b/lib/cpp/src/thrift/transport/TZlibTransport.h @@ -153,6 +153,8 @@ class TZlibTransport : public TVirtualTransport { void flush() override; + void onewayComplete() override { transport_->onewayComplete(); } + /** * Finalize the zlib stream. * diff --git a/lib/cpp/test/OneWayHTTPTest.cpp b/lib/cpp/test/OneWayHTTPTest.cpp index 8789c2c8ffb..2c545678d59 100644 --- a/lib/cpp/test/OneWayHTTPTest.cpp +++ b/lib/cpp/test/OneWayHTTPTest.cpp @@ -222,6 +222,9 @@ BOOST_AUTO_TEST_CASE( JSON_BufferedHTTP ) blockable_transport->unblock() ; client.send_roundTripRPC() ; blockable_transport->flush() ; + uint8_t discard; + BOOST_CHECK_EQUAL(transport->read(&discard, 1), 0U); + BOOST_CHECK_EQUAL(transport->read(&discard, 1), 0U); try { client.recv_roundTripRPC() ; } catch (const TTransportException &e) { @@ -236,4 +239,51 @@ BOOST_AUTO_TEST_CASE( JSON_BufferedHTTP ) #endif } +BOOST_AUTO_TEST_CASE( JSON_HTTP_OneWayWrapperDoesNotPoisonNextCall ) +{ + std::shared_ptr ss = std::make_shared(0); + TThreadedServer server( + std::make_shared( + std::make_shared()), + ss, + std::make_shared(), + std::make_shared()); + + std::shared_ptr pEventHandler(new TServerReadyEventHandler); + server.setServerEventHandler(pEventHandler); + + RPC0ThreadClass t(server); + boost::thread thread(&RPC0ThreadClass::Run, &t); + + { + Synchronized sync(*(pEventHandler.get())); + while (!pEventHandler->isListening()) { + pEventHandler->wait(); + } + } + + { + std::shared_ptr socket(new TSocket("localhost", ss->getPort())); + socket->setRecvTimeout(10000); + std::shared_ptr transport(new THttpClient(socket, "localhost", "/service")); + std::shared_ptr protocol(new TJSONProtocol(transport)); + onewaytest::OneWayServiceClient client(protocol); + + transport->open(); + client.roundTripRPC(); + BOOST_CHECK_EQUAL(pEventHandler->acceptedCount(), 1U); + client.oneWayRPC(); + try { + client.roundTripRPC(); + } catch (const std::exception& e) { + BOOST_ERROR("roundTripRPC after oneWayRPC failed: " + std::string(e.what())); + } + BOOST_CHECK_EQUAL(pEventHandler->acceptedCount(), 1U); + transport->close(); + } + + server.stop(); + thread.join(); +} + BOOST_AUTO_TEST_SUITE_END() diff --git a/test/known_failures_Linux.json b/test/known_failures_Linux.json index e4cd98bd963..c85c25bc185 100644 --- a/test/known_failures_Linux.json +++ b/test/known_failures_Linux.json @@ -29,21 +29,13 @@ "cpp-dart_multi-binary_http-ip", "cpp-dart_multic-compact_http-ip", "cpp-dart_multij-json_http-ip", - "cpp-go_binary_http-ip", "cpp-go_binary_http-ip-ssl", - "cpp-go_compact_http-ip", "cpp-go_compact_http-ip-ssl", - "cpp-go_header_http-ip", "cpp-go_header_http-ip-ssl", - "cpp-go_json_http-ip", "cpp-go_json_http-ip-ssl", - "cpp-go_multi-binary_http-ip", "cpp-go_multi-binary_http-ip-ssl", - "cpp-go_multic-compact_http-ip", "cpp-go_multic-compact_http-ip-ssl", - "cpp-go_multih-header_http-ip", "cpp-go_multih-header_http-ip-ssl", - "cpp-go_multij-json_http-ip", "cpp-go_multij-json_http-ip-ssl", "cpp-java_binary_http-ip", "cpp-java_binary_http-ip-ssl", @@ -124,58 +116,40 @@ "cpp-nodejs_multij-json_http-ip-ssl", "cpp-nodejs_multij-json_websocket-domain", "cpp-py_binary-accel_http-domain", - "cpp-py_binary-accel_http-ip", "cpp-py_binary-accel_http-ip-ssl", "cpp-py_binary_http-domain", - "cpp-py_binary_http-ip", "cpp-py_binary_http-ip-ssl", "cpp-py_compact-accelc_http-domain", - "cpp-py_compact-accelc_http-ip", "cpp-py_compact-accelc_http-ip-ssl", "cpp-py_compact_http-domain", - "cpp-py_compact_http-ip", "cpp-py_compact_http-ip-ssl", "cpp-py_header_http-domain", - "cpp-py_header_http-ip", "cpp-py_header_http-ip-ssl", "cpp-py_json_http-domain", - "cpp-py_json_http-ip", "cpp-py_json_http-ip-ssl", "cpp-py_multi-accel_http-domain", - "cpp-py_multi-accel_http-ip", "cpp-py_multi-accel_http-ip-ssl", "cpp-py_multi-binary_http-domain", - "cpp-py_multi-binary_http-ip", "cpp-py_multi-binary_http-ip-ssl", "cpp-py_multi-multia_http-domain", - "cpp-py_multi-multia_http-ip", "cpp-py_multi-multia_http-ip-ssl", "cpp-py_multi_http-domain", - "cpp-py_multi_http-ip", "cpp-py_multi_http-ip-ssl", "cpp-py_multic-accelc_http-domain", - "cpp-py_multic-accelc_http-ip", "cpp-py_multic-accelc_http-ip-ssl", "cpp-py_multic-compact_http-domain", - "cpp-py_multic-compact_http-ip", "cpp-py_multic-compact_http-ip-ssl", "cpp-py_multic-multiac_http-domain", - "cpp-py_multic-multiac_http-ip", "cpp-py_multic-multiac_http-ip-ssl", "cpp-py_multic_http-domain", - "cpp-py_multic_http-ip", "cpp-py_multic_http-ip-ssl", "cpp-py_multih-header_http-domain", - "cpp-py_multih-header_http-ip", "cpp-py_multih-header_http-ip-ssl", "cpp-py_multih_http-domain", - "cpp-py_multih_http-ip", "cpp-py_multih_http-ip-ssl", "cpp-py_multij-json_http-domain", - "cpp-py_multij-json_http-ip", "cpp-py_multij-json_http-ip-ssl", "cpp-py_multij_http-domain", - "cpp-py_multij_http-ip", "cpp-py_multij_http-ip-ssl", "d-cl_binary_buffered-ip", "d-cl_binary_framed-ip", @@ -306,13 +280,10 @@ "erl-nodejs_binary_buffered-ip", "erl-nodejs_compact_buffered-ip", "erl-nodets_binary_buffered-ip", - "go-cpp_binary_http-ip", "go-cpp_binary_http-ip-ssl", - "go-cpp_compact_http-ip", "go-cpp_compact_http-ip-ssl", "go-cpp_header_http-ip", "go-cpp_header_http-ip-ssl", - "go-cpp_json_http-ip", "go-cpp_json_http-ip-ssl", "go-d_binary_http-ip", "go-d_binary_http-ip-ssl",