From 328456a291b279fd1379ede37db98310ef7ee1a3 Mon Sep 17 00:00:00 2001 From: Umesh-k26 Date: Mon, 13 May 2024 13:06:12 +0530 Subject: [PATCH] Modified MLBridgeTest.cpp --- test/MLBridgeTest.cpp | 111 +++++++++++++++++++++--------------------- 1 file changed, 56 insertions(+), 55 deletions(-) diff --git a/test/MLBridgeTest.cpp b/test/MLBridgeTest.cpp index 1915383..6b43cd5 100644 --- a/test/MLBridgeTest.cpp +++ b/test/MLBridgeTest.cpp @@ -29,19 +29,6 @@ std::cout using namespace grpc; -#define gRPCModelRunnerInit(datatype) \ - increment_port(1); \ - MLBridgeTestgRPC_##datatype::Reply response; \ - MLBridgeTestgRPC_##datatype::Request request; \ - MLRunner = std::make_unique< \ - gRPCModelRunner>( \ - server_address, &request, &response, nullptr); \ - MLRunner->setRequest(&request); \ - MLRunner->setResponse(&response) - static llvm::cl::opt cl_server_address("test-server-address", llvm::cl::Hidden, llvm::cl::desc("Server address, format :"), @@ -64,9 +51,8 @@ static llvm::cl::opt llvm::cl::desc("Only print errors when if set to true")); namespace { -std::unique_ptr MLRunner; std::string basename; -BaseSerDes::Kind SerDesType; +SerDesKind SerDesType; std::string test_config; std::string pipe_name; @@ -74,11 +60,12 @@ std::string server_address; std::string onnx_path; // send value of type T1. Test received value of type T2 against expected value -template -void testPrimitive(std::string label, T1 value, T2 expected) { +template +void testPrimitive(MLRunnerTy &MLRunner, std::string label, T1 value, + T2 expected) { std::pair p("request_" + label, value); MLRunner->populateFeatures(p); - T2 out = MLRunner->evaluate(); + T2 out = MLRunner->template evaluate(); debug_out << " " << label << " reply: " << out << "\n"; if (std::abs(out - expected) > 10e-6) { std::cerr << "Error: Expected " << label << " reply: " << expected @@ -87,14 +74,14 @@ void testPrimitive(std::string label, T1 value, T2 expected) { } } -template -void testVector(std::string label, std::vector value, +template +void testVector(MLRunnerTy &MLRunner, std::string label, std::vector value, std::vector expected) { std::pair> p("request_" + label, value); MLRunner->populateFeatures(p); T2 *out; size_t size; - MLRunner->evaluate(out, size); + MLRunner->template evaluate(out, size); std::vector reply(out, out + size); debug_out << " " << label << " reply: "; int i = 0; @@ -117,22 +104,22 @@ int testPipeBytes() { exit(1); } basename = "./" + pipe_name; - SerDesType = BaseSerDes::Kind::Bitstream; - MLRunner = std::make_unique( + SerDesType = SerDesKind::Bitstream; + auto MLRunner = std::make_unique( basename + ".out", basename + ".in", SerDesType, nullptr); - testPrimitive("int", 11, 12); - testPrimitive("long", 1234567890, 1234567891); - testPrimitive("float", 3.14, 4.14); - testPrimitive("double", 0.123456789123456789, + testPrimitive(MLRunner, "int", 11, 12); + testPrimitive(MLRunner, "long", 1234567890, 1234567891); + testPrimitive(MLRunner, "float", 3.14, 4.14); + testPrimitive(MLRunner, "double", 0.123456789123456789, 1.123456789123456789); - testPrimitive("char", 'a', 'b'); - testPrimitive("bool", true, false); - testVector("vec_int", {11, 22, 33}, {12, 23, 34}); - testVector("vec_long", {123456780, 222, 333}, + testPrimitive(MLRunner, "char", 'a', 'b'); + testPrimitive(MLRunner, "bool", true, false); + testVector(MLRunner, "vec_int", {11, 22, 33}, {12, 23, 34}); + testVector(MLRunner, "vec_long", {123456780, 222, 333}, {123456780, 123456781, 123456782}); - testVector("vec_float", {11.1, 22.2, 33.3}, + testVector(MLRunner, "vec_float", {11.1, 22.2, 33.3}, {1.11, 2.22, -3.33, 0}); - testVector("vec_double", + testVector(MLRunner, "vec_double", {-1.1111111111, -2.2222222222, -3.3333333333}, {1.12345678912345670, -1.12345678912345671}); return 0; @@ -145,22 +132,22 @@ int testPipeJSON() { exit(1); } basename = "./" + pipe_name; - SerDesType = BaseSerDes::Kind::Json; - MLRunner = std::make_unique( + SerDesType = SerDesKind::Json; + auto MLRunner = std::make_unique( basename + ".out", basename + ".in", SerDesType, nullptr); - testPrimitive("int", 11, 12); - testPrimitive("long", 1234567890, 12345); - testPrimitive("float", 3.14, 4.14); - testPrimitive("double", 0.123456789123456789, + testPrimitive(MLRunner, "int", 11, 12); + testPrimitive(MLRunner, "long", 1234567890, 12345); + testPrimitive(MLRunner, "float", 3.14, 4.14); + testPrimitive(MLRunner, "double", 0.123456789123456789, 1.123456789123456789); - testPrimitive("char", 'a', 'b'); - testPrimitive("bool", true, false); - testVector("vec_int", {11, 22, 33}, {12, 23, 34}); - testVector("vec_long", {123456780, 222, 333}, + testPrimitive(MLRunner, "char", 'a', 'b'); + testPrimitive(MLRunner, "bool", true, false); + testVector(MLRunner, "vec_int", {11, 22, 33}, {12, 23, 34}); + testVector(MLRunner, "vec_long", {123456780, 222, 333}, {6780, 6781, 6782}); - testVector("vec_float", {11.1, 22.2, 33.3}, + testVector(MLRunner, "vec_float", {11.1, 22.2, 33.3}, {1.11, 2.22, -3.33, 0}); - testVector("vec_double", + testVector(MLRunner, "vec_double", {-1.1111111111, -2.2222222222, -3.3333333333}, {1.12345678912345670, -1.12345678912345671}); return 0; @@ -174,6 +161,19 @@ void increment_port(int delta) { } int testGRPC() { +#define gRPCModelRunnerInit(datatype) \ + increment_port(1); \ + MLBridgeTestgRPC_##datatype::Reply response; \ + MLBridgeTestgRPC_##datatype::Request request; \ + auto MLRunner = std::make_unique< \ + gRPCModelRunner>( \ + server_address, &request, &response, nullptr); \ + MLRunner->setRequest(&request); \ + MLRunner->setResponse(&response) + if (server_address == "") { std::cerr << "Server Address must be specified via " "--test-server-address=\":\"\n"; @@ -181,46 +181,47 @@ int testGRPC() { } { gRPCModelRunnerInit(int); - testPrimitive("int", 11, 12); + testPrimitive(MLRunner, "int", 11, 12); } { gRPCModelRunnerInit(long); - testPrimitive("long", 1234567890, 1234567891); + testPrimitive(MLRunner, "long", 1234567890, 1234567891); } { gRPCModelRunnerInit(float); - testPrimitive("float", 3.14, 4.14); + testPrimitive(MLRunner, "float", 3.14, 4.14); } { gRPCModelRunnerInit(double); - testPrimitive("double", 0.123456789123456789, + testPrimitive(MLRunner, "double", 0.123456789123456789, 1.123456789123456789); } increment_port(1); { gRPCModelRunnerInit(bool); - testPrimitive("bool", true, false); + testPrimitive(MLRunner, "bool", true, false); } { gRPCModelRunnerInit(vec_int); - testVector("vec_int", {11, 22, 33}, {12, 23, 34}); + testVector(MLRunner, "vec_int", {11, 22, 33}, {12, 23, 34}); } { gRPCModelRunnerInit(vec_long); - testVector("vec_long", {123456780, 222, 333}, + testVector(MLRunner, "vec_long", {123456780, 222, 333}, {123456780, 123456781, 123456782}); } { gRPCModelRunnerInit(vec_float); - testVector("vec_float", {11.1, 22.2, 33.3}, + testVector(MLRunner, "vec_float", {11.1, 22.2, 33.3}, {1.11, 2.22, -3.33, 0}); } { gRPCModelRunnerInit(vec_double); - testVector("vec_double", + testVector(MLRunner, "vec_double", {-1.1111111111, -2.2222222222, -3.3333333333}, {1.12345678912345670, -1.12345678912345671}); } +#undef gRPCModelRunnerInit return 0; } @@ -243,7 +244,7 @@ class ONNXTest : public MLBridgeTestEnv { Agent *agent = new Agent(onnx_path); std::map agents; agents["agent"] = agent; - MLRunner = std::make_unique(this, agents, nullptr); + auto MLRunner = std::make_unique(this, agents, nullptr); MLRunner->evaluate(); if (lastAction != expectedAction) { std::cerr << "Error: Expected action: " << expectedAction