diff --git a/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml b/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml
index f5b6c1b1b9e0..404f935a1f3e 100644
--- a/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml
+++ b/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml
@@ -6,7 +6,7 @@
false
- ../../../lib/libtvm_runtime.so
+ ../../../lib/libtvm_runtime.dylib
lib/native
0644
diff --git a/jvm/core/pom.xml b/jvm/core/pom.xml
index 62d6a866ae5d..8aed78525189 100644
--- a/jvm/core/pom.xml
+++ b/jvm/core/pom.xml
@@ -20,18 +20,21 @@
osx-x86_64-cpu
osx-x86_64-cpu
+ libtvm_runtime.dylib
linux-x86_64-cpu
linux-x86_64-cpu
+ libtvm_runtime.so
linux-x86_64-gpu
linux-x86_64-gpu
+ libtvm_runtime.so
@@ -88,7 +91,7 @@
1
-Djava.library.path=${project.parent.basedir}/native/${platform}/target
- -Dlibtvm.so.path=${project.parent.basedir}/../lib/libtvm_runtime.so
+ -Dlibtvm.so.path=${project.parent.basedir}/../lib/${libtvm.so.filename}
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Base.java b/jvm/core/src/main/java/ml/dmlc/tvm/Base.java
index ec88325b1a8b..09d1efe42097 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/Base.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/Base.java
@@ -80,7 +80,18 @@ public RefTVMValue() {
if (tvmLibFilename == null || !new File(tvmLibFilename).isFile()
|| _LIB.nativeLibInit(tvmLibFilename) != 0) {
try {
- NativeLibraryLoader.extractResourceFileToTempDir("libtvm_runtime.so", new Action() {
+ String runtimeLibname;
+ String os = System.getProperty("os.name");
+ // ref: http://lopica.sourceforge.net/os.html
+ if (os.startsWith("Linux")) {
+ runtimeLibname = "libtvm_runtime.so";
+ } else if (os.startsWith("Mac")) {
+ runtimeLibname = "libtvm_runtime.dylib";
+ } else {
+ // TODO(yizhi) support windows later
+ throw new UnsatisfiedLinkError("Windows not supported currently");
+ }
+ NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() {
@Override public void invoke(File target) {
System.err.println("Loading tvm runtime from " + target.getPath());
checkCall(_LIB.nativeLibInit(target.getPath()));
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java
index 3f38a5aba155..63602f3a14d0 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java
@@ -34,7 +34,7 @@ public class Function extends TVMValue {
* @param name full function name.
* @return TVM function.
*/
- static Function getFunction(final String name) {
+ public static Function getFunction(final String name) {
for (String fullName : listGlobalFuncNames()) {
if (fullName.equals(name)) {
return getGlobalFunc(fullName, true, false);
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java b/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java
index c073416ae0c4..396f740b914f 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java
@@ -30,7 +30,7 @@ class NativeLibraryLoader {
static {
try {
- tempDir = File.createTempFile("tvm", "");
+ tempDir = File.createTempFile("tvm4j", "");
if (!tempDir.delete() || !tempDir.mkdir()) {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java b/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java
index e91b90cce083..0d108e0a2943 100644
--- a/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java
@@ -17,12 +17,12 @@
package ml.dmlc.tvm;
+import ml.dmlc.tvm.rpc.RPC;
+
import java.util.HashMap;
import java.util.Map;
public class TVMContext {
- private static final int RPC_SESS_MASK = 128;
-
private static final Map MASK2STR = new HashMap();
private static final Map STR2MASK = new HashMap();
@@ -169,9 +169,9 @@ public void sync() {
}
@Override public String toString() {
- if (deviceType >= RPC_SESS_MASK) {
- int tblId = deviceType / RPC_SESS_MASK - 1;
- int devType = deviceType % RPC_SESS_MASK;
+ if (deviceType >= RPC.RPC_SESS_MASK) {
+ int tblId = deviceType / RPC.RPC_SESS_MASK - 1;
+ int devType = deviceType % RPC.RPC_SESS_MASK;
return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), deviceId);
}
return String.format("%s(%d)", MASK2STR.get(deviceType), deviceId);
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java
new file mode 100644
index 000000000000..20292b7a6f82
--- /dev/null
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.tvm.rpc;
+
+import ml.dmlc.tvm.Function;
+import ml.dmlc.tvm.TVMValue;
+
+public class Client {
+ /**
+ * Connect to RPC Server.
+ * @param url The url of the host.
+ * @param port The port to connect to.
+ * @param key Additional key to match server.
+ * @return The connected session.
+ */
+ public static RPCSession connect(String url, int port, String key) {
+ Function doConnect = RPC.getApi("_Connect");
+ if (doConnect == null) {
+ throw new RuntimeException("Please compile with USE_RPC=1");
+ }
+ TVMValue sess = doConnect.pushArg(url).pushArg(port).pushArg(key).invoke();
+ return new RPCSession(sess.asModule());
+ }
+
+ /**
+ * Connect to RPC Server.
+ * @param url The url of the host.
+ * @param port The port to connect to.
+ * @return The connected session.
+ */
+ public static RPCSession connect(String url, int port) {
+ return connect(url, port, "");
+ }
+}
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java
new file mode 100644
index 000000000000..e3b8b9366751
--- /dev/null
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.tvm.rpc;
+
+import ml.dmlc.tvm.Function;
+
+import java.util.HashMap;
+import java.util.Map;
+
+public class RPC {
+ public static final int RPC_MAGIC = 0xff271;
+ public static final int RPC_SESS_MASK = 128;
+
+ private static ThreadLocal> apiFuncs
+ = new ThreadLocal>() {
+ @Override
+ protected Map initialValue() {
+ return new HashMap();
+ }
+ };
+
+ static Function getApi(String name) {
+ Function func = apiFuncs.get().get(name);
+ if (func == null) {
+ func = Function.getFunction("contrib.rpc." + name);
+ if (func == null) {
+ return null;
+ }
+ apiFuncs.get().put(name, func);
+ }
+ return func;
+ }
+}
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
new file mode 100644
index 000000000000..cb4ccf49434b
--- /dev/null
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.tvm.rpc;
+
+import ml.dmlc.tvm.Function;
+import ml.dmlc.tvm.Module;
+import ml.dmlc.tvm.TVMContext;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * RPC Client session module.
+ * Do not directly create the object, use Client.connect.
+ */
+public class RPCSession {
+ private final Module session;
+ private final int tblIndex;
+ private final Map remoteFuncs = new HashMap();
+
+ RPCSession(Module sess) {
+ session = sess;
+ tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong();
+ }
+
+ /**
+ * Get function from the session.
+ * @param name The name of the function.
+ * @return The result function.
+ */
+ public Function getFunction(String name) {
+ return session.getFunction(name);
+ }
+
+ /**
+ * Construct a remote context.
+ * @param devType device type.
+ * @param devId device id.
+ * @return The corresponding encoded remote context.
+ */
+ public TVMContext context(String devType, int devId) {
+ TVMContext ctx = new TVMContext(devType, devId);
+ int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK;
+ return new TVMContext(ctx.deviceType + encode, devId);
+ }
+
+ /**
+ * Construct a remote context.
+ * @param devType device type.
+ * @return The corresponding encoded remote context.
+ */
+ public TVMContext context(String devType) {
+ return context(devType, 0);
+ }
+
+ /**
+ * Construct a remote context.
+ * @param devType device type.
+ * @param devId device id.
+ * @return The corresponding encoded remote context.
+ */
+ public TVMContext context(int devType, int devId) {
+ int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK;
+ return new TVMContext(devType + encode, devId);
+ }
+
+ /**
+ * Construct a remote context.
+ * @param devType device type.
+ * @return The corresponding encoded remote context.
+ */
+ public TVMContext context(int devType) {
+ return context(devType, 0);
+ }
+
+ /**
+ * Construct remote CPU device.
+ * @param devId device id.
+ * @return Remote CPU context.
+ */
+ public TVMContext cpu(int devId) {
+ return context(1, devId);
+ }
+
+ /**
+ * Construct remote CPU device.
+ * @return Remote CPU context.
+ */
+ public TVMContext cpu() {
+ return cpu(0);
+ }
+
+ /**
+ * Construct remote GPU device.
+ * @param devId device id.
+ * @return Remote GPU context.
+ */
+ public TVMContext gpu(int devId) {
+ return context(2, devId);
+ }
+
+ /**
+ * Construct remote GPU device.
+ * @return Remote GPU context.
+ */
+ public TVMContext gpu() {
+ return gpu(0);
+ }
+
+ /**
+ * Construct remote OpenCL device.
+ * @param devId device id.
+ * @return Remote OpenCL context.
+ */
+ public TVMContext cl(int devId) {
+ return context(4, devId);
+ }
+
+ /**
+ * Construct remote OpenCL device.
+ * @return Remote OpenCL context.
+ */
+ public TVMContext cl() {
+ return cl(0);
+ }
+
+ /**
+ * Construct remote Metal device.
+ * @param devId device id.
+ * @return Remote metal context.
+ */
+ public TVMContext metal(int devId) {
+ return context(8, devId);
+ }
+
+ /**
+ * Construct remote Metal device.
+ * @return Remote metal context.
+ */
+ public TVMContext metal() {
+ return metal(0);
+ }
+
+ /**
+ * Upload binary to remote runtime temp folder.
+ * @param data The binary in local to upload.
+ * @param target The path in remote, cannot be null.
+ */
+ public void upload(byte[] data, String target) {
+ if (target == null) {
+ throw new IllegalArgumentException("Please specify the upload target");
+ }
+ final String funcName = "upload";
+ Function remoteFunc = remoteFuncs.get(funcName);
+ if (remoteFunc == null) {
+ remoteFunc = getFunction("tvm.contrib.rpc.server.upload");
+ remoteFuncs.put(funcName, remoteFunc);
+ }
+ remoteFunc.pushArg(target).pushArg(data).invoke();
+ }
+
+ /**
+ * Upload file to remote runtime temp folder.
+ * @param data The file in local to upload.
+ * @param target The path in remote.
+ */
+ public void upload(File data, String target) throws IOException {
+ byte[] blob = getBytesFromFile(data);
+ upload(blob, target);
+ }
+
+ /**
+ * Upload file to remote runtime temp folder.
+ * @param data The file in local to upload.
+ */
+ public void upload(File data) throws IOException {
+ upload(data, data.getName());
+ }
+
+ /**
+ * Download file from remote temp folder.
+ * @param path The relative location to remote temp folder.
+ * @return The result blob from the file.
+ */
+ public byte[] download(String path) {
+ final String name = "download";
+ Function func = remoteFuncs.get(name);
+ if (func == null) {
+ func = getFunction("tvm.contrib.rpc.server.download");
+ remoteFuncs.put(name, func);
+ }
+ return func.pushArg(path).invoke().asBytes();
+ }
+
+ /**
+ * Load a remote module, the file need to be uploaded first.
+ * @param path The relative location to remote temp folder.
+ * @return The remote module containing remote function.
+ */
+ public Module loadModule(String path) {
+ return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule();
+ }
+
+
+ private static byte[] getBytesFromFile(File file) throws IOException {
+ // Get the size of the file
+ long length = file.length();
+
+ if (length > Integer.MAX_VALUE) {
+ throw new IOException("File " + file.getName() + " is too large!");
+ }
+
+ // cannot create an array using a long type.
+ byte[] bytes = new byte[(int)length];
+
+ // Read in the bytes
+ int offset = 0;
+ int numRead = 0;
+
+ InputStream is = new FileInputStream(file);
+ try {
+ while (offset < bytes.length
+ && (numRead = is.read(bytes, offset, bytes.length - offset)) >= 0) {
+ offset += numRead;
+ }
+ } finally {
+ is.close();
+ }
+
+ // Ensure all the bytes have been read in
+ if (offset < bytes.length) {
+ throw new IOException("Could not completely read file " + file.getName());
+ }
+ return bytes;
+ }
+}
diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java
new file mode 100644
index 000000000000..9be1859bb46e
--- /dev/null
+++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java
@@ -0,0 +1,361 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.tvm.rpc;
+
+import ml.dmlc.tvm.Function;
+import ml.dmlc.tvm.Module;
+import ml.dmlc.tvm.TVMValue;
+import sun.misc.SharedSecrets;
+
+import java.io.File;
+import java.io.FileDescriptor;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.net.SocketException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+
+/**
+ * RPC Server.
+ */
+public class Server {
+ private static SocketFileDescriptorGetter defaultSocketFdGetter
+ = new SocketFileDescriptorGetter() {
+ @Override public int get(Socket socket) {
+ try {
+ InputStream is = socket.getInputStream();
+ FileDescriptor fd = ((FileInputStream) is).getFD();
+ return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd);
+ } catch (IOException e) {
+ e.printStackTrace();
+ return -1;
+ }
+ }
+ };
+ private static final int DEFAULT_THREAD_NUMBER_IN_A_POOL = 20;
+
+ private final Loop serverLoop;
+ private final ExecutorService threadPool;
+
+ /**
+ * Start a standalone server.
+ * @param serverPort Port.
+ * @param socketFdGetter Method to get system file descriptor of the server socket.
+ * @throws IOException if failed to bind localhost:port.
+ */
+ public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException {
+ threadPool = setupThreadPool();
+ serverLoop = new ListenLoop(serverPort, threadPool, socketFdGetter);
+ }
+
+ /**
+ * Start a standalone server.
+ * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
+ * to get file descriptor for the socket.
+ * @param serverPort Port.
+ * @throws IOException if failed to bind localhost:port.
+ */
+ public Server(int serverPort) throws IOException {
+ this(serverPort, defaultSocketFdGetter);
+ }
+
+ /**
+ * Start a server connected to proxy.
+ * @param proxyHost The proxy server host.
+ * @param proxyPort The proxy server port.
+ * @param key The key to identify the server.
+ * @param socketFdGetter Method to get system file descriptor of the server socket.
+ */
+ public Server(String proxyHost, int proxyPort, String key,
+ SocketFileDescriptorGetter socketFdGetter) {
+ threadPool = setupThreadPool();
+ serverLoop = new ConnectProxyLoop(proxyHost, proxyPort, key, threadPool, socketFdGetter);
+ }
+
+ /**
+ * Start a server connected to proxy.
+ * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess
+ * to get file descriptor for the socket.
+ * @param proxyHost The proxy server host.
+ * @param proxyPort The proxy server port.
+ * @param key The key to identify the server.
+ */
+ public Server(String proxyHost, int proxyPort, String key) {
+ this(proxyHost, proxyPort, key, defaultSocketFdGetter);
+ }
+
+ private ExecutorService setupThreadPool() {
+ final String workerThreadNumber = System.getProperty("rpc.server.thread.number");
+ final int numThread = (workerThreadNumber == null)
+ ? DEFAULT_THREAD_NUMBER_IN_A_POOL : Integer.parseInt(workerThreadNumber);
+ return Executors.newFixedThreadPool(numThread);
+ }
+
+ /**
+ * Start the server.
+ */
+ public void start() {
+ serverLoop.start();
+ }
+
+ /**
+ * Stop the server.
+ */
+ public void terminate() {
+ serverLoop.interrupt();
+ serverLoop.terminate();
+ threadPool.shutdown();
+ }
+
+ public static interface SocketFileDescriptorGetter {
+ public int get(Socket socket);
+ }
+
+ static class ServerLoop implements Runnable {
+ private final Socket socket;
+ private final SocketFileDescriptorGetter socketFdGetter;
+
+ ServerLoop(Socket socket, SocketFileDescriptorGetter fdGetter) {
+ this.socket = socket;
+ socketFdGetter = fdGetter;
+ }
+
+ @Override public void run() {
+ int sockFd = socketFdGetter.get(socket);
+ if (sockFd != -1) {
+ File tempDir = null;
+ try {
+ tempDir = serverEnv();
+ RPC.getApi("_ServerLoop").pushArg(sockFd).invoke();
+ System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString());
+ } catch (IOException e) {
+ e.printStackTrace();
+ } finally {
+ if (tempDir != null) {
+ if (!tempDir.delete()) {
+ System.err.println(
+ "[WARN] Couldn't delete temporary directory " + tempDir.getAbsolutePath());
+ }
+ }
+ closeQuietly(socket);
+ }
+ }
+ }
+
+ private File serverEnv() throws IOException {
+ // Server environment function return temp dir.
+ final File tempDir = File.createTempFile("tvm4j_rpc_", "");
+ if (!tempDir.delete() || !tempDir.mkdir()) {
+ throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
+ }
+
+ Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() {
+ @Override public Object invoke(TVMValue... args) {
+ return tempDir + File.separator + args[0].asString();
+ }
+ }, true);
+
+ Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() {
+ @Override public Object invoke(TVMValue... args) {
+ String filename = args[0].asString();
+ String path = tempDir + File.separator + filename;
+ System.err.println("Load module from " + path);
+ return Module.load(path);
+ }
+ }, true);
+
+ return tempDir;
+ }
+ }
+
+ abstract static class Loop extends Thread {
+ public abstract void terminate();
+ }
+
+ static class ConnectProxyLoop extends Loop {
+ private volatile boolean running = true;
+ private final String host;
+ private final int port;
+ private final String key;
+ private final ExecutorService workerPool;
+ private final SocketFileDescriptorGetter socketFileDescriptorGetter;
+ private Socket waitingSocket = null;
+
+ public ConnectProxyLoop(String host, int port, String key,
+ ExecutorService workerPool,
+ SocketFileDescriptorGetter sockFdGetter) {
+ this.host = host;
+ this.port = port;
+ this.key = "server:" + key;
+ this.workerPool = workerPool;
+ socketFileDescriptorGetter = sockFdGetter;
+ }
+
+ @Override public void terminate() {
+ running = false;
+ if (waitingSocket != null) {
+ try {
+ waitingSocket.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+ }
+
+ @Override public void run() {
+ while (running) {
+ try {
+ Socket socket = new Socket(host, port);
+ waitingSocket = socket;
+ InputStream in = socket.getInputStream();
+ OutputStream out = socket.getOutputStream();
+ out.write(toBytes(RPC.RPC_MAGIC));
+ out.write(toBytes(key.length()));
+ out.write(toBytes(key));
+ int magic = wrapBytes(recvAll(in, 4)).getInt();
+ final String address = host + ":" + port;
+ if (magic == RPC.RPC_MAGIC + 1) {
+ throw new RuntimeException(
+ String.format("key: %s has already been used in proxy", key));
+ } else if (magic == RPC.RPC_MAGIC + 2) {
+ System.err.println("RPCProxy do not have matching client key " + key);
+ } else if (magic != RPC.RPC_MAGIC) {
+ throw new RuntimeException(address + " is not RPC Proxy");
+ }
+ System.err.println("RPCProxy connected to " + address);
+
+ waitingSocket = null;
+ workerPool.execute(new ServerLoop(socket, socketFileDescriptorGetter));
+ } catch (SocketException e) {
+ // when terminates, this is what we expect, do nothing.
+ } catch (IOException e) {
+ e.printStackTrace();
+ terminate();
+ }
+ }
+ }
+ }
+
+ static class ListenLoop extends Loop {
+ private final ServerSocket server;
+ private final ExecutorService workerPool;
+ private final SocketFileDescriptorGetter socketFileDescriptorGetter;
+ private volatile boolean running = true;
+
+ public ListenLoop(int serverPort, ExecutorService workerPool,
+ SocketFileDescriptorGetter sockFdGetter) throws IOException {
+ this.server = new ServerSocket(serverPort);
+ this.workerPool = workerPool;
+ this.socketFileDescriptorGetter = sockFdGetter;
+ }
+
+ @Override public void terminate() {
+ this.running = false;
+ try {
+ server.close();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ @Override public void run() {
+ while (running) {
+ try {
+ Socket socket = server.accept();
+ InputStream in = socket.getInputStream();
+ OutputStream out = socket.getOutputStream();
+ int magic = wrapBytes(recvAll(in, 4)).getInt();
+ if (magic != RPC.RPC_MAGIC) {
+ closeQuietly(socket);
+ continue;
+ }
+ int keyLen = wrapBytes(recvAll(in, 4)).getInt();
+ String key = decodeToStr(recvAll(in, keyLen));
+ if (!key.startsWith("client:")) {
+ out.write(toBytes(RPC.RPC_MAGIC + 2));
+ } else {
+ out.write(toBytes(RPC.RPC_MAGIC));
+ }
+ System.err.println("Connection from " + socket.getRemoteSocketAddress().toString());
+ workerPool.execute(new ServerLoop(socket, socketFileDescriptorGetter));
+ } catch (SocketException e) {
+ // when terminates, this is what we expect, do nothing.
+ } catch (IOException e) {
+ e.printStackTrace();
+ terminate();
+ }
+ }
+ }
+ }
+
+ private static byte[] recvAll(final InputStream in, final int numBytes) throws IOException {
+ byte[] res = new byte[numBytes];
+ int numRead = 0;
+ while (numRead < numBytes) {
+ int chunk = in.read(res, numRead, Math.min(numBytes - numRead, 1024));
+ numRead += chunk;
+ }
+ return res;
+ }
+
+ private static void closeQuietly(Socket socket) {
+ if (socket != null) {
+ try {
+ socket.shutdownInput();
+ socket.shutdownOutput();
+ socket.close();
+ } catch (IOException ioe) {
+ // close quietly, do nothing.
+ }
+ }
+ }
+
+ private static ByteBuffer wrapBytes(byte[] bytes) {
+ ByteBuffer bb = ByteBuffer.wrap(bytes);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ return bb;
+ }
+
+ private static byte[] toBytes(int number) {
+ ByteBuffer bb = ByteBuffer.allocate(4);
+ bb.order(ByteOrder.LITTLE_ENDIAN);
+ return bb.putInt(number).array();
+ }
+
+ private static byte[] toBytes(String str) {
+ byte[] bytes = new byte[str.length()];
+ for (int i = 0; i < str.length(); ++i) {
+ bytes[i] = (byte) str.charAt(i);
+ }
+ return bytes;
+ }
+
+ private static String decodeToStr(byte[] bytes) {
+ StringBuilder builder = new StringBuilder();
+ for (byte bt : bytes) {
+ builder.append((char) bt);
+ }
+ return builder.toString();
+ }
+}
diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java
new file mode 100644
index 000000000000..f4dd9eb9cb4e
--- /dev/null
+++ b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.tvm.rpc;
+
+import ml.dmlc.tvm.Function;
+import ml.dmlc.tvm.Module;
+import ml.dmlc.tvm.TVMValue;
+import org.junit.Test;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+
+public class RPCTest {
+ static class RefInt {
+ public int value;
+ }
+
+ private static Server startServer(RefInt portRef) {
+ Server server = null;
+ int port = 9981;
+ for (int i = 0; i < 10; ++i) {
+ try {
+ server = new Server(port + i);
+ server.start();
+ portRef.value = port + i;
+ return server;
+ } catch (IOException e) {
+ }
+ }
+ throw new RuntimeException("Cannot find an available port.");
+ }
+
+ @Test
+ public void test_addone() {
+ if (!Module.enabled("rpc")) {
+ return;
+ }
+ Function.register("test.rpc.addone", new Function.Callback() {
+ @Override public Object invoke(TVMValue... args) {
+ return args[0].asLong() + 1L;
+ }
+ });
+
+ RefInt port = new RefInt();
+ Server server = null;
+ try {
+ server = startServer(port);
+ RPCSession client = Client.connect("localhost", port.value);
+ Function func = client.getFunction("test.rpc.addone");
+ assertEquals(11L, func.call(10).asLong());
+ } finally {
+ if (server != null) {
+ server.terminate();
+ }
+ }
+ }
+
+ @Test
+ public void test_strcat() {
+ if (!Module.enabled("rpc")) {
+ return;
+ }
+ Function.register("test.rpc.strcat", new Function.Callback() {
+ @Override public Object invoke(TVMValue... args) {
+ return args[0].asString() + ":" + args[1].asLong();
+ }
+ });
+
+ RefInt port = new RefInt();
+ Server server = null;
+ try {
+ server = startServer(port);
+ RPCSession client = Client.connect("localhost", port.value);
+ Function func = client.getFunction("test.rpc.strcat");
+ assertEquals("abc:11", func.call("abc", 11L).asString());
+ } finally {
+ if (server != null) {
+ server.terminate();
+ }
+ }
+ }
+
+ @Test
+ public void test_connect_proxy_server() {
+ String proxyHost = System.getProperty("test.rpc.proxy.host");
+ int proxyPort = Integer.parseInt(System.getProperty("test.rpc.proxy.port"));
+
+ Function.register("test.rpc.proxy.addone", new Function.Callback() {
+ @Override public Object invoke(TVMValue... tvmValues) {
+ return tvmValues[0].asLong() + 1L;
+ }
+ });
+
+ Server server = null;
+ try {
+ server = new Server(proxyHost, proxyPort, "x1");
+ server.start();
+
+ RPCSession client = Client.connect(proxyHost, proxyPort, "x1");
+ Function f1 = client.getFunction("test.rpc.proxy.addone");
+ assertEquals(11L, f1.call(10L).asLong());
+ } finally {
+ if (server != null) {
+ server.terminate();
+ }
+ }
+ }
+}
diff --git a/jvm/core/src/test/scripts/test_rpc_proxy_server.py b/jvm/core/src/test/scripts/test_rpc_proxy_server.py
new file mode 100644
index 000000000000..3f1f6466c715
--- /dev/null
+++ b/jvm/core/src/test/scripts/test_rpc_proxy_server.py
@@ -0,0 +1,20 @@
+import time
+from tvm.contrib import rpc_proxy
+
+def start_proxy_server(port, timeout):
+ prox = rpc_proxy.Proxy("localhost", port=port, port_end=port+1)
+ if timeout > 0:
+ import time
+ time.sleep(timeout)
+ prox.terminate()
+ else:
+ prox.proc.join()
+
+if __name__ == "__main__":
+ import sys
+ if len(sys.argv) < 2:
+ sys.exit(-1)
+ port = int(sys.argv[1])
+ timeout = 0 if len(sys.argv) == 2 else float(sys.argv[2])
+ start_proxy_server(port, timeout)
+
diff --git a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc
index 2162da296ddc..32e4b7ca7da4 100644
--- a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc
+++ b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc
@@ -158,27 +158,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall(
TVMValue retVal;
int retTypeCode;
+
+ // function can be invoked recursively,
+ // thus we copy the pushed arguments here.
+ auto argValues = e->tvmFuncArgValues;
+ auto argTypes = e->tvmFuncArgTypes;
+ auto pushedStrs = e->tvmFuncArgPushedStrs;
+ auto pushedBytes = e->tvmFuncArgPushedBytes;
+
+ e->tvmFuncArgPushedStrs.clear();
+ e->tvmFuncArgPushedBytes.clear();
+ e->tvmFuncArgTypes.clear();
+ e->tvmFuncArgValues.clear();
+
int ret = TVMFuncCall(reinterpret_cast(jhandle),
- &e->tvmFuncArgValues[0], &e->tvmFuncArgTypes[0], numArgs, &retVal, &retTypeCode);
+ &argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode);
- for (auto iter = e->tvmFuncArgPushedStrs.cbegin();
- iter != e->tvmFuncArgPushedStrs.cend(); iter++) {
+ for (auto iter = pushedStrs.cbegin(); iter != pushedStrs.cend(); iter++) {
env->ReleaseStringUTFChars(iter->first, iter->second);
env->DeleteGlobalRef(iter->first);
}
- for (auto iter = e->tvmFuncArgPushedBytes.cbegin();
- iter != e->tvmFuncArgPushedBytes.cend(); iter++) {
+ for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) {
env->ReleaseByteArrayElements(iter->first,
reinterpret_cast(const_cast(iter->second->data)), 0);
env->DeleteGlobalRef(iter->first);
delete iter->second;
}
- e->tvmFuncArgPushedStrs.clear();
- e->tvmFuncArgPushedBytes.clear();
- e->tvmFuncArgTypes.clear();
- e->tvmFuncArgValues.clear();
-
// return TVMValue object to Java
jclass refTVMValueCls = env->FindClass("ml/dmlc/tvm/Base$RefTVMValue");
jfieldID refTVMValueFid
diff --git a/tests/scripts/task_java_unittest.sh b/tests/scripts/task_java_unittest.sh
index bb1f98247f36..9017079237f1 100755
--- a/tests/scripts/task_java_unittest.sh
+++ b/tests/scripts/task_java_unittest.sh
@@ -10,7 +10,14 @@ TEMP_DIR=$(mktemp -d)
python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1
python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1
+# start rpc proxy server
+PORT=$(( ( RANDOM % 1000 ) + 9000 ))
+python $SCRIPT_DIR/test_rpc_proxy_server.py $PORT 30 &
+
make jvmpkg || exit -1
-make jvmpkg JVM_TEST_ARGS="-DskipTests=false -Dtest.tempdir=$TEMP_DIR" || exit -1
+make jvmpkg JVM_TEST_ARGS="-DskipTests=false \
+ -Dtest.tempdir=$TEMP_DIR \
+ -Dtest.rpc.proxy.host=localhost \
+ -Dtest.rpc.proxy.port=$PORT" || exit -1
rm -rf $TEMP_DIR