From 7747911e26d002051b75d6fa467d80878e18b1d1 Mon Sep 17 00:00:00 2001 From: Varun Sharma Date: Mon, 18 Apr 2022 18:08:11 -0700 Subject: [PATCH] Add ServerMetadata API Update the tfzendnn examples --- examples/cpp/tf_zendnn_client.cpp | 26 ++++++++--------- examples/python/tf_zendnn.py | 43 ++++++++++++++++------------ include/proteus/clients/client.hpp | 1 + include/proteus/clients/grpc.hpp | 1 + include/proteus/clients/native.hpp | 1 + include/proteus/core/predict_api.hpp | 7 +++++ src/proteus/clients/grpc.cpp | 19 ++++++++++++ src/proteus/clients/native.cpp | 17 +++++++++++ src/proteus/servers/grpc_server.cpp | 17 ++++++----- src/proteus/servers/http_server.cpp | 19 ++++++------ src/python/src/proteus/rest.py | 18 ++++++++++++ 11 files changed, 117 insertions(+), 52 deletions(-) diff --git a/examples/cpp/tf_zendnn_client.cpp b/examples/cpp/tf_zendnn_client.cpp index b7e7c6e58..1a21b9003 100644 --- a/examples/cpp/tf_zendnn_client.cpp +++ b/examples/cpp/tf_zendnn_client.cpp @@ -248,6 +248,14 @@ int main() { // initialize the server proteus::initialize(); + auto client = proteus::NativeClient(); + auto metadata = client.serverMetadata(); + if (metadata.extensions.find("tfzendnn") == metadata.extensions.end()) { + std::cout << "TFZenDNN support required but not found.\n"; + proteus::terminate(); + exit(0); + } + // load worker with required parameters proteus::RequestParameters parameters; // parameters.put("max_buffer_num", options.batch_size); @@ -257,7 +265,7 @@ int main() { parameters.put("input_size", options.input_size); parameters.put("inter_op", options.inter_op); parameters.put("intra_op", options.intra_op); - auto workerName = proteus::load("TfZendnn", ¶meters); + auto workerName = client.modelLoad("TfZendnn", ¶meters); float time_tmp = 0.f; // prepare images for inference @@ -274,11 +282,7 @@ int main() { auto start = std::chrono::high_resolution_clock::now(); // Timing the start request.addInputTensor(static_cast(images[0].data()), shape, proteus::types::DataType::FP32); - queue.push(proteus::enqueue(workerName, request)); - - auto front = std::move(queue.front()); - queue.pop(); - auto results = front.get(); + auto results = client.modelInfer(workerName, request); // Timing the prediction auto stop = std::chrono::high_resolution_clock::now(); auto duration = @@ -323,10 +327,7 @@ int main() { request.addInputTensor(static_cast(images[i].data()), shape, proteus::types::DataType::FP32); } - queue.push(proteus::enqueue(workerName, request)); - auto front = std::move(queue.front()); - queue.pop(); - auto results = front.get(); + auto results = client.modelInfer(workerName, request); } // Running for `steps` number of time for proper benchmarking @@ -338,10 +339,7 @@ int main() { request.addInputTensor(static_cast(images[i].data()), shape, proteus::types::DataType::FP32); } - queue.push(proteus::enqueue(workerName, request)); - auto front = std::move(queue.front()); - queue.pop(); - auto results = front.get(); + auto results = client.modelInfer(workerName, request); } // Timing the prediction auto stop = std::chrono::high_resolution_clock::now(); diff --git a/examples/python/tf_zendnn.py b/examples/python/tf_zendnn.py index 2c2c05e74..589b1af39 100644 --- a/examples/python/tf_zendnn.py +++ b/examples/python/tf_zendnn.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -# fmt: off +import argparse import os +import sys import time -import argparse import numpy as np import proteus from utils.utils import preprocess, postprocess -# fmt: on def main(args): @@ -41,6 +40,29 @@ def main(args): FileNotFoundError: If image_location is given and the file is not found """ + # Create server objects + server = proteus.Server() + client = proteus.RestClient("0.0.0.0:8998", None) + + # Start server: if it's not already started, start it here + try: + start_server = not client.server_live() + print("Server already up") + except proteus.ConnectionError: + start_server = True + if start_server: + print("Starting server") + server.start(quiet=True) + client.wait_until_live() + + if not client.has_extension("tfzendnn"): + print("TFZenDNN support required but not found.") + if start_server: + print("Closing server") + server.stop() + client.wait_until_stop() + sys.exit(0) + # Argument parsing real_data = True if args.image_location else False batch_size = args.batch_size @@ -70,21 +92,6 @@ def main(args): classes = list(range(1000)) classes = np.asarray(classes) - # Create server objects - server = proteus.Server() - client = proteus.RestClient("0.0.0.0:8998", None) - - # Start server: if it's not already started, start it here - try: - start_server = not client.server_live() - print("Server already up") - except proteus.ConnectionError: - start_server = True - if start_server: - print("Starting server") - server.start(quiet=True) - client.wait_until_live() - # Load the worker with appropriate paramaters parameters = { "model": args.graph, diff --git a/include/proteus/clients/client.hpp b/include/proteus/clients/client.hpp index b174dd9aa..36e7eebed 100644 --- a/include/proteus/clients/client.hpp +++ b/include/proteus/clients/client.hpp @@ -30,6 +30,7 @@ class Client { public: virtual ~Client() = default; + virtual ServerMetadata serverMetadata() = 0; virtual bool serverLive() = 0; virtual bool serverReady() = 0; virtual bool modelReady(const std::string& model) = 0; diff --git a/include/proteus/clients/grpc.hpp b/include/proteus/clients/grpc.hpp index a2649ae29..615c1532b 100644 --- a/include/proteus/clients/grpc.hpp +++ b/include/proteus/clients/grpc.hpp @@ -49,6 +49,7 @@ class GrpcClient : public Client { GrpcClient(const std::shared_ptr<::grpc::Channel>& channel); ~GrpcClient(); + ServerMetadata serverMetadata() override; bool serverLive() override; bool serverReady() override; bool modelReady(const std::string& model) override; diff --git a/include/proteus/clients/native.hpp b/include/proteus/clients/native.hpp index 748c5bdbd..06e4f1c9b 100644 --- a/include/proteus/clients/native.hpp +++ b/include/proteus/clients/native.hpp @@ -54,6 +54,7 @@ class NativeClient : public Client { public: ~NativeClient(); + ServerMetadata serverMetadata() override; bool serverLive() override; bool serverReady() override; diff --git a/include/proteus/core/predict_api.hpp b/include/proteus/core/predict_api.hpp index ee8939718..313d787c8 100644 --- a/include/proteus/core/predict_api.hpp +++ b/include/proteus/core/predict_api.hpp @@ -28,6 +28,7 @@ #include // for map, operator==, map<>::... #include // for shared_ptr, allocator #include // for operator<<, ostream, bas... +#include // for set #include // for string, operator<<, char... #include // for string_view #include // for move @@ -137,6 +138,12 @@ class RequestParameters { using RequestParametersPtr = std::shared_ptr; +struct ServerMetadata { + std::string name; + std::string version; + std::set extensions; +}; + /** * @brief Holds an inference request's input data */ diff --git a/src/proteus/clients/grpc.cpp b/src/proteus/clients/grpc.cpp index a7f24654a..f85a541d8 100644 --- a/src/proteus/clients/grpc.cpp +++ b/src/proteus/clients/grpc.cpp @@ -29,6 +29,7 @@ #include // for operator<<, cout #include // for map #include // for make_shared, reinter... +#include // for set #include // for runtime_error #include // for string, operator+ #include // for move @@ -71,6 +72,24 @@ GrpcClient::GrpcClient(const std::shared_ptr<::grpc::Channel>& channel) { GrpcClient::~GrpcClient() = default; +ServerMetadata GrpcClient::serverMetadata() { + inference::ServerMetadataRequest request; + inference::ServerMetadataResponse reply; + + ClientContext context; + + auto* stub = this->impl_->getStub(); + Status status = stub->ServerMetadata(&context, request, &reply); + + if (status.ok()) { + auto ext = reply.extensions(); + std::set extensions(ext.begin(), ext.end()); + ServerMetadata metadata{reply.name(), reply.version(), extensions}; + return metadata; + } + throw std::runtime_error(status.error_message()); +} + bool GrpcClient::serverLive() { inference::ServerLiveRequest request; inference::ServerLiveResponse reply; diff --git a/src/proteus/clients/native.cpp b/src/proteus/clients/native.cpp index f0e5b9948..600feacc6 100644 --- a/src/proteus/clients/native.cpp +++ b/src/proteus/clients/native.cpp @@ -23,6 +23,7 @@ #include // for getenv #include // for promise #include // for unique_ptr, make_unique +#include // for set #include // for invalid_argument #include // for string, basic_string #include // for unordered_map, operat... @@ -37,6 +38,7 @@ #include "proteus/observation/logging.hpp" // for initLogging #include "proteus/observation/metrics.hpp" // for Metrics, MetricCounte... #include "proteus/observation/tracing.hpp" // for startTrace, startTracer +#include "proteus/version.hpp" // for kProteusVersion #ifdef PROTEUS_ENABLE_AKS #include // for SysManagerExt @@ -76,6 +78,21 @@ void terminate() { NativeClient::~NativeClient() = default; +ServerMetadata NativeClient::serverMetadata() { + std::set extensions; + ServerMetadata metadata{"proteus", kProteusVersion, extensions}; + +#ifdef PROTEUS_ENABLE_AKS + metadata.extensions.insert("aks"); +#endif +#ifdef PROTEUS_ENABLE_VITIS + metadata.extensions.insert("vitis"); +#endif +#ifdef PROTEUS_ENABLE_TFZENDNN + metadata.extensions.insert("tfzendnn"); +#endif + return metadata; +} bool NativeClient::serverLive() { return true; } bool NativeClient::serverReady() { return true; } diff --git a/src/proteus/servers/grpc_server.cpp b/src/proteus/servers/grpc_server.cpp index 843730e5b..8e60357f5 100644 --- a/src/proteus/servers/grpc_server.cpp +++ b/src/proteus/servers/grpc_server.cpp @@ -42,6 +42,7 @@ #include "proteus/buffers/buffer.hpp" // for Buffer #include "proteus/build_options.hpp" // for PROTEUS_ENABLE_TRACING #include "proteus/clients/grpc_internal.hpp" // for mapProtoToParameters +#include "proteus/clients/native.hpp" // for NativeClient #include "proteus/core/data_types.hpp" // for DataType, mapStrToType #include "proteus/core/interface.hpp" // for Interface, Interfac... #include "proteus/core/manager.hpp" // for Manager @@ -50,7 +51,6 @@ #include "proteus/helpers/declarations.hpp" // for BufferRawPtrs, Infe... #include "proteus/observation/logging.hpp" // for SPDLOG_INFO, SPDLOG... #include "proteus/observation/tracing.hpp" // for Trace, startTrace -#include "proteus/version.hpp" // for kProteusVersion namespace proteus { class CallDataModelInfer; @@ -535,14 +535,13 @@ CALLDATA_IMPL(ModelReady, Unary) { CALLDATA_IMPL_END CALLDATA_IMPL(ServerMetadata, Unary) { - reply_.set_name("proteus"); - reply_.set_version(kProteusVersion); -#ifdef PROTEUS_ENABLE_AKS - reply_.add_extensions("aks"); -#endif -#ifdef PROTEUS_ENABLE_VITIS - reply_.add_extensions("vitis"); -#endif + NativeClient client; + auto metadata = client.serverMetadata(); + reply_.set_name(metadata.name); + reply_.set_version(metadata.version); + for (const auto& extension : metadata.extensions) { + reply_.add_extensions(extension); + } finish(); } CALLDATA_IMPL_END diff --git a/src/proteus/servers/http_server.cpp b/src/proteus/servers/http_server.cpp index 16339460e..4cbec7268 100644 --- a/src/proteus/servers/http_server.cpp +++ b/src/proteus/servers/http_server.cpp @@ -151,19 +151,16 @@ void v2::ProteusHttpServer::getServerMetadata( #endif (void)req; // suppress unused variable warning + NativeClient client; + auto metadata = client.serverMetadata(); + Json::Value ret; - ret["name"] = "proteus"; - ret["version"] = kProteusVersion; + ret["name"] = metadata.name; + ret["version"] = metadata.version; ret["extensions"] = Json::arrayValue; -#ifdef PROTEUS_ENABLE_AKS - ret["extensions"].append("aks"); -#endif -#ifdef PROTEUS_ENABLE_VITIS - ret["extensions"].append("vitis"); -#endif -#ifdef PROTEUS_ENABLE_TFZENDNN - ret["extensions"].append("tfzendnn"); -#endif + for (const auto &extension : metadata.extensions) { + ret["extensions"].append(extension); + } auto resp = HttpResponse::newHttpJsonResponse(ret); callback(resp); } diff --git a/src/python/src/proteus/rest.py b/src/python/src/proteus/rest.py index 0452cf032..22a904e24 100644 --- a/src/python/src/proteus/rest.py +++ b/src/python/src/proteus/rest.py @@ -287,6 +287,19 @@ def server_live(self): response = self._get(endpoint, error_str) return response.status_code == 200 + def server_metadata(self): + """ + Get the server's metadata + + Returns: + dict: metadata + """ + error_str = "Failed to get the metadata" + endpoint = self.get_endpoint("metadata") + response = self._get(endpoint, error_str) + assert response.status_code == 200 + return response.json() + def wait_until_live(self): """ Block until the server is live @@ -326,6 +339,7 @@ def get_endpoint(self, command, *args): arg_0 = args[0] commands = { + "metadata": "v2", "server_live": "v2/health/live", "infer": f"v2/models/{arg_0}/infer", "model_ready": f"v2/models/{arg_0}/ready", @@ -347,3 +361,7 @@ def get_address(self, command, *args): """ url = self.get_endpoint(command, *args) return f"{self.http_addr}/{url}" + + def has_extension(self, extension): + metadata = self.server_metadata() + return extension in metadata["extensions"]