Skip to content

Commit

Permalink
Add ServerMetadata API
Browse files Browse the repository at this point in the history
Update the tfzendnn examples
  • Loading branch information
varunsh-xilinx committed Apr 19, 2022
1 parent 8d86881 commit 7747911
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 52 deletions.
26 changes: 12 additions & 14 deletions examples/cpp/tf_zendnn_client.cpp
Expand Up @@ -248,6 +248,14 @@ int main() {
// initialize the server
proteus::initialize();

auto client = proteus::NativeClient();
auto metadata = client.serverMetadata();
if (metadata.extensions.find("tfzendnn") == metadata.extensions.end()) {
std::cout << "TFZenDNN support required but not found.\n";
proteus::terminate();
exit(0);
}

// load worker with required parameters
proteus::RequestParameters parameters;
// parameters.put("max_buffer_num", options.batch_size);
Expand All @@ -257,7 +265,7 @@ int main() {
parameters.put("input_size", options.input_size);
parameters.put("inter_op", options.inter_op);
parameters.put("intra_op", options.intra_op);
auto workerName = proteus::load("TfZendnn", &parameters);
auto workerName = client.modelLoad("TfZendnn", &parameters);

float time_tmp = 0.f;
// prepare images for inference
Expand All @@ -274,11 +282,7 @@ int main() {
auto start = std::chrono::high_resolution_clock::now(); // Timing the start
request.addInputTensor(static_cast<void*>(images[0].data()), shape,
proteus::types::DataType::FP32);
queue.push(proteus::enqueue(workerName, request));

auto front = std::move(queue.front());
queue.pop();
auto results = front.get();
auto results = client.modelInfer(workerName, request);
// Timing the prediction
auto stop = std::chrono::high_resolution_clock::now();
auto duration =
Expand Down Expand Up @@ -323,10 +327,7 @@ int main() {
request.addInputTensor(static_cast<void*>(images[i].data()), shape,
proteus::types::DataType::FP32);
}
queue.push(proteus::enqueue(workerName, request));
auto front = std::move(queue.front());
queue.pop();
auto results = front.get();
auto results = client.modelInfer(workerName, request);
}

// Running for `steps` number of time for proper benchmarking
Expand All @@ -338,10 +339,7 @@ int main() {
request.addInputTensor(static_cast<void*>(images[i].data()), shape,
proteus::types::DataType::FP32);
}
queue.push(proteus::enqueue(workerName, request));
auto front = std::move(queue.front());
queue.pop();
auto results = front.get();
auto results = client.modelInfer(workerName, request);
}
// Timing the prediction
auto stop = std::chrono::high_resolution_clock::now();
Expand Down
43 changes: 25 additions & 18 deletions examples/python/tf_zendnn.py
Expand Up @@ -12,16 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# fmt: off
import argparse
import os
import sys
import time
import argparse

import numpy as np

import proteus
from utils.utils import preprocess, postprocess
# fmt: on


def main(args):
Expand All @@ -41,6 +40,29 @@ def main(args):
FileNotFoundError: If image_location is given and the file is not found
"""

# Create server objects
server = proteus.Server()
client = proteus.RestClient("0.0.0.0:8998", None)

# Start server: if it's not already started, start it here
try:
start_server = not client.server_live()
print("Server already up")
except proteus.ConnectionError:
start_server = True
if start_server:
print("Starting server")
server.start(quiet=True)
client.wait_until_live()

if not client.has_extension("tfzendnn"):
print("TFZenDNN support required but not found.")
if start_server:
print("Closing server")
server.stop()
client.wait_until_stop()
sys.exit(0)

# Argument parsing
real_data = True if args.image_location else False
batch_size = args.batch_size
Expand Down Expand Up @@ -70,21 +92,6 @@ def main(args):
classes = list(range(1000))
classes = np.asarray(classes)

# Create server objects
server = proteus.Server()
client = proteus.RestClient("0.0.0.0:8998", None)

# Start server: if it's not already started, start it here
try:
start_server = not client.server_live()
print("Server already up")
except proteus.ConnectionError:
start_server = True
if start_server:
print("Starting server")
server.start(quiet=True)
client.wait_until_live()

# Load the worker with appropriate paramaters
parameters = {
"model": args.graph,
Expand Down
1 change: 1 addition & 0 deletions include/proteus/clients/client.hpp
Expand Up @@ -30,6 +30,7 @@ class Client {
public:
virtual ~Client() = default;

virtual ServerMetadata serverMetadata() = 0;
virtual bool serverLive() = 0;
virtual bool serverReady() = 0;
virtual bool modelReady(const std::string& model) = 0;
Expand Down
1 change: 1 addition & 0 deletions include/proteus/clients/grpc.hpp
Expand Up @@ -49,6 +49,7 @@ class GrpcClient : public Client {
GrpcClient(const std::shared_ptr<::grpc::Channel>& channel);
~GrpcClient();

ServerMetadata serverMetadata() override;
bool serverLive() override;
bool serverReady() override;
bool modelReady(const std::string& model) override;
Expand Down
1 change: 1 addition & 0 deletions include/proteus/clients/native.hpp
Expand Up @@ -54,6 +54,7 @@ class NativeClient : public Client {
public:
~NativeClient();

ServerMetadata serverMetadata() override;
bool serverLive() override;
bool serverReady() override;

Expand Down
7 changes: 7 additions & 0 deletions include/proteus/core/predict_api.hpp
Expand Up @@ -28,6 +28,7 @@
#include <map> // for map, operator==, map<>::...
#include <memory> // for shared_ptr, allocator
#include <ostream> // for operator<<, ostream, bas...
#include <set> // for set
#include <string> // for string, operator<<, char...
#include <string_view> // for string_view
#include <utility> // for move
Expand Down Expand Up @@ -137,6 +138,12 @@ class RequestParameters {

using RequestParametersPtr = std::shared_ptr<RequestParameters>;

struct ServerMetadata {
std::string name;
std::string version;
std::set<std::string> extensions;
};

/**
* @brief Holds an inference request's input data
*/
Expand Down
19 changes: 19 additions & 0 deletions src/proteus/clients/grpc.cpp
Expand Up @@ -29,6 +29,7 @@
#include <iostream> // for operator<<, cout
#include <map> // for map
#include <memory> // for make_shared, reinter...
#include <set> // for set
#include <stdexcept> // for runtime_error
#include <string> // for string, operator+
#include <utility> // for move
Expand Down Expand Up @@ -71,6 +72,24 @@ GrpcClient::GrpcClient(const std::shared_ptr<::grpc::Channel>& channel) {

GrpcClient::~GrpcClient() = default;

ServerMetadata GrpcClient::serverMetadata() {
inference::ServerMetadataRequest request;
inference::ServerMetadataResponse reply;

ClientContext context;

auto* stub = this->impl_->getStub();
Status status = stub->ServerMetadata(&context, request, &reply);

if (status.ok()) {
auto ext = reply.extensions();
std::set<std::string> extensions(ext.begin(), ext.end());
ServerMetadata metadata{reply.name(), reply.version(), extensions};
return metadata;
}
throw std::runtime_error(status.error_message());
}

bool GrpcClient::serverLive() {
inference::ServerLiveRequest request;
inference::ServerLiveResponse reply;
Expand Down
17 changes: 17 additions & 0 deletions src/proteus/clients/native.cpp
Expand Up @@ -23,6 +23,7 @@
#include <cstdlib> // for getenv
#include <future> // for promise
#include <memory> // for unique_ptr, make_unique
#include <set> // for set
#include <stdexcept> // for invalid_argument
#include <string> // for string, basic_string
#include <unordered_map> // for unordered_map, operat...
Expand All @@ -37,6 +38,7 @@
#include "proteus/observation/logging.hpp" // for initLogging
#include "proteus/observation/metrics.hpp" // for Metrics, MetricCounte...
#include "proteus/observation/tracing.hpp" // for startTrace, startTracer
#include "proteus/version.hpp" // for kProteusVersion

#ifdef PROTEUS_ENABLE_AKS
#include <aks/AksSysManagerExt.h> // for SysManagerExt
Expand Down Expand Up @@ -76,6 +78,21 @@ void terminate() {

NativeClient::~NativeClient() = default;

ServerMetadata NativeClient::serverMetadata() {
std::set<std::string> extensions;
ServerMetadata metadata{"proteus", kProteusVersion, extensions};

#ifdef PROTEUS_ENABLE_AKS
metadata.extensions.insert("aks");
#endif
#ifdef PROTEUS_ENABLE_VITIS
metadata.extensions.insert("vitis");
#endif
#ifdef PROTEUS_ENABLE_TFZENDNN
metadata.extensions.insert("tfzendnn");
#endif
return metadata;
}
bool NativeClient::serverLive() { return true; }
bool NativeClient::serverReady() { return true; }

Expand Down
17 changes: 8 additions & 9 deletions src/proteus/servers/grpc_server.cpp
Expand Up @@ -42,6 +42,7 @@
#include "proteus/buffers/buffer.hpp" // for Buffer
#include "proteus/build_options.hpp" // for PROTEUS_ENABLE_TRACING
#include "proteus/clients/grpc_internal.hpp" // for mapProtoToParameters
#include "proteus/clients/native.hpp" // for NativeClient
#include "proteus/core/data_types.hpp" // for DataType, mapStrToType
#include "proteus/core/interface.hpp" // for Interface, Interfac...
#include "proteus/core/manager.hpp" // for Manager
Expand All @@ -50,7 +51,6 @@
#include "proteus/helpers/declarations.hpp" // for BufferRawPtrs, Infe...
#include "proteus/observation/logging.hpp" // for SPDLOG_INFO, SPDLOG...
#include "proteus/observation/tracing.hpp" // for Trace, startTrace
#include "proteus/version.hpp" // for kProteusVersion

namespace proteus {
class CallDataModelInfer;
Expand Down Expand Up @@ -535,14 +535,13 @@ CALLDATA_IMPL(ModelReady, Unary) {
CALLDATA_IMPL_END

CALLDATA_IMPL(ServerMetadata, Unary) {
reply_.set_name("proteus");
reply_.set_version(kProteusVersion);
#ifdef PROTEUS_ENABLE_AKS
reply_.add_extensions("aks");
#endif
#ifdef PROTEUS_ENABLE_VITIS
reply_.add_extensions("vitis");
#endif
NativeClient client;
auto metadata = client.serverMetadata();
reply_.set_name(metadata.name);
reply_.set_version(metadata.version);
for (const auto& extension : metadata.extensions) {
reply_.add_extensions(extension);
}
finish();
}
CALLDATA_IMPL_END
Expand Down
19 changes: 8 additions & 11 deletions src/proteus/servers/http_server.cpp
Expand Up @@ -151,19 +151,16 @@ void v2::ProteusHttpServer::getServerMetadata(
#endif
(void)req; // suppress unused variable warning

NativeClient client;
auto metadata = client.serverMetadata();

Json::Value ret;
ret["name"] = "proteus";
ret["version"] = kProteusVersion;
ret["name"] = metadata.name;
ret["version"] = metadata.version;
ret["extensions"] = Json::arrayValue;
#ifdef PROTEUS_ENABLE_AKS
ret["extensions"].append("aks");
#endif
#ifdef PROTEUS_ENABLE_VITIS
ret["extensions"].append("vitis");
#endif
#ifdef PROTEUS_ENABLE_TFZENDNN
ret["extensions"].append("tfzendnn");
#endif
for (const auto &extension : metadata.extensions) {
ret["extensions"].append(extension);
}
auto resp = HttpResponse::newHttpJsonResponse(ret);
callback(resp);
}
Expand Down
18 changes: 18 additions & 0 deletions src/python/src/proteus/rest.py
Expand Up @@ -287,6 +287,19 @@ def server_live(self):
response = self._get(endpoint, error_str)
return response.status_code == 200

def server_metadata(self):
"""
Get the server's metadata
Returns:
dict: metadata
"""
error_str = "Failed to get the metadata"
endpoint = self.get_endpoint("metadata")
response = self._get(endpoint, error_str)
assert response.status_code == 200
return response.json()

def wait_until_live(self):
"""
Block until the server is live
Expand Down Expand Up @@ -326,6 +339,7 @@ def get_endpoint(self, command, *args):
arg_0 = args[0]

commands = {
"metadata": "v2",
"server_live": "v2/health/live",
"infer": f"v2/models/{arg_0}/infer",
"model_ready": f"v2/models/{arg_0}/ready",
Expand All @@ -347,3 +361,7 @@ def get_address(self, command, *args):
"""
url = self.get_endpoint(command, *args)
return f"{self.http_addr}/{url}"

def has_extension(self, extension):
metadata = self.server_metadata()
return extension in metadata["extensions"]

0 comments on commit 7747911

Please sign in to comment.