Skip to content

Commit

Permalink
Update BaseSerDes::Kind to SerDesKind
Browse files Browse the repository at this point in the history
  • Loading branch information
Umesh-k26 committed May 7, 2024
1 parent b384b49 commit 0767f69
Show file tree
Hide file tree
Showing 10 changed files with 30 additions and 30 deletions.
2 changes: 1 addition & 1 deletion MLModelRunner/PipeModelRunner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ using namespace llvm;

namespace MLBridge {
PipeModelRunner::PipeModelRunner(StringRef OutboundName, StringRef InboundName,
BaseSerDes::Kind SerDesType, LLVMContext *Ctx)
SerDesKind SerDesType, LLVMContext *Ctx)
: MLModelRunner(Kind::Pipe, SerDesType, Ctx),
InEC(sys::fs::openFileForRead(InboundName, Inbound)) {
this->InboundName = InboundName.str();
Expand Down
18 changes: 9 additions & 9 deletions include/MLModelRunner/MLModelRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class MLModelRunner {
enum class Kind : int { Unknown, Pipe, gRPC, ONNX, TFAOT };

Kind getKind() const { return Type; }
BaseSerDes::Kind getSerDesKind() const { return SerDesType; }
SerDesKind getSerDesKind() const { return SerDesType; }

virtual void requestExit() = 0;

Expand Down Expand Up @@ -118,15 +118,15 @@ class MLModelRunner {
void setResponse(void *response) { SerDes->setResponse(response); }

protected:
MLModelRunner(Kind Type, BaseSerDes::Kind SerDesType,
MLModelRunner(Kind Type, SerDesKind SerDesType,
llvm::LLVMContext *Ctx = nullptr)
: Ctx(Ctx), Type(Type), SerDesType(SerDesType) {
assert(Type != Kind::Unknown);
initSerDes();
}

MLModelRunner(Kind Type, llvm::LLVMContext *Ctx = nullptr)
: Ctx(Ctx), Type(Type), SerDesType(BaseSerDes::Kind::Unknown) {
: Ctx(Ctx), Type(Type), SerDesType(SerDesKind::Unknown) {
SerDes = nullptr;
};

Expand All @@ -136,29 +136,29 @@ class MLModelRunner {

llvm::LLVMContext *Ctx;
const Kind Type;
const BaseSerDes::Kind SerDesType;
const SerDesKind SerDesType;

protected:
std::unique_ptr<BaseSerDes> SerDes;

private:
void initSerDes() {
switch (SerDesType) {
case BaseSerDes::Kind::Json:
case SerDesKind::Json:
SerDes = std::make_unique<JsonSerDes>();
break;
case BaseSerDes::Kind::Bitstream:
case SerDesKind::Bitstream:
SerDes = std::make_unique<BitstreamSerDes>();
break;
#ifndef C_LIBRARY
case BaseSerDes::Kind::Protobuf:
case SerDesKind::Protobuf:
SerDes = std::make_unique<ProtobufSerDes>();
break;
case BaseSerDes::Kind::Tensorflow:
case SerDesKind::Tensorflow:
SerDes = std::make_unique<TensorflowSerDes>();
break;
#endif
case BaseSerDes::Kind::Unknown:
case SerDesKind::Unknown:
SerDes = nullptr;
break;
}
Expand Down
2 changes: 1 addition & 1 deletion include/MLModelRunner/PipeModelRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace MLBridge {
class PipeModelRunner : public MLModelRunner {
public:
PipeModelRunner(llvm::StringRef OutboundName, llvm::StringRef InboundName,
BaseSerDes::Kind Kind, llvm::LLVMContext *Ctx = nullptr);
SerDesKind Kind, llvm::LLVMContext *Ctx = nullptr);

static bool classof(const MLModelRunner *R) {
return R->getKind() == MLModelRunner::Kind::Pipe;
Expand Down
5 changes: 2 additions & 3 deletions include/MLModelRunner/TFModelRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@ template <class TGen> class TFModelRunner final : public MLModelRunner {
TFModelRunner(llvm::StringRef DecisionName, llvm::LLVMContext &Ctx,
llvm::StringRef FeedPrefix = "feed_",
llvm::StringRef FetchPrefix = "fetch_")
: MLModelRunner(MLModelRunner::Kind::TFAOT, BaseSerDes::Kind::Tensorflow,
&Ctx),
: MLModelRunner(MLModelRunner::Kind::TFAOT, SerDesKind::Tensorflow, &Ctx),
CompiledModel(std::make_unique<TGen>()) {

SerDes->setRequest(CompiledModel.get());
Expand All @@ -49,7 +48,7 @@ template <class TGen> class TFModelRunner final : public MLModelRunner {
TFModelRunner(llvm::StringRef DecisionName,
llvm::StringRef FeedPrefix = "feed_",
llvm::StringRef FetchPrefix = "fetch_")
: MLModelRunner(MLModelRunner::Kind::TFAOT, BaseSerDes::Kind::Tensorflow),
: MLModelRunner(MLModelRunner::Kind::TFAOT, SerDesKind::Tensorflow),
CompiledModel(std::make_unique<TGen>()) {

SerDes->setRequest(CompiledModel.get());
Expand Down
7 changes: 3 additions & 4 deletions include/MLModelRunner/gRPCModelRunner.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include <memory>
#include <thread>

namespace MLBridge {
/// This class is used to create the grpc model runner object. grpc model runner
Expand All @@ -81,8 +82,7 @@ class gRPCModelRunner : public MLModelRunner {
/// For server mode
gRPCModelRunner(std::string server_address, grpc::Service *s,
llvm::LLVMContext *Ctx = nullptr)
: MLModelRunner(MLModelRunner::Kind::gRPC, BaseSerDes::Kind::Protobuf,
Ctx),
: MLModelRunner(MLModelRunner::Kind::gRPC, SerDesKind::Protobuf, Ctx),
server_address(server_address), request(nullptr), response(nullptr),
server_mode(true) {
RunService(s);
Expand All @@ -91,8 +91,7 @@ class gRPCModelRunner : public MLModelRunner {
/// For client mode
gRPCModelRunner(std::string server_address, Request *request,
Response *response, llvm::LLVMContext *Ctx = nullptr)
: MLModelRunner(MLModelRunner::Kind::gRPC, BaseSerDes::Kind::Protobuf,
Ctx),
: MLModelRunner(MLModelRunner::Kind::gRPC, SerDesKind::Protobuf, Ctx),
server_address(server_address), request(request), response(response),
server_mode(false) {
SetStub();
Expand Down
12 changes: 7 additions & 5 deletions include/SerDes/baseSerDes.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
/// 1. Create a new class which inherits from BaseSerDes.
/// 2. Implement the setFeature(), getSerializedData(), cleanDataStructures()
/// and deserializeUntyped() methods.
/// 3. Add the new SerDes to the enum class Kind in this class.
/// 3. Add the new SerDes to the enum class SerDesKind in this class.
///
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -43,10 +43,10 @@ namespace MLBridge {
/// communication by the MLModelRunner.
/// Currently, (int, float) or (long, double), char and bool are supported.
/// Vectors of these types are supported as well.
enum class SerDesKind : int { Unknown, Json, Bitstream, Protobuf, Tensorflow };
class BaseSerDes {
public:
enum class Kind : int { Unknown, Json, Bitstream, Protobuf, Tensorflow };
Kind getKind() const { return Type; }
SerDesKind getKind() const { return Type; }

/// setFeature() is used to set the features of the data structure used for
/// communication. The features are set as key-value pairs. The key is a
Expand Down Expand Up @@ -75,9 +75,11 @@ class BaseSerDes {
virtual void *getResponse() { return nullptr; };

protected:
BaseSerDes(Kind Type) : Type(Type) { assert(Type != Kind::Unknown); }
BaseSerDes(SerDesKind Type) : Type(Type) {
assert(Type != SerDesKind::Unknown);
}
virtual void cleanDataStructures() = 0;
const Kind Type;
const SerDesKind Type;
void *RequestVoid;
void *ResponseVoid;
size_t MessageLength;
Expand Down
2 changes: 1 addition & 1 deletion include/SerDes/bitstreamSerDes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace MLBridge {
/// information followed by the raw data.
class BitstreamSerDes : public BaseSerDes {
public:
BitstreamSerDes() : BaseSerDes(Kind::Bitstream) {
BitstreamSerDes() : BaseSerDes(SerDesKind::Bitstream) {
Buffer = "";
tensorSpecs = std::vector<TensorSpec>();
rawData = std::vector<const void *>();
Expand Down
4 changes: 2 additions & 2 deletions include/SerDes/jsonSerDes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ namespace MLBridge {
/// JsonSerDes - Json Serialization/Deserialization using LLVM's json library.
class JsonSerDes : public BaseSerDes {
public:
JsonSerDes() : BaseSerDes(BaseSerDes::Kind::Json){};
JsonSerDes() : BaseSerDes(SerDesKind::Json){};

static bool classof(const BaseSerDes *S) {
return S->getKind() == BaseSerDes::Kind::Json;
return S->getKind() == SerDesKind::Json;
}

#define SET_FEATURE(TYPE, _) \
Expand Down
4 changes: 2 additions & 2 deletions include/SerDes/protobufSerDes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ namespace MLBridge {
/// communication.
class ProtobufSerDes : public BaseSerDes {
public:
ProtobufSerDes() : BaseSerDes(Kind::Protobuf){};
ProtobufSerDes() : BaseSerDes(SerDesKind::Protobuf){};

static bool classof(const BaseSerDes *S) {
return S->getKind() == BaseSerDes::Kind::Protobuf;
return S->getKind() == SerDesKind::Protobuf;
}

void setRequest(void *Request) override;
Expand Down
4 changes: 2 additions & 2 deletions include/SerDes/tensorflowSerDes.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ namespace MLBridge {
/// TensorflowSerDes - Serialization/Deserialization to support TF AOT models.
class TensorflowSerDes : public BaseSerDes {
public:
TensorflowSerDes() : BaseSerDes(Kind::Tensorflow) {}
TensorflowSerDes() : BaseSerDes(SerDesKind::Tensorflow) {}

static bool classof(const BaseSerDes *S) {
return S->getKind() == BaseSerDes::Kind::Tensorflow;
return S->getKind() == SerDesKind::Tensorflow;
}

#define SET_FEATURE(TYPE, _) \
Expand Down

0 comments on commit 0767f69

Please sign in to comment.