From d2fc0252ba036a23a60975e8754523830f403043 Mon Sep 17 00:00:00 2001 From: Zhao Wu Date: Mon, 11 Nov 2019 06:56:44 +0800 Subject: [PATCH] [RUTNIME] Support C++ RPC (#4281) --- apps/cpp_rpc/Makefile | 53 +++++ apps/cpp_rpc/README.md | 56 +++++ apps/cpp_rpc/main.cc | 265 ++++++++++++++++++++++ apps/cpp_rpc/rpc_env.cc | 254 +++++++++++++++++++++ apps/cpp_rpc/rpc_env.h | 80 +++++++ apps/cpp_rpc/rpc_server.cc | 359 ++++++++++++++++++++++++++++++ apps/cpp_rpc/rpc_server.h | 52 +++++ apps/cpp_rpc/rpc_tracker_client.h | 246 ++++++++++++++++++++ src/common/socket.h | 187 +++++++++++++++- src/common/util.h | 158 +++++++++++++ src/runtime/rpc/rpc_session.h | 19 ++ src/runtime/rpc/rpc_socket_impl.h | 39 ++++ 12 files changed, 1766 insertions(+), 2 deletions(-) create mode 100644 apps/cpp_rpc/Makefile create mode 100644 apps/cpp_rpc/README.md create mode 100644 apps/cpp_rpc/main.cc create mode 100644 apps/cpp_rpc/rpc_env.cc create mode 100644 apps/cpp_rpc/rpc_env.h create mode 100644 apps/cpp_rpc/rpc_server.cc create mode 100644 apps/cpp_rpc/rpc_server.h create mode 100644 apps/cpp_rpc/rpc_tracker_client.h create mode 100644 src/common/util.h create mode 100644 src/runtime/rpc/rpc_socket_impl.h diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile new file mode 100644 index 000000000000..9cd39b446acc --- /dev/null +++ b/apps/cpp_rpc/Makefile @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Makefile to compile RPC Server. +TVM_ROOT=$(shell cd ../..; pwd) +DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core +TVM_RUNTIME_DIR?= +OS?= + +# Android can not link pthrad, but Linux need. +ifeq ($(OS), Linux) +LINK_PTHREAD=-lpthread +else +LINK_PTHREAD= +endif + +PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\ + -I${TVM_ROOT}/include\ + -I${DMLC_CORE}/include\ + -I${TVM_ROOT}/3rdparty/dlpack/include + +PKG_LDFLAGS = -L$(TVM_RUNTIME_DIR) $(LINK_PTHREAD) -ltvm_runtime -ldl -Wl,-R$(TVM_RUNTIME_DIR) + +ifeq ($(USE_GLOG), 1) + PKG_CFLAGS += -DDMLC_USE_GLOG=1 + PKG_LDFLAGS += -lglog +endif + +.PHONY: clean all + +all: tvm_rpc + +# Build rule for all in one TVM package library +tvm_rpc: *.cc + @mkdir -p $(@D) + $(CXX) $(PKG_CFLAGS) -o $@ $(filter %.cc %.o %.a, $^) $(PKG_LDFLAGS) + +clean: + -rm -f tvm_rpc \ No newline at end of file diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md new file mode 100644 index 000000000000..4baecaf25150 --- /dev/null +++ b/apps/cpp_rpc/README.md @@ -0,0 +1,56 @@ + + + + + + + + + + + + + + + + + +# TVM RPC Server +This folder contains a simple recipe to make RPC server in c++. + +## Usage +- Build tvm runtime +- Make the rpc executable [Makefile](Makefile). + `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux` + if you want to compile it for embedded Linux, you should add `OS=Linux`. + if the target os is Android, you doesn't need to pass OS argument. + You could cross compile the TVM runtime like this: +``` + cd tvm + mkdir arm_runtime + cp cmake/config.cmake arm_runtime + cd arm_runtime + cmake .. -DCMAKE_CXX_COMPILER="/path/to/cross compiler g++/" + make runtime +``` +- Use `./tvm_rpc server` to start the RPC server + +## How it works +- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library. + +``` +Command line usage + server - Start the server +--host - The hostname of the server, Default=0.0.0.0 +--port - The port of the RPC, Default=9090 +--port-end - The end search port of the RPC, Default=9199 +--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default="" +--key - The key used to identify the device type in tracker. Default="" +--custom-addr - Custom IP Address to Report to RPC Tracker. Default="" +--silent - Whether to run in silent mode. Default=False + Example + ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 --tracker=127.0.0.1:9190 --key=rasp +``` + +## Note +Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently. \ No newline at end of file diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc new file mode 100644 index 000000000000..3cf2ed6a5d59 --- /dev/null +++ b/apps/cpp_rpc/main.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_server.cc + * \brief RPC Server for TVM. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../src/common/util.h" +#include "../../src/common/socket.h" +#include "rpc_server.h" + +using namespace std; +using namespace tvm::runtime; +using namespace tvm::common; + +static const string kUSAGE = \ +"Command line usage\n" \ +" server - Start the server\n" \ +"--host - The hostname of the server, Default=0.0.0.0\n" \ +"--port - The port of the RPC, Default=9090\n" \ +"--port-end - The end search port of the RPC, Default=9199\n" \ +"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" \ +"--key - The key used to identify the device type in tracker. Default=\"\"\n" \ +"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" \ +"--silent - Whether to run in silent mode. Default=False\n" \ +"\n" \ +" Example\n" \ +" ./tvm_rpc server --host=0.0.0.0 --port=9000 --port-end=9090 " +" --tracker=127.0.0.1:9190 --key=rasp" \ +"\n"; + +/*! + * \brief RpcServerArgs. + * \arg host The hostname of the server, Default=0.0.0.0 + * \arg port The port of the RPC, Default=9090 + * \arg port_end The end search port of the RPC, Default=9199 + * \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \arg key The key used to identify the device type in tracker. Default="" + * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \arg silent Whether run in silent mode. Default=False + */ +struct RpcServerArgs { + string host = "0.0.0.0"; + int port = 9090; + int port_end = 9099; + string tracker; + string key; + string custom_addr; + bool silent = false; +}; + +/*! + * \brief PrintArgs print the contents of RpcServerArgs + * \param args RpcServerArgs structure + */ +void PrintArgs(struct RpcServerArgs args) { + LOG(INFO) << "host = " << args.host; + LOG(INFO) << "port = " << args.port; + LOG(INFO) << "port_end = " << args.port_end; + LOG(INFO) << "tracker = " << args.tracker; + LOG(INFO) << "key = " << args.key; + LOG(INFO) << "custom_addr = " << args.custom_addr; + LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); +} + +/*! + * \brief CtrlCHandler, exits if Ctrl+C is pressed + * \param s signal + */ +void CtrlCHandler(int s) { + LOG(INFO) << "\nUser pressed Ctrl+C, Exiting"; + exit(1); +} + +/*! + * \brief HandleCtrlC Register for handling Ctrl+C event. + */ +void HandleCtrlC() { + // Ctrl+C handler + struct sigaction sigIntHandler; + sigIntHandler.sa_handler = CtrlCHandler; + sigemptyset(&sigIntHandler.sa_mask); + sigIntHandler.sa_flags = 0; + sigaction(SIGINT, &sigIntHandler, nullptr); +} + +/*! + * \brief GetCmdOption Parse and find the command option. + * \param argc arg counter + * \param argv arg values + * \param option command line option to search for. + * \param key whether the option itself is key + * \return value corresponding to option. + */ +string GetCmdOption(int argc, char* argv[], string option, bool key = false) { + string cmd; + for (int i = 1; i < argc; ++i) { + string arg = argv[i]; + if (arg.find(option) == 0) { + if (key) { + cmd = argv[i]; + return cmd; + } + // We assume "=" is the end of option. + CHECK_EQ(*option.rbegin(), '='); + cmd = arg.substr(arg.find("=") + 1); + return cmd; + } + } + return cmd; +} + +/*! + * \brief ValidateTracker Check the tracker address format is correct and changes the format. + * \param tracker The tracker input. + * \return result of operation. + */ +bool ValidateTracker(string &tracker) { + vector list = Split(tracker, ':'); + if ((list.size() != 2) || (!ValidateIP(list[0])) || (!IsNumber(list[1]))) { + return false; + } + ostringstream ss; + ss << "('" << list[0] << "', " << list[1] << ")"; + tracker = ss.str(); + return true; +} + +/*! + * \brief ParseCmdArgs parses the command line arguments. + * \param argc arg counter + * \param argv arg values + * \param args, the output structure which holds the parsed values + */ +void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { + string silent = GetCmdOption(argc, argv, "--silent", true); + if (!silent.empty()) { + args.silent = true; + // Only errors and fatal is logged + dmlc::InitLogging("--minloglevel=2"); + } + + string host = GetCmdOption(argc, argv, "--host="); + if (!host.empty()) { + if (!ValidateIP(host)) { + LOG(WARNING) << "Wrong host address format."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.host = host; + } + + string port = GetCmdOption(argc, argv, "--port="); + if (!port.empty()) { + if (!IsNumber(port) || stoi(port) > 65535) { + LOG(WARNING) << "Wrong port number."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.port = stoi(port); + } + + string port_end = GetCmdOption(argc, argv, "--port_end="); + if (!port_end.empty()) { + if (!IsNumber(port_end) || stoi(port_end) > 65535) { + LOG(WARNING) << "Wrong port_end number."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.port_end = stoi(port_end); + } + + string tracker = GetCmdOption(argc, argv, "--tracker="); + if (!tracker.empty()) { + if (!ValidateTracker(tracker)) { + LOG(WARNING) << "Wrong tracker address format."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.tracker = tracker; + } + + string key = GetCmdOption(argc, argv, "--key="); + if (!key.empty()) { + args.key = key; + } + + string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); + if (!custom_addr.empty()) { + if (!ValidateIP(custom_addr)) { + LOG(WARNING) << "Wrong custom address format."; + LOG(INFO) << kUSAGE; + exit(1); + } + args.custom_addr = custom_addr; + } +} + +/*! + * \brief RpcServer Starts the RPC server. + * \param argc arg counter + * \param argv arg values + * \return result of operation. + */ +int RpcServer(int argc, char * argv[]) { + struct RpcServerArgs args; + + /* parse the command line args */ + ParseCmdArgs(argc, argv, args); + PrintArgs(args); + + // Ctrl+C handler + LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; + HandleCtrlC(); + tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker, + args.key, args.custom_addr, args.silent); + return 0; +} + +/*! + * \brief main The main function. + * \param argc arg counter + * \param argv arg values + * \return result of operation. + */ +int main(int argc, char * argv[]) { + if (argc <= 1) { + LOG(INFO) << kUSAGE; + return 0; + } + + if (0 == strcmp(argv[1], "server")) { + RpcServer(argc, argv); + } else { + LOG(INFO) << kUSAGE; + } + + return 0; +} diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc new file mode 100644 index 000000000000..44f848dc749e --- /dev/null +++ b/apps/cpp_rpc/rpc_env.cc @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file rpc_env.cc + * \brief Server environment of the RPC. + */ +#include +#include +#ifndef _MSC_VER +#include +#include +#include +#else +#include +#endif +#include +#include +#include +#include +#include + +#include "rpc_env.h" +#include "../../src/common/util.h" +#include "../../src/runtime/file_util.h" + +namespace tvm { +namespace runtime { + +RPCEnv::RPCEnv() { + #if defined(__linux__) || defined(__ANDROID__) + base_ = "./rpc"; + mkdir(&base_[0], 0777); + + TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") + .set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + *rv = env.GetPath(args[0]); + }); + + TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") + .set_body([](TVMArgs args, TVMRetValue *rv) { + static RPCEnv env; + std::string file_name = env.GetPath(args[0]); + *rv = Load(&file_name, ""); + LOG(INFO) << "Load module from " << file_name << " ..."; + }); + #else + LOG(FATAL) << "Only support RPC in linux environment"; + #endif +} +/*! + * \brief GetPath To get the workpath from packed function + * \param name The file name + * \return The full path of file. + */ +std::string RPCEnv::GetPath(std::string file_name) { + // we assume file_name has "/" means file_name is the exact path + // and does not create /.rpc/ + if (file_name.find("/") != std::string::npos) { + return file_name; + } else { + return base_ + "/" + file_name; + } +} +/*! + * \brief Remove The RPC Environment cleanup function + */ +void RPCEnv::CleanUp() { + #if defined(__linux__) || defined(__ANDROID__) + CleanDir(&base_[0]); + int ret = rmdir(&base_[0]); + if (ret != 0) { + LOG(WARNING) << "Remove directory " << base_ << " failed"; + } + #else + LOG(FATAL) << "Only support RPC in linux environment"; + #endif +} + +/*! + * \brief ListDir get the list of files in a directory + * \param dirname The root directory name + * \return vector Files in directory. + */ +std::vector ListDir(const std::string &dirname) { + std::vector vec; + #ifndef _MSC_VER + DIR *dp = opendir(dirname.c_str()); + if (dp == nullptr) { + int errsv = errno; + LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv); + } + dirent *d; + while ((d = readdir(dp)) != nullptr) { + std::string filename = d->d_name; + if (filename != "." && filename != "..") { + std::string f = dirname; + if (f[f.length() - 1] != '/') { + f += '/'; + } + f += d->d_name; + vec.push_back(f); + } + } + closedir(dp); + #else + WIN32_FIND_DATA fd; + std::string pattern = dirname + "/*"; + HANDLE handle = FindFirstFile(pattern.c_str(), &fd); + if (handle == INVALID_HANDLE_VALUE) { + int errsv = GetLastError(); + LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + } + do { + if (fd.cFileName != "." && fd.cFileName != "..") { + std::string f = dirname; + char clast = f[f.length() - 1]; + if (f == ".") { + f = fd.cFileName; + } else if (clast != '/' && clast != '\\') { + f += '/'; + f += fd.cFileName; + } + vec.push_back(f); + } + } while (FindNextFile(handle, &fd)); + FindClose(handle); + #endif + return vec; +} + +/*! + * \brief LinuxShared Creates a linux shared library + * \param output The output file name + * \param files The files for building + * \param options The compiler options + * \param cc The compiler + */ +void LinuxShared(const std::string output, + const std::vector &files, + std::string options = "", + std::string cc = "g++") { + std::string cmd = cc; + cmd += " -shared -fPIC "; + cmd += " -o " + output; + for (auto f = files.begin(); f != files.end(); ++f) { + cmd += " " + *f; + } + cmd += " " + options; + std::string err_msg; + auto executed_status = common::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; + } +} + +/*! + * \brief CreateShared Creates a shared library + * \param output The output file name + * \param files The files for building + */ +void CreateShared(const std::string output, const std::vector &files) { + #if defined(__linux__) || defined(__ANDROID__) + LinuxShared(output, files); + #else + LOG(FATAL) << "Do not support creating shared library"; + #endif +} + +/*! + * \brief Load Load module from file + This function will automatically call + cc.create_shared if the path is in format .o or .tar + High level handling for .o and .tar file. + We support this to be consistent with RPC module load. + * \param fileIn The input file, file name will be updated + * \param fmt The format of file + * \return Module The loaded module + */ +Module Load(std::string *fileIn, const std::string fmt) { + std::string file = *fileIn; + if (common::EndsWith(file, ".so")) { + return Module::LoadFromFile(file, fmt); + } + + #if defined(__linux__) || defined(__ANDROID__) + std::string file_name = file + ".so"; + if (common::EndsWith(file, ".o")) { + std::vector files; + files.push_back(file); + CreateShared(file_name, files); + } else if (common::EndsWith(file, ".tar")) { + std::string tmp_dir = "./rpc/tmp/"; + mkdir(&tmp_dir[0], 0777); + std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; + std::string err_msg; + int executed_status = common::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; + } + CreateShared(file_name, ListDir(tmp_dir)); + CleanDir(tmp_dir); + rmdir(&tmp_dir[0]); + } else { + file_name = file; + } + *fileIn = file_name; + return Module::LoadFromFile(file_name, fmt); + #else + LOG(FATAL) << "Do not support creating shared library"; + #endif +} + +/*! + * \brief CleanDir Removes the files from the directory + * \param dirname The name of the directory + */ +void CleanDir(const std::string &dirname) { + #if defined(__linux__) || defined(__ANDROID__) + DIR *dp = opendir(dirname.c_str()); + dirent *d; + while ((d = readdir(dp)) != nullptr) { + std::string filename = d->d_name; + if (filename != "." && filename != "..") { + filename = dirname + "/" + d->d_name; + int ret = std::remove(&filename[0]); + if (ret != 0) { + LOG(WARNING) << "Remove file " << filename << " failed"; + } + } + } + #else + LOG(FATAL) << "Only support RPC in linux environment"; + #endif +} + +} // namespace runtime +} // namespace tvm diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h new file mode 100644 index 000000000000..82409bae81a1 --- /dev/null +++ b/apps/cpp_rpc/rpc_env.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_env.h + * \brief Server environment of the RPC. + */ +#ifndef TVM_APPS_CPP_RPC_ENV_H_ +#define TVM_APPS_CPP_RPC_ENV_H_ + +#include +#include + +namespace tvm { +namespace runtime { + +/*! + * \brief Load Load module from file + This function will automatically call + cc.create_shared if the path is in format .o or .tar + High level handling for .o and .tar file. + We support this to be consistent with RPC module load. + * \param file The input file + * \param file The format of file + * \return Module The loaded module + */ +Module Load(std::string *path, const std::string fmt = ""); + +/*! + * \brief CleanDir Removes the files from the directory + * \param dirname THe name of the directory + */ +void CleanDir(const std::string &dirname); + +/*! + * \brief RPCEnv The RPC Environment parameters for c++ rpc server + */ +struct RPCEnv { + public: + /*! + * \brief Constructor Init The RPC Environment initialize function + */ + RPCEnv(); + /*! + * \brief GetPath To get the workpath from packed function + * \param name The file name + * \return The full path of file. + */ + std::string GetPath(std::string file_name); + /*! + * \brief The RPC Environment cleanup function + */ + void CleanUp(); + + private: + /*! + * \brief Holds the environment path. + */ + std::string base_; +}; // RPCEnv + +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_ENV_H_ diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc new file mode 100644 index 000000000000..b35a63bd67dc --- /dev/null +++ b/apps/cpp_rpc/rpc_server.cc @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_server.cc + * \brief RPC Server implementation. + */ +#include + +#if defined(__linux__) || defined(__ANDROID__) +#include +#include +#endif +#include +#include +#include +#include +#include +#include + +#include "rpc_server.h" +#include "rpc_env.h" +#include "rpc_tracker_client.h" +#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/runtime/rpc/rpc_socket_impl.h" +#include "../../src/common/socket.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief wait the child process end. + * \param status status value + */ +#if defined(__linux__) || defined(__ANDROID__) +static pid_t waitPidEintr(int *status) { + pid_t pid = 0; + while ((pid = waitpid(-1, status, 0)) == -1) { + if (errno == EINTR) { + continue; + } else { + perror("waitpid"); + abort(); + } + } + return pid; +} +#endif + +/*! + * \brief RPCServer RPC Server class. + * \param host The hostname of the server, Default=0.0.0.0 + * \param port The port of the RPC, Default=9090 + * \param port_end The end search port of the RPC, Default=9199 + * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param key The key used to identify the device type in tracker. Default="" + * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + */ +class RPCServer { + public: + /*! + * \brief Constructor. + */ + RPCServer(const std::string &host, + int port, + int port_end, + const std::string &tracker_addr, + const std::string &key, + const std::string &custom_addr) { + // Init the values + host_ = host; + port_ = port; + port_end_ = port_end; + tracker_addr_ = tracker_addr; + key_ = key; + custom_addr_ = custom_addr; + } + + /*! + * \brief Destructor. + */ + ~RPCServer() { + // Free the resources + tracker_sock_.Close(); + listen_sock_.Close(); + } + + /*! + * \brief Start Creates the RPC listen process and execution. + */ + void Start() { + listen_sock_.Create(); + my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); + LOG(INFO) << "bind to " << host_ << ":" << my_port_; + listen_sock_.Listen(1); + std::future proc(std::async(std::launch::async, &RPCServer::ListenLoopProc, this)); + proc.get(); + // Close the listen socket + listen_sock_.Close(); + } + + private: + /*! + * \brief ListenLoopProc The listen process. + */ + void ListenLoopProc() { + TrackerClient tracker(tracker_addr_, key_, custom_addr_); + while (true) { + common::TCPSocket conn; + common::SockAddr addr("0.0.0.0", 0); + std::string opts; + try { + // step 1: setup tracker and report to tracker + tracker.TryConnect(); + // step 2: wait for in-coming connections + AcceptConnection(&tracker, &conn, &addr, &opts); + } + catch (const char* msg) { + LOG(WARNING) << "Socket exception: " << msg; + // close tracker resource + tracker.Close(); + continue; + } + catch (std::exception& e) { + // Other errors + LOG(WARNING) << "Exception standard: " << e.what(); + continue; + } + + int timeout = GetTimeOutFromOpts(opts); + #if defined(__linux__) || defined(__ANDROID__) + // step 3: serving + if (timeout != 0) { + const pid_t timer_pid = fork(); + if (timer_pid == 0) { + // Timer process + sleep(timeout); + exit(0); + } + + const pid_t worker_pid = fork(); + if (worker_pid == 0) { + // Worker process + ServerLoopProc(conn, addr); + exit(0); + } + + int status = 0; + const pid_t finished_first = waitPidEintr(&status); + if (finished_first == timer_pid) { + kill(worker_pid, SIGKILL); + } else if (finished_first == worker_pid) { + kill(timer_pid, SIGKILL); + } else { + LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; + } + + int status_second = 0; + waitPidEintr(&status_second); + + // Logging. + if (finished_first == timer_pid) { + LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout + << "), Process status = " << status_second; + } else if (finished_first == worker_pid) { + LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; + } + } else { + auto pid = fork(); + if (pid == 0) { + ServerLoopProc(conn, addr); + exit(0); + } + // Wait for the result + int status = 0; + wait(&status); + LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; + } + #else + // step 3: serving + std::future proc(std::async(std::launch::async, + &RPCServer::ServerLoopProc, this, conn, addr)); + // wait until server process finish or timeout + if (timeout != 0) { + // Autoterminate after timeout + proc.wait_for(std::chrono::seconds(timeout)); + } else { + // Wait for the result + proc.get(); + } + #endif + // close from our side. + LOG(INFO) << "Socket Connection Closed"; + conn.Close(); + } + } + + + /*! + * \brief AcceptConnection Accepts the RPC Server connection. + * \param tracker Tracker details. + * \param conn New connection information. + * \param addr New connection address information. + * \param opts Parsed options for socket + * \param ping_period Timeout for select call waiting + */ + void AcceptConnection(TrackerClient* tracker, + common::TCPSocket* conn_sock, + common::SockAddr* addr, + std::string* opts, + int ping_period = 2) { + std::set old_keyset; + std::string matchkey; + + // Report resource to tracker and get key + tracker->ReportResourceAndGetKey(my_port_, &matchkey); + + while (true) { + tracker->WaitConnectionAndUpdateKey(listen_sock_, my_port_, ping_period, &matchkey); + common::TCPSocket conn = listen_sock_.Accept(addr); + + int code = kRPCMagic; + CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + if (code != kRPCMagic) { + conn.Close(); + LOG(FATAL) << "Client connected is not TVM RPC server"; + continue; + } + + int keylen = 0; + CHECK_EQ(conn.RecvAll(&keylen, sizeof(keylen)), sizeof(keylen)); + + const char* CLIENT_HEADER = "client:"; + const char* SERVER_HEADER = "server:"; + std::string expect_header = CLIENT_HEADER + matchkey; + std::string server_key = SERVER_HEADER + key_; + if (size_t(keylen) < expect_header.length()) { + conn.Close(); + LOG(INFO) << "Wrong client header length"; + continue; + } + + CHECK_NE(keylen, 0); + std::string remote_key; + remote_key.resize(keylen); + CHECK_EQ(conn.RecvAll(&remote_key[0], keylen), keylen); + + std::stringstream ssin(remote_key); + std::string arg0; + ssin >> arg0; + if (arg0 != expect_header) { + code = kRPCMismatch; + CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + conn.Close(); + LOG(WARNING) << "Mismatch key from" << addr->AsString(); + continue; + } else { + code = kRPCSuccess; + CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + keylen = server_key.length(); + CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); + CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); + LOG(INFO) << "Connection success " << addr->AsString(); + ssin >> *opts; + *conn_sock = conn; + return; + } + } + } + + /*! + * \brief ServerLoopProc The Server loop process. + * \param sock The socket information + * \param addr The socket address information + */ + void ServerLoopProc(common::TCPSocket sock, common::SockAddr addr) { + // Server loop + auto env = RPCEnv(); + RPCServerLoop(sock.sockfd); + LOG(INFO) << "Finish serving " << addr.AsString(); + env.CleanUp(); + } + + /*! + * \brief GetTimeOutFromOpts Parse and get the timeout option. + * \param opts The option string + * \param timeout value after parsing. + */ + int GetTimeOutFromOpts(std::string opts) { + std::string cmd; + std::string option = "-timeout="; + + if (opts.find(option) == 0) { + cmd = opts.substr(opts.find_last_of(option) + 1); + CHECK(common::IsNumber(cmd)) << "Timeout is not valid"; + return std::stoi(cmd); + } + return 0; + } + + std::string host_; + int port_; + int my_port_; + int port_end_; + std::string tracker_addr_; + std::string key_; + std::string custom_addr_; + common::TCPSocket listen_sock_; + common::TCPSocket tracker_sock_; +}; + +/*! + * \brief RPCServerCreate Creates the RPC Server. + * \param host The hostname of the server, Default=0.0.0.0 + * \param port The port of the RPC, Default=9090 + * \param port_end The end search port of the RPC, Default=9199 + * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param key The key used to identify the device type in tracker. Default="" + * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \param silent Whether run in silent mode. Default=True + */ +void RPCServerCreate(std::string host, + int port, + int port_end, + std::string tracker_addr, + std::string key, + std::string custom_addr, + bool silent) { + if (silent) { + // Only errors and fatal is logged + dmlc::InitLogging("--minloglevel=2"); + } + // Start the rpc server + RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr); + rpc.Start(); +} + +TVM_REGISTER_GLOBAL("rpc._ServerCreate") +.set_body([](TVMArgs args, TVMRetValue* rv) { + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + }); +} // namespace runtime +} // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h new file mode 100644 index 000000000000..205182e4449a --- /dev/null +++ b/apps/cpp_rpc/rpc_server.h @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_server.h + * \brief RPC Server implementation. + */ +#ifndef TVM_APPS_CPP_RPC_SERVER_H_ +#define TVM_APPS_CPP_RPC_SERVER_H_ + +#include +#include "tvm/runtime/c_runtime_api.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief RPCServerCreate Creates the RPC Server. + * \param host The hostname of the server, Default=0.0.0.0 + * \param port The port of the RPC, Default=9090 + * \param port_end The end search port of the RPC, Default=9199 + * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param key The key used to identify the device type in tracker. Default="" + * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \param silent Whether run in silent mode. Default=True + */ +TVM_DLL void RPCServerCreate(std::string host = "", + int port = 9090, + int port_end = 9099, + std::string tracker_addr = "", + std::string key = "", + std::string custom_addr = "", + bool silent = true); +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/apps/cpp_rpc/rpc_tracker_client.h b/apps/cpp_rpc/rpc_tracker_client.h new file mode 100644 index 000000000000..89424c7511f0 --- /dev/null +++ b/apps/cpp_rpc/rpc_tracker_client.h @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_tracker_client.h + * \brief RPC Tracker client to report resources. + */ +#ifndef TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ +#define TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ + +#include +#include +#include +#include +#include +#include + +#include "../../src/runtime/rpc/rpc_session.h" +#include "../../src/common/socket.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief TrackerClient Tracker client class. + * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param key The key used to identify the device type in tracker. Default="" + * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + */ +class TrackerClient { + public: + /*! + * \brief Constructor. + */ + TrackerClient(const std::string& tracker_addr, + const std::string& key, + const std::string& custom_addr) + : tracker_addr_(tracker_addr), key_(key), custom_addr_(custom_addr), + gen_(std::random_device{}()), dis_(0.0, 1.0) { + } + /*! + * \brief Destructor. + */ + ~TrackerClient() { + // Free the resources + Close(); + } + /*! + * \brief IsValid Check tracker is valid. + */ + bool IsValid() { + return (!tracker_addr_.empty() && !tracker_sock_.IsClosed()); + } + /*! + * \brief TryConnect Connect to tracker if the tracker address is valid. + */ + void TryConnect() { + if (!tracker_addr_.empty() && (tracker_sock_.IsClosed())) { + tracker_sock_ = ConnectWithRetry(); + + int code = kRPCTrackerMagic; + CHECK_EQ(tracker_sock_.SendAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(tracker_sock_.RecvAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(code, kRPCTrackerMagic) << tracker_addr_.c_str() << " is not RPC Tracker"; + + std::ostringstream ss; + ss << "[" << static_cast(TrackerCode::kUpdateInfo) + << ", {\"key\": \"server:"<< key_ << "\"}]"; + tracker_sock_.SendBytes(ss.str()); + + // Receive status and validate + std::string remote_status = tracker_sock_.RecvBytes(); + CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + } + } + /*! + * \brief Close Clean up tracker resources. + */ + void Close() { + // close tracker resource + if (!tracker_sock_.IsClosed()) { + tracker_sock_.Close(); + } + } + /*! + * \brief ReportResourceAndGetKey Report resource to tracker. + * \param port listening port. + * \param matchkey Random match key output. + */ + void ReportResourceAndGetKey(int port, + std::string *matchkey) { + if (!tracker_sock_.IsClosed()) { + *matchkey = RandomKey(key_ + ":", old_keyset_); + if (custom_addr_.empty()) { + custom_addr_ = "null"; + } + + std::ostringstream ss; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" + << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + + tracker_sock_.SendBytes(ss.str()); + + // Receive status and validate + std::string remote_status = tracker_sock_.RecvBytes(); + CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + } else { + *matchkey = key_; + } + } + + /*! + * \brief ReportResourceAndGetKey Report resource to tracker. + * \param listen_sock Listen socket details for select. + * \param port listening port. + * \param ping_period Select wait time. + * \param matchkey Random match key output. + */ + void WaitConnectionAndUpdateKey(common::TCPSocket listen_sock, + int port, + int ping_period, + std::string *matchkey) { + int unmatch_period_count = 0; + int unmatch_timeout = 4; + while (true) { + if (!tracker_sock_.IsClosed()) { + common::PollHelper poller; + poller.WatchRead(listen_sock.sockfd); + poller.Poll(ping_period * 1000); + if (!poller.CheckRead(listen_sock.sockfd)) { + std::ostringstream ss; + ss << "[" << int(TrackerCode::kGetPendingMatchKeys) << "]"; + tracker_sock_.SendBytes(ss.str()); + + // Receive status and validate + std::string pending_keys = tracker_sock_.RecvBytes(); + old_keyset_.insert(*matchkey); + + // if match key not in pending key set + // it means the key is acquired by a client but not used. + if (pending_keys.find(*matchkey) == std::string::npos) { + unmatch_period_count += 1; + } else { + unmatch_period_count = 0; + } + // regenerate match key if key is acquired but not used for a while + if (unmatch_period_count * ping_period > unmatch_timeout + ping_period) { + LOG(INFO) << "no incoming connections, regenerate key ..."; + + *matchkey = RandomKey(key_ + ":", old_keyset_); + + std::ostringstream ss; + ss << "[" << static_cast(TrackerCode::kPut) << ", \"" << key_ << "\", [" + << port << ", \"" << *matchkey << "\"], " << custom_addr_ << "]"; + tracker_sock_.SendBytes(ss.str()); + + std::string remote_status = tracker_sock_.RecvBytes(); + CHECK_EQ(std::stoi(remote_status), static_cast(TrackerCode::kSuccess)); + unmatch_period_count = 0; + } + continue; + } + } + break; + } + } + + private: + /*! + * \brief Connect to a RPC address with retry. + This function is only reliable to short period of server restart. + * \param timeout Timeout during retry + * \param retry_period Number of seconds before we retry again. + * \return TCPSocket The socket information if connect is success. + */ + common::TCPSocket ConnectWithRetry(int timeout = 60, int retry_period = 5) { + auto tbegin = std::chrono::system_clock::now(); + while (true) { + common::SockAddr addr(tracker_addr_); + common::TCPSocket sock; + sock.Create(); + LOG(INFO) << "Tracker connecting to " << addr.AsString(); + if (sock.Connect(addr)) { + return sock; + } + + auto period = (std::chrono::duration_cast( + std::chrono::system_clock::now() - tbegin)).count(); + CHECK(period < timeout) << "Failed to connect to server" << addr.AsString(); + LOG(WARNING) << "Cannot connect to tracker " << addr.AsString() + << " retry in " << retry_period << " seconds."; + std::this_thread::sleep_for(std::chrono::seconds(retry_period)); + } + } + /*! + * \brief Random Generate a random number between 0 and 1. + * \return random float value. + */ + float Random() { + return dis_(gen_); + } + /*! + * \brief Generate a random key. + * \param prefix The string prefix. + * \return cmap The conflict map set. + */ + std::string RandomKey(const std::string& prefix, const std::set &cmap) { + if (!cmap.empty()) { + while (true) { + std::string key = prefix + std::to_string(Random()); + if (cmap.find(key) == cmap.end()) { + return key; + } + } + } + return prefix + std::to_string(Random()); + } + + std::string tracker_addr_; + std::string key_; + std::string custom_addr_; + common::TCPSocket tracker_sock_; + std::set old_keyset_; + std::mt19937 gen_; + std::uniform_real_distribution dis_; + +}; +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_TRACKER_CLIENT_H_ diff --git a/src/common/socket.h b/src/common/socket.h index 39bcff863c10..616991d021d1 100644 --- a/src/common/socket.h +++ b/src/common/socket.h @@ -43,12 +43,27 @@ using ssize_t = int; #include #include #include +#include #include #endif #include #include #include +#include +#include +#include "../common/util.h" +#if defined(_WIN32) +static inline int poll(struct pollfd *pfd, int nfds, + int timeout) { + return WSAPoll(pfd, nfds, timeout); +} +static inline int inet_pton(int family, const char* addr_str, void* addr_buf) { + return InetPton(family, addr_str, addr_buf); +} +#else +#include +#endif // defined(_WIN32) namespace tvm { namespace common { @@ -62,6 +77,22 @@ inline std::string GetHostName() { return std::string(buf.c_str()); } +/*! + * \brief ValidateIP validates an ip address. + * \param ip The ip address in string format localhost or x.x.x.x format + * \return result of operation. + */ +inline bool ValidateIP(std::string ip) { + if (ip == "localhost") { + return true; + } + struct sockaddr_in sa_ipv4; + struct sockaddr_in6 sa_ipv6; + bool is_ipv4 = inet_pton(AF_INET, ip.c_str(), &(sa_ipv4.sin_addr)); + bool is_ipv6 = inet_pton(AF_INET6, ip.c_str(), &(sa_ipv6.sin6_addr)); + return is_ipv4 || is_ipv6; +} + /*! * \brief Common data structure for network address. */ @@ -76,6 +107,23 @@ struct SockAddr { SockAddr(const char *url, int port) { this->Set(url, port); } + + /*! + * \brief SockAddr Get the socket address from tracker. + * \param tracker The url containing the ip and port number. Format is ('192.169.1.100', 9090) + * \return SockAddr parsed from url. + */ + explicit SockAddr(const std::string &url) { + size_t sep = url.find(","); + std::string host = url.substr(2, sep - 3); + std::string port = url.substr(sep + 1, url.length() - 1); + CHECK(ValidateIP(host)) << "Url address is not valid " << url; + if (host == "localhost") { + host = "127.0.0.1"; + } + this->Set(host.c_str(), std::stoi(port)); + } + /*! * \brief set the address * \param host the url of the address @@ -203,17 +251,20 @@ class Socket { } /*! * \brief try bind the socket to host, from start_port to end_port + * \param host host address to bind the socket * \param start_port starting port number to try * \param end_port ending port number to try * \return the port successfully bind to, return -1 if failed to bind any port */ - inline int TryBindHost(int start_port, int end_port) { + inline int TryBindHost(std::string host, int start_port, int end_port) { for (int port = start_port; port < end_port; ++port) { - SockAddr addr("0.0.0.0", port); + SockAddr addr(host.c_str(), port); if (bind(sockfd, reinterpret_cast(&addr.addr), (addr.addr.ss_family == AF_INET6 ? sizeof(sockaddr_in6) : sizeof(sockaddr_in))) == 0) { return port; + } else { + LOG(WARNING) << "Bind failed to " << host << ":" << port; } #if defined(_WIN32) if (WSAGetLastError() != WSAEADDRINUSE) { @@ -373,6 +424,20 @@ class TCPSocket : public Socket { } return TCPSocket(newfd); } + /*! + * \brief get a new connection + * \param addr client address from which connection accepted + * \return The accepted socket connection. + */ + TCPSocket Accept(SockAddr *addr) { + socklen_t addrlen = sizeof(addr->addr); + SockType newfd = accept(sockfd, reinterpret_cast(&addr->addr), + &addrlen); + if (newfd == INVALID_SOCKET) { + Socket::Error("Accept"); + } + return TCPSocket(newfd); + } /*! * \brief decide whether the socket is at OOB mark * \return 1 if at mark, 0 if not, -1 if an error occurred @@ -468,7 +533,125 @@ class TCPSocket : public Socket { } return ndone; } + /*! + * \brief Send the data to remote. + * \param data The data to be sent. + */ + void SendBytes(std::string data) { + int datalen = data.length(); + CHECK_EQ(SendAll(&datalen, sizeof(datalen)), sizeof(datalen)); + CHECK_EQ(SendAll(data.c_str(), datalen), datalen); + } + /*! + * \brief Receive the data to remote. + * \return The data received. + */ + std::string RecvBytes() { + int datalen = 0; + CHECK_EQ(RecvAll(&datalen, sizeof(datalen)), sizeof(datalen)); + std::string data; + data.resize(datalen); + CHECK_EQ(RecvAll(&data[0], datalen), datalen); + return data; + } }; + +/*! \brief helper data structure to perform poll */ +struct PollHelper { + public: + /*! + * \brief add file descriptor to watch for read + * \param fd file descriptor to be watched + */ + inline void WatchRead(TCPSocket::SockType fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLIN; + } + /*! + * \brief add file descriptor to watch for write + * \param fd file descriptor to be watched + */ + inline void WatchWrite(TCPSocket::SockType fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLOUT; + } + /*! + * \brief add file descriptor to watch for exception + * \param fd file descriptor to be watched + */ + inline void WatchException(TCPSocket::SockType fd) { + auto& pfd = fds[fd]; + pfd.fd = fd; + pfd.events |= POLLPRI; + } + /*! + * \brief Check if the descriptor is ready for read + * \param fd file descriptor to check status + */ + inline bool CheckRead(TCPSocket::SockType fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLIN) != 0); + } + /*! + * \brief Check if the descriptor is ready for write + * \param fd file descriptor to check status + */ + inline bool CheckWrite(TCPSocket::SockType fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLOUT) != 0); + } + /*! + * \brief Check if the descriptor has any exception + * \param fd file descriptor to check status + */ + inline bool CheckExcept(TCPSocket::SockType fd) const { + const auto& pfd = fds.find(fd); + return pfd != fds.end() && ((pfd->second.events & POLLPRI) != 0); + } + /*! + * \brief wait for exception event on a single descriptor + * \param fd the file descriptor to wait the event for + * \param timeout the timeout counter, can be negative, which means wait until the event happen + * \return 1 if success, 0 if timeout, and -1 if error occurs + */ + inline static int WaitExcept(TCPSocket::SockType fd, long timeout = -1) { // NOLINT(*) + pollfd pfd; + pfd.fd = fd; + pfd.events = POLLPRI; + return poll(&pfd, 1, timeout); + } + + /*! + * \brief peform poll on the set defined, read, write, exception + * \param timeout specify timeout in milliseconds(ms) if negative, means poll will block + * \return + */ + inline void Poll(long timeout = -1) { // NOLINT(*) + std::vector fdset; + fdset.reserve(fds.size()); + for (auto kv : fds) { + fdset.push_back(kv.second); + } + int ret = poll(fdset.data(), fdset.size(), timeout); + if (ret == -1) { + Socket::Error("Poll"); + } else { + for (auto& pfd : fdset) { + auto revents = pfd.revents & pfd.events; + if (!revents) { + fds.erase(pfd.fd); + } else { + fds[pfd.fd].events = revents; + } + } + } + } + + std::unordered_map fds; +}; + } // namespace common } // namespace tvm #endif // TVM_COMMON_SOCKET_H_ diff --git a/src/common/util.h b/src/common/util.h new file mode 100644 index 000000000000..93f32f48a2a6 --- /dev/null +++ b/src/common/util.h @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file util.h + * \brief Defines some common utility function.. + */ +#ifndef TVM_COMMON_UTIL_H_ +#define TVM_COMMON_UTIL_H_ + +#include +#ifndef _WIN32 +#include +#include +#endif +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace common { +/*! + * \brief TVMPOpen wrapper of popen between windows / unix. + * \param command executed command + * \param type "r" is for reading or "w" for writing. + * \return normal standard stream + */ +inline FILE* TVMPOpen(const char* command, const char* type) { +#if defined(_WIN32) + return _popen(command, type); +#else + return popen(command, type); +#endif +} + +/*! + * \brief TVMPClose wrapper of pclose between windows / linux + * \param stream the stream needed to be close. + * \return exit status + */ +inline int TVMPClose(FILE* stream) { +#if defined(_WIN32) + return _pclose(stream); +#else + return pclose(stream); +#endif +} + +/*! + * \brief TVMWifexited wrapper of WIFEXITED between windows / linux + * \param status The status field that was filled in by the wait or waitpid function + * \return the exit code of the child process + */ +inline int TVMWifexited(int status) { +#if defined(_WIN32) + return (status != 3); +#else + return WIFEXITED(status); +#endif +} + +/*! + * \brief TVMWexitstatus wrapper of WEXITSTATUS between windows / linux + * \param status The status field that was filled in by the wait or waitpid function. + * \return the child process exited normally or not + */ +inline int TVMWexitstatus(int status) { +#if defined(_WIN32) + return status; +#else + return WEXITSTATUS(status); +#endif +} + + +/*! + * \brief IsNumber check whether string is a number. + * \param str input string + * \return result of operation. + */ +inline bool IsNumber(const std::string& str) { + return !str.empty() && std::find_if(str.begin(), + str.end(), [](char c) { return !std::isdigit(c); }) == str.end(); +} + +/*! + * \brief split Split the string based on delimiter + * \param str Input string + * \param delim The delimiter. + * \return vector of strings which are splitted. + */ +inline std::vector Split(const std::string& str, char delim) { + std::string item; + std::istringstream is(str); + std::vector ret; + while (std::getline(is, item, delim)) { + ret.push_back(item); + } + return ret; +} + +/*! + * \brief EndsWith check whether the strings ends with + * \param value The full string + * \param end The end substring + * \return bool The result. + */ +inline bool EndsWith(std::string const& value, std::string const& end) { + if (end.size() <= value.size()) { + return std::equal(end.rbegin(), end.rend(), value.rbegin()); + } + return false; +} + +/*! + * \brief Execute the command + * \param cmd The command we want to execute + * \param err_msg The error message if we have + * \return executed output status + */ +inline int Execute(std::string cmd, std::string* err_msg) { + std::array buffer; + std::string result; + cmd += " 2>&1"; + FILE* fd = TVMPOpen(cmd.c_str(), "r"); + while (fgets(buffer.data(), buffer.size(), fd) != nullptr) { + *err_msg += buffer.data(); + } + int status = TVMPClose(fd); + if (TVMWifexited(status)) { + return TVMWexitstatus(status); + } + return 255; +} + +} // namespace common +} // namespace tvm +#endif // TVM_COMMON_UTIL_H_ diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index d982f68bcb6e..3518455c83d1 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -36,8 +36,27 @@ namespace tvm { namespace runtime { +// Magic header for RPC data plane const int kRPCMagic = 0xff271; +// magic header for RPC tracker(control plane) +const int kRPCTrackerMagic = 0x2f271; +// sucess response +const int kRPCSuccess = kRPCMagic + 0; +// cannot found matched key in server +const int kRPCMismatch = kRPCMagic + 2; +/*! \brief Enumeration code for the RPC tracker */ +enum class TrackerCode : int { + kFail = -1, + kSuccess = 0, + kPing = 1, + kStop = 2, + kPut = 3, + kRequest = 4, + kUpdateInfo = 5, + kSummary = 6, + kGetPendingMatchKeys = 7 +}; /*! \brief The remote functio handle */ using RPCFuncHandle = void*; diff --git a/src/runtime/rpc/rpc_socket_impl.h b/src/runtime/rpc/rpc_socket_impl.h new file mode 100644 index 000000000000..ea7c8394bff8 --- /dev/null +++ b/src/runtime/rpc/rpc_socket_impl.h @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file rpc_socket_impl.h + * \brief Socket based RPC implementation. + */ +#ifndef TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_ +#define TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_ + +namespace tvm { +namespace runtime { + +/*! + * \brief RPCServerLoop Start the rpc server loop. + * \param sockfd Socket file descriptor + */ +void RPCServerLoop(int sockfd); + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_SOCKET_IMPL_H_