Skip to content

Commit

Permalink
Added grpc tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RajivChitale committed Jan 27, 2024
1 parent f9ab15a commit 152dc30
Show file tree
Hide file tree
Showing 19 changed files with 394 additions and 145 deletions.
8 changes: 7 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ add_subdirectory(MLModelRunner)
add_subdirectory(SerDes)
add_subdirectory(test)

add_custom_target(copy)
add_custom_command(TARGET copy PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy CompilerInterface ${CMAKE_CURRENT_BINARY_DIR}/MLModelRunner/CompilerInterface)

if(LLVM_MLBRIDGE)
include(AddLLVM)
include(HandleLLVMOptions)
Expand All @@ -51,6 +54,9 @@ if(LLVM_MLBRIDGE)
ADDITIONAL_HEADER_DIRS
${CMAKE_CURRENT_SOURCE_DIR}/include

DEPENDS
copy

LINK_LIBS
ModelRunnerLib
$<TARGET_OBJECTS:SerDesLib>
Expand Down Expand Up @@ -89,4 +95,4 @@ else()
endif(LLVM_MLBRIDGE)

install(DIRECTORY include/ DESTINATION include)
install(DIRECTORY CompilerInterface DESTINATION include/python/MLCompilerBridge)
install(DIRECTORY CompilerInterface DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/MLModelRunner/CompilerInterface)
2 changes: 1 addition & 1 deletion MLModelRunner/gRPCModelRunner/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ if(LLVM_MLBRIDGE)
${proto_python_srcs_list}
)
else()
add_library(gRPCModelRunnerLib OBJECT
add_library(gRPCModelRunnerLib
${cc_files}
${proto_srcs_list}
${grpc_srcs_list}
Expand Down
13 changes: 13 additions & 0 deletions SerDes/protobufSerDes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ void *ProtobufSerDes::deserializeUntyped(void *data) {
this->MessageLength = ref.size() * sizeof(int32_t);
return ret->data();
}
if (field->type() == FieldDescriptor::Type::TYPE_INT64) {
auto &ref = reflection->GetRepeatedField<int64_t>(*Response, field);
std::vector<int64_t> *ret =
new std::vector<int64_t>(ref.begin(), ref.end());
this->MessageLength = ref.size() * sizeof(int64_t);
return ret->data();
}
if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) {
auto ref = reflection->GetRepeatedField<float>(*Response, field);
std::vector<float> *ret = new std::vector<float>(ref.begin(), ref.end());
Expand Down Expand Up @@ -199,6 +206,12 @@ void *ProtobufSerDes::deserializeUntyped(void *data) {
this->MessageLength = sizeof(int32_t);
return ptr;
}
if (field->type() == FieldDescriptor::Type::TYPE_INT64) {
int64_t value = reflection->GetInt64(*Response, field);
int64_t *ptr = new int64_t(value);
this->MessageLength = sizeof(int64_t);
return ptr;
}
if (field->type() == FieldDescriptor::Type::TYPE_FLOAT) {
float value = reflection->GetFloat(*Response, field);
float *ptr = new float(value);
Expand Down
5 changes: 3 additions & 2 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ file(GLOB MODEL_OBJECTS ${CMAKE_CURRENT_SOURCE_DIR}/tf_models/*.o)
foreach(MODEL_OBJECT ${MODEL_OBJECTS})
target_link_libraries(MLBridgeCPPTest PRIVATE ${MODEL_OBJECT})
endforeach()
target_link_libraries(MLBridgeCPPTest PRIVATE ModelRunnerUtils)
target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include)
target_link_libraries(MLBridgeCPPTest PRIVATE MLCompilerBridge )
target_include_directories(MLBridgeCPPTest PRIVATE ${CMAKE_BINARY_DIR}/include ${TENSORFLOW_AOT_PATH}/include ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_link_libraries(MLBridgeCPPTest PRIVATE tf_xla_runtime)
171 changes: 112 additions & 59 deletions test/MLBridgeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,16 @@
//
//===----------------------------------------------------------------------===//

#include "HelloMLBridge_Env.h"
#include "MLModelRunner/MLModelRunner.h"
#include "MLModelRunner/ONNXModelRunner/ONNXModelRunner.h"
#include "MLModelRunner/PipeModelRunner.h"
#include "MLModelRunner/TFModelRunner.h"
#include "MLModelRunner/Utils/DataTypes.h"
#include "MLModelRunner/Utils/MLConfig.h"
#include "MLModelRunner/gRPCModelRunner.h"
// #include "grpc/helloMLBridgeTest/helloMLBridgeTest.grpc.pb.h"
// #include "grpc/helloMLBridgeTest/helloMLBridgeTest.pb.h"
#include "grpcpp/impl/codegen/status.h"
#include "inference/HelloMLBridge_Env.h"
#include "ProtosInclude.h"
#include "llvm/Support/CommandLine.h"
// #include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <fstream>
#include <google/protobuf/text_format.h>
Expand All @@ -31,6 +28,20 @@
#define debug_out \
if (!silent) \
std::cout
using namespace grpc;

#define gRPCModelRunnerInit(datatype) \
increment_port(1); \
MLBridgeTestgRPC_##datatype::Reply response; \
MLBridgeTestgRPC_##datatype::Request request; \
MLRunner = std::make_unique< \
gRPCModelRunner<MLBridgeTestgRPC_##datatype::MLBridgeTestService, \
MLBridgeTestgRPC_##datatype::MLBridgeTestService::Stub, \
MLBridgeTestgRPC_##datatype::Request, \
MLBridgeTestgRPC_##datatype::Reply>>( \
server_address, &request, &response, nullptr); \
MLRunner->setRequest(&request); \
MLRunner->setResponse(&response)

static llvm::cl::opt<std::string>
cl_server_address("test-server-address", llvm::cl::Hidden,
Expand All @@ -55,7 +66,6 @@ std::string basename;
BaseSerDes::Kind SerDesType;

std::string test_config;
std::string data_format;
std::string pipe_name;
std::string server_address;

Expand All @@ -65,6 +75,7 @@ void testPrimitive(std::string label, T1 value, T2 expected) {
std::pair<std::string, T1> p("request_" + label, value);
MLRunner->populateFeatures(p);
T2 out = MLRunner->evaluate<T2>();

debug_out << " " << label << " reply: " << out << "\n";
if (std::abs(out - expected) > 10e-6) {
std::cerr << "Error: Expected " << label << " reply: " << expected
Expand Down Expand Up @@ -96,73 +107,117 @@ void testVector(std::string label, std::vector<T1> value,
debug_out << "\n";
}

void runTests() {
if (data_format != "json") {
testPrimitive<int, int>("int", 11, 12);
testPrimitive<long, long>("long", 1234567890, 1234567891);
testPrimitive<float, float>("float", 3.14, 4.14);
testPrimitive<double, double>("double", 0.123456789123456789,
1.123456789123456789);
testPrimitive<char, char>("char", 'a', 'b');
testPrimitive<bool, bool>("bool", true, false);
testVector<int, int>("vec_int", {11, 22, 33}, {12, 23, 34});
testVector<long, long>("vec_long", {123456780, 222, 333},
{123456780, 123456781, 123456782});
testVector<float, float>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
testVector<double, double>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
} else if (data_format == "json") {
testPrimitive<int, IntegerType>("int", 11, 12);
testPrimitive<long, IntegerType>("long", 1234567890, 12345);
testPrimitive<float, RealType>("float", 3.14, 4.14);
testPrimitive<double, RealType>("double", 0.123456789123456789,
1.123456789123456789);
testPrimitive<char, char>("char", 'a', 'b');
testPrimitive<bool, bool>("bool", true, false);
testVector<int, IntegerType>("vec_int", {11, 22, 33}, {12, 23, 34});
testVector<long, IntegerType>("vec_long", {123456780, 222, 333},
{6780, 6781, 6782});
testVector<float, RealType>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
testVector<double, RealType>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
int testPipeBytes() {
if (pipe_name == "") {
std::cerr
<< "Pipe name must be specified via --test-pipe-name=<filename>\n";
exit(1);
}
basename = "./" + pipe_name;
SerDesType = BaseSerDes::Kind::Bitstream;
MLRunner = std::make_unique<PipeModelRunner>(
basename + ".out", basename + ".in", SerDesType, nullptr);
testPrimitive<int, int>("int", 11, 12);
testPrimitive<long, long>("long", 1234567890, 1234567891);
testPrimitive<float, float>("float", 3.14, 4.14);
testPrimitive<double, double>("double", 0.123456789123456789,
1.123456789123456789);
testPrimitive<char, char>("char", 'a', 'b');
testPrimitive<bool, bool>("bool", true, false);
testVector<int, int>("vec_int", {11, 22, 33}, {12, 23, 34});
testVector<long, long>("vec_long", {123456780, 222, 333},
{123456780, 123456781, 123456782});
testVector<float, float>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
testVector<double, double>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
return 0;
}

int testPipes() {
int testPipeJSON() {
if (pipe_name == "") {
std::cerr
<< "Pipe name must be specified via --test-pipe-name=<filename>\n";
exit(1);
}
basename = "/tmp/" + pipe_name;
if (data_format == "json")
SerDesType = BaseSerDes::Kind::Json;
else if (data_format == "protobuf")
SerDesType = BaseSerDes::Kind::Protobuf;
else if (data_format == "bytes")
SerDesType = BaseSerDes::Kind::Bitstream;
else {
std::cout << "Invalid data format\n";
exit(1);
}

basename = "./" + pipe_name;
SerDesType = BaseSerDes::Kind::Json;
MLRunner = std::make_unique<PipeModelRunner>(
basename + ".out", basename + ".in", SerDesType, nullptr);

runTests();
testPrimitive<int, IntegerType>("int", 11, 12);
testPrimitive<long, IntegerType>("long", 1234567890, 12345);
testPrimitive<float, RealType>("float", 3.14, 4.14);
testPrimitive<double, RealType>("double", 0.123456789123456789,
1.123456789123456789);
testPrimitive<char, char>("char", 'a', 'b');
testPrimitive<bool, bool>("bool", true, false);
testVector<int, IntegerType>("vec_int", {11, 22, 33}, {12, 23, 34});
testVector<long, IntegerType>("vec_long", {123456780, 222, 333},
{6780, 6781, 6782});
testVector<float, RealType>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
testVector<double, RealType>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
return 0;
}

void increment_port(int delta) {
int split = server_address.find(":");
int port = stoi(server_address.substr(split + 1));
server_address =
server_address.substr(0, split) + ":" + to_string(port + delta);
}

int testGRPC() {
if (server_address == "") {
std::cerr << "Server Address must be specified via "
"--test-server-address=<ip>:<port>\n";
"--test-server-address=\"<ip>:<port>\"\n";
exit(1);
}
{
gRPCModelRunnerInit(int);
testPrimitive<int, int>("int", 11, 12);
}
{
gRPCModelRunnerInit(long);
testPrimitive<long, long>("long", 1234567890, 1234567891);
}
{
gRPCModelRunnerInit(float);
testPrimitive<float, float>("float", 3.14, 4.14);
}
{
gRPCModelRunnerInit(double);
testPrimitive<double, double>("double", 0.123456789123456789,
1.123456789123456789);
}
increment_port(1);
{
gRPCModelRunnerInit(bool);
testPrimitive<bool, bool>("bool", true, false);
}
{
gRPCModelRunnerInit(vec_int);
testVector<int, int>("vec_int", {11, 22, 33}, {12, 23, 34});
}
{
gRPCModelRunnerInit(vec_long);
testVector<long, long>("vec_long", {123456780, 222, 333},
{123456780, 123456781, 123456782});
}
{
gRPCModelRunnerInit(vec_float);
testVector<float, float>("vec_float", {11.1, 22.2, 33.3},
{1.11, 2.22, -3.33, 0});
}
{
gRPCModelRunnerInit(vec_double);
testVector<double, double>("vec_double",
{-1.1111111111, -2.2222222222, -3.3333333333},
{1.12345678912345670, -1.12345678912345671});
}
return 0;
}

Expand All @@ -176,12 +231,10 @@ int main(int argc, char **argv) {

if (test_config == "pipe-bytes") {
pipe_name = cl_pipe_name.getValue();
data_format = "bytes";
testPipes();
testPipeBytes();
} else if (test_config == "pipe-json") {
pipe_name = cl_pipe_name.getValue();
data_format = "json";
testPipes();
testPipeJSON();
} else if (test_config == "grpc") {
server_address = cl_server_address.getValue();
testGRPC();
Expand Down
40 changes: 40 additions & 0 deletions test/include/HelloMLBridge_Env.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===----------------------------------------------------------------------===//
//
// Part of the MLCompilerBridge Project, under the Apache 2.0 License.
// See the LICENSE file under home directory for license and copyright
// information.
//
//===----------------------------------------------------------------------===//

#include "MLModelRunner/ONNXModelRunner/environment.h"
#include "MLModelRunner/ONNXModelRunner/utils.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/raw_ostream.h"

using namespace MLBridge;
class MLBridgeTestEnv : public Environment {
Observation CurrObs;

public:
MLBridgeTestEnv() { setNextAgent("agent"); };
Observation &reset() override;
Observation &step(Action) override;

protected:
std::vector<float> FeatureVector;
};

Observation &MLBridgeTestEnv::step(Action Action) {
CurrObs.clear();
std::copy(FeatureVector.begin(), FeatureVector.end(),
std::back_inserter(CurrObs));
llvm::outs() << "Action: " << Action << "\n";
setDone();
return CurrObs;
}

Observation &MLBridgeTestEnv::reset() {
std::copy(FeatureVector.begin(), FeatureVector.end(),
std::back_inserter(CurrObs));
return CurrObs;
}
20 changes: 20 additions & 0 deletions test/include/ProtosInclude.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#include "grpc/MLBridgeTest_bool/MLBridgeTest_bool.grpc.pb.h"
#include "grpc/MLBridgeTest_bool/MLBridgeTest_bool.pb.h"
#include "grpc/MLBridgeTest_char/MLBridgeTest_char.grpc.pb.h"
#include "grpc/MLBridgeTest_char/MLBridgeTest_char.pb.h"
#include "grpc/MLBridgeTest_double/MLBridgeTest_double.grpc.pb.h"
#include "grpc/MLBridgeTest_double/MLBridgeTest_double.pb.h"
#include "grpc/MLBridgeTest_float/MLBridgeTest_float.grpc.pb.h"
#include "grpc/MLBridgeTest_float/MLBridgeTest_float.pb.h"
#include "grpc/MLBridgeTest_int/MLBridgeTest_int.grpc.pb.h"
#include "grpc/MLBridgeTest_int/MLBridgeTest_int.pb.h"
#include "grpc/MLBridgeTest_long/MLBridgeTest_long.grpc.pb.h"
#include "grpc/MLBridgeTest_long/MLBridgeTest_long.pb.h"
#include "grpc/MLBridgeTest_vec_double/MLBridgeTest_vec_double.grpc.pb.h"
#include "grpc/MLBridgeTest_vec_double/MLBridgeTest_vec_double.pb.h"
#include "grpc/MLBridgeTest_vec_float/MLBridgeTest_vec_float.grpc.pb.h"
#include "grpc/MLBridgeTest_vec_float/MLBridgeTest_vec_float.pb.h"
#include "grpc/MLBridgeTest_vec_int/MLBridgeTest_vec_int.grpc.pb.h"
#include "grpc/MLBridgeTest_vec_int/MLBridgeTest_vec_int.pb.h"
#include "grpc/MLBridgeTest_vec_long/MLBridgeTest_vec_long.grpc.pb.h"
#include "grpc/MLBridgeTest_vec_long/MLBridgeTest_vec_long.pb.h"
Loading

0 comments on commit 152dc30

Please sign in to comment.