diff --git a/MLModelRunner/PipeModelRunner.cpp b/MLModelRunner/PipeModelRunner.cpp index 55dd940..e608d8a 100644 --- a/MLModelRunner/PipeModelRunner.cpp +++ b/MLModelRunner/PipeModelRunner.cpp @@ -31,7 +31,7 @@ using namespace llvm; namespace MLBridge { PipeModelRunner::PipeModelRunner(StringRef OutboundName, StringRef InboundName, - BaseSerDes::Kind SerDesType, LLVMContext *Ctx) + SerDesKind SerDesType, LLVMContext *Ctx) : MLModelRunner(Kind::Pipe, SerDesType, Ctx), InEC(sys::fs::openFileForRead(InboundName, Inbound)) { this->InboundName = InboundName.str(); diff --git a/include/MLModelRunner/MLModelRunner.h b/include/MLModelRunner/MLModelRunner.h index 9c3f0c2..77892ec 100644 --- a/include/MLModelRunner/MLModelRunner.h +++ b/include/MLModelRunner/MLModelRunner.h @@ -84,7 +84,7 @@ class MLModelRunner { enum class Kind : int { Unknown, Pipe, gRPC, ONNX, TFAOT }; Kind getKind() const { return Type; } - BaseSerDes::Kind getSerDesKind() const { return SerDesType; } + SerDesKind getSerDesKind() const { return SerDesType; } virtual void requestExit() = 0; @@ -118,7 +118,7 @@ class MLModelRunner { void setResponse(void *response) { SerDes->setResponse(response); } protected: - MLModelRunner(Kind Type, BaseSerDes::Kind SerDesType, + MLModelRunner(Kind Type, SerDesKind SerDesType, llvm::LLVMContext *Ctx = nullptr) : Ctx(Ctx), Type(Type), SerDesType(SerDesType) { assert(Type != Kind::Unknown); @@ -126,7 +126,7 @@ class MLModelRunner { } MLModelRunner(Kind Type, llvm::LLVMContext *Ctx = nullptr) - : Ctx(Ctx), Type(Type), SerDesType(BaseSerDes::Kind::Unknown) { + : Ctx(Ctx), Type(Type), SerDesType(SerDesKind::Unknown) { SerDes = nullptr; }; @@ -136,7 +136,7 @@ class MLModelRunner { llvm::LLVMContext *Ctx; const Kind Type; - const BaseSerDes::Kind SerDesType; + const SerDesKind SerDesType; protected: std::unique_ptr SerDes; @@ -144,21 +144,21 @@ class MLModelRunner { private: void initSerDes() { switch (SerDesType) { - case BaseSerDes::Kind::Json: + case SerDesKind::Json: SerDes = std::make_unique(); break; - case BaseSerDes::Kind::Bitstream: + case SerDesKind::Bitstream: SerDes = std::make_unique(); break; #ifndef C_LIBRARY - case BaseSerDes::Kind::Protobuf: + case SerDesKind::Protobuf: SerDes = std::make_unique(); break; - case BaseSerDes::Kind::Tensorflow: + case SerDesKind::Tensorflow: SerDes = std::make_unique(); break; #endif - case BaseSerDes::Kind::Unknown: + case SerDesKind::Unknown: SerDes = nullptr; break; } diff --git a/include/MLModelRunner/PipeModelRunner.h b/include/MLModelRunner/PipeModelRunner.h index 4bddc0c..b928cb7 100755 --- a/include/MLModelRunner/PipeModelRunner.h +++ b/include/MLModelRunner/PipeModelRunner.h @@ -55,7 +55,7 @@ namespace MLBridge { class PipeModelRunner : public MLModelRunner { public: PipeModelRunner(llvm::StringRef OutboundName, llvm::StringRef InboundName, - BaseSerDes::Kind Kind, llvm::LLVMContext *Ctx = nullptr); + SerDesKind Kind, llvm::LLVMContext *Ctx = nullptr); static bool classof(const MLModelRunner *R) { return R->getKind() == MLModelRunner::Kind::Pipe; diff --git a/include/MLModelRunner/TFModelRunner.h b/include/MLModelRunner/TFModelRunner.h index 755a8e2..012ac3f 100644 --- a/include/MLModelRunner/TFModelRunner.h +++ b/include/MLModelRunner/TFModelRunner.h @@ -34,8 +34,7 @@ template class TFModelRunner final : public MLModelRunner { TFModelRunner(llvm::StringRef DecisionName, llvm::LLVMContext &Ctx, llvm::StringRef FeedPrefix = "feed_", llvm::StringRef FetchPrefix = "fetch_") - : MLModelRunner(MLModelRunner::Kind::TFAOT, BaseSerDes::Kind::Tensorflow, - &Ctx), + : MLModelRunner(MLModelRunner::Kind::TFAOT, SerDesKind::Tensorflow, &Ctx), CompiledModel(std::make_unique()) { SerDes->setRequest(CompiledModel.get()); @@ -49,7 +48,7 @@ template class TFModelRunner final : public MLModelRunner { TFModelRunner(llvm::StringRef DecisionName, llvm::StringRef FeedPrefix = "feed_", llvm::StringRef FetchPrefix = "fetch_") - : MLModelRunner(MLModelRunner::Kind::TFAOT, BaseSerDes::Kind::Tensorflow), + : MLModelRunner(MLModelRunner::Kind::TFAOT, SerDesKind::Tensorflow), CompiledModel(std::make_unique()) { SerDes->setRequest(CompiledModel.get()); diff --git a/include/MLModelRunner/gRPCModelRunner.h b/include/MLModelRunner/gRPCModelRunner.h index 3582052..9414441 100644 --- a/include/MLModelRunner/gRPCModelRunner.h +++ b/include/MLModelRunner/gRPCModelRunner.h @@ -70,6 +70,7 @@ #include #include #include +#include namespace MLBridge { /// This class is used to create the grpc model runner object. grpc model runner @@ -81,8 +82,7 @@ class gRPCModelRunner : public MLModelRunner { /// For server mode gRPCModelRunner(std::string server_address, grpc::Service *s, llvm::LLVMContext *Ctx = nullptr) - : MLModelRunner(MLModelRunner::Kind::gRPC, BaseSerDes::Kind::Protobuf, - Ctx), + : MLModelRunner(MLModelRunner::Kind::gRPC, SerDesKind::Protobuf, Ctx), server_address(server_address), request(nullptr), response(nullptr), server_mode(true) { RunService(s); @@ -91,8 +91,7 @@ class gRPCModelRunner : public MLModelRunner { /// For client mode gRPCModelRunner(std::string server_address, Request *request, Response *response, llvm::LLVMContext *Ctx = nullptr) - : MLModelRunner(MLModelRunner::Kind::gRPC, BaseSerDes::Kind::Protobuf, - Ctx), + : MLModelRunner(MLModelRunner::Kind::gRPC, SerDesKind::Protobuf, Ctx), server_address(server_address), request(request), response(response), server_mode(false) { SetStub(); diff --git a/include/SerDes/baseSerDes.h b/include/SerDes/baseSerDes.h index 95e7fb6..977143a 100644 --- a/include/SerDes/baseSerDes.h +++ b/include/SerDes/baseSerDes.h @@ -11,7 +11,7 @@ /// 1. Create a new class which inherits from BaseSerDes. /// 2. Implement the setFeature(), getSerializedData(), cleanDataStructures() /// and deserializeUntyped() methods. -/// 3. Add the new SerDes to the enum class Kind in this class. +/// 3. Add the new SerDes to the enum class SerDesKind in this class. /// //===----------------------------------------------------------------------===// @@ -43,10 +43,10 @@ namespace MLBridge { /// communication by the MLModelRunner. /// Currently, (int, float) or (long, double), char and bool are supported. /// Vectors of these types are supported as well. +enum class SerDesKind : int { Unknown, Json, Bitstream, Protobuf, Tensorflow }; class BaseSerDes { public: - enum class Kind : int { Unknown, Json, Bitstream, Protobuf, Tensorflow }; - Kind getKind() const { return Type; } + SerDesKind getKind() const { return Type; } /// setFeature() is used to set the features of the data structure used for /// communication. The features are set as key-value pairs. The key is a @@ -75,9 +75,11 @@ class BaseSerDes { virtual void *getResponse() { return nullptr; }; protected: - BaseSerDes(Kind Type) : Type(Type) { assert(Type != Kind::Unknown); } + BaseSerDes(SerDesKind Type) : Type(Type) { + assert(Type != SerDesKind::Unknown); + } virtual void cleanDataStructures() = 0; - const Kind Type; + const SerDesKind Type; void *RequestVoid; void *ResponseVoid; size_t MessageLength; diff --git a/include/SerDes/bitstreamSerDes.h b/include/SerDes/bitstreamSerDes.h index cd6d525..cd3037f 100644 --- a/include/SerDes/bitstreamSerDes.h +++ b/include/SerDes/bitstreamSerDes.h @@ -28,7 +28,7 @@ namespace MLBridge { /// information followed by the raw data. class BitstreamSerDes : public BaseSerDes { public: - BitstreamSerDes() : BaseSerDes(Kind::Bitstream) { + BitstreamSerDes() : BaseSerDes(SerDesKind::Bitstream) { Buffer = ""; tensorSpecs = std::vector(); rawData = std::vector(); diff --git a/include/SerDes/jsonSerDes.h b/include/SerDes/jsonSerDes.h index 6df810d..c20f87f 100644 --- a/include/SerDes/jsonSerDes.h +++ b/include/SerDes/jsonSerDes.h @@ -23,10 +23,10 @@ namespace MLBridge { /// JsonSerDes - Json Serialization/Deserialization using LLVM's json library. class JsonSerDes : public BaseSerDes { public: - JsonSerDes() : BaseSerDes(BaseSerDes::Kind::Json){}; + JsonSerDes() : BaseSerDes(SerDesKind::Json){}; static bool classof(const BaseSerDes *S) { - return S->getKind() == BaseSerDes::Kind::Json; + return S->getKind() == SerDesKind::Json; } #define SET_FEATURE(TYPE, _) \ diff --git a/include/SerDes/protobufSerDes.h b/include/SerDes/protobufSerDes.h index 0cde21c..25b7e3f 100644 --- a/include/SerDes/protobufSerDes.h +++ b/include/SerDes/protobufSerDes.h @@ -25,10 +25,10 @@ namespace MLBridge { /// communication. class ProtobufSerDes : public BaseSerDes { public: - ProtobufSerDes() : BaseSerDes(Kind::Protobuf){}; + ProtobufSerDes() : BaseSerDes(SerDesKind::Protobuf){}; static bool classof(const BaseSerDes *S) { - return S->getKind() == BaseSerDes::Kind::Protobuf; + return S->getKind() == SerDesKind::Protobuf; } void setRequest(void *Request) override; diff --git a/include/SerDes/tensorflowSerDes.h b/include/SerDes/tensorflowSerDes.h index f2b3909..02d37ae 100644 --- a/include/SerDes/tensorflowSerDes.h +++ b/include/SerDes/tensorflowSerDes.h @@ -21,10 +21,10 @@ namespace MLBridge { /// TensorflowSerDes - Serialization/Deserialization to support TF AOT models. class TensorflowSerDes : public BaseSerDes { public: - TensorflowSerDes() : BaseSerDes(Kind::Tensorflow) {} + TensorflowSerDes() : BaseSerDes(SerDesKind::Tensorflow) {} static bool classof(const BaseSerDes *S) { - return S->getKind() == BaseSerDes::Kind::Tensorflow; + return S->getKind() == SerDesKind::Tensorflow; } #define SET_FEATURE(TYPE, _) \