From 824037de2fcfb1994e541e53bc28556d3aa2abf0 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Tue, 25 Jul 2017 01:11:32 +0800 Subject: [PATCH 1/3] [tvm4j] RPC Server --- .../src/main/java/ml/dmlc/tvm/Function.java | 2 +- .../java/ml/dmlc/tvm/NativeLibraryLoader.java | 2 +- .../src/main/java/ml/dmlc/tvm/TVMContext.java | 10 +- .../src/main/java/ml/dmlc/tvm/rpc/Client.java | 32 +++ .../src/main/java/ml/dmlc/tvm/rpc/RPC.java | 31 +++ .../main/java/ml/dmlc/tvm/rpc/RPCSession.java | 238 ++++++++++++++++++ .../src/main/java/ml/dmlc/tvm/rpc/Server.java | 207 +++++++++++++++ 7 files changed, 515 insertions(+), 7 deletions(-) create mode 100644 jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java create mode 100644 jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java create mode 100644 jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java create mode 100644 jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java index 3f38a5aba155..63602f3a14d0 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java @@ -34,7 +34,7 @@ public class Function extends TVMValue { * @param name full function name. * @return TVM function. */ - static Function getFunction(final String name) { + public static Function getFunction(final String name) { for (String fullName : listGlobalFuncNames()) { if (fullName.equals(name)) { return getGlobalFunc(fullName, true, false); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java b/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java index c073416ae0c4..396f740b914f 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java @@ -30,7 +30,7 @@ class NativeLibraryLoader { static { try { - tempDir = File.createTempFile("tvm", ""); + tempDir = File.createTempFile("tvm4j", ""); if (!tempDir.delete() || !tempDir.mkdir()) { throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath()); } diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java b/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java index e91b90cce083..0d108e0a2943 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java @@ -17,12 +17,12 @@ package ml.dmlc.tvm; +import ml.dmlc.tvm.rpc.RPC; + import java.util.HashMap; import java.util.Map; public class TVMContext { - private static final int RPC_SESS_MASK = 128; - private static final Map MASK2STR = new HashMap(); private static final Map STR2MASK = new HashMap(); @@ -169,9 +169,9 @@ public void sync() { } @Override public String toString() { - if (deviceType >= RPC_SESS_MASK) { - int tblId = deviceType / RPC_SESS_MASK - 1; - int devType = deviceType % RPC_SESS_MASK; + if (deviceType >= RPC.RPC_SESS_MASK) { + int tblId = deviceType / RPC.RPC_SESS_MASK - 1; + int devType = deviceType % RPC.RPC_SESS_MASK; return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), deviceId); } return String.format("%s(%d)", MASK2STR.get(deviceType), deviceId); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java new file mode 100644 index 000000000000..db4bcc50e52d --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java @@ -0,0 +1,32 @@ +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.TVMValue; + +public class Client { + /** + * Connect to RPC Server. + * @param url The url of the host. + * @param port The port to connect to. + * @param key Additional key to match server. + * @return The connected session. + */ + public static RPCSession connect(String url, int port, String key) { + Function doConnect = RPC.getApi("_Connect"); + if (doConnect == null) { + throw new RuntimeException("Please compile with USE_RPC=1"); + } + TVMValue sess = doConnect.pushArg(url).pushArg(port).pushArg(key).invoke(); + return new RPCSession(sess.asModule()); + } + + /** + * Connect to RPC Server. + * @param url The url of the host. + * @param port The port to connect to. + * @return The connected session. + */ + public static RPCSession connect(String url, int port) { + return connect(url, port, ""); + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java new file mode 100644 index 000000000000..1de6239614b7 --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java @@ -0,0 +1,31 @@ +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; + +import java.util.HashMap; +import java.util.Map; + +public class RPC { + public static final int RPC_MAGIC = 0xff271; + public static final int RPC_SESS_MASK = 128; + + private static ThreadLocal> apiFuncs + = new ThreadLocal>() { + @Override + protected Map initialValue() { + return new HashMap(); + } + }; + + static Function getApi(String name) { + Function func = apiFuncs.get().get(name); + if (func == null) { + func = Function.getFunction("contrib.rpc." + name); + if (func == null) { + return null; + } + apiFuncs.get().put(name, func); + } + return func; + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java new file mode 100644 index 000000000000..3ef6b47a50ba --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java @@ -0,0 +1,238 @@ +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.TVMContext; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.HashMap; +import java.util.Map; + +/** + * RPC Client session module. + * Do not directly create the object, use Client.connect. + */ +public class RPCSession { + private final Module session; + private final int tblIndex; + private final Map remoteFuncs = new HashMap(); + + RPCSession(Module sess) { + session = sess; + tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong(); + } + + /** + * Get function from the session. + * @param name The name of the function. + * @return The result function. + */ + public Function getFunction(String name) { + return session.getFunction(name); + } + + /** + * Construct a remote context. + * @param devType device type. + * @param devId device id. + * @return The corresponding encoded remote context. + */ + public TVMContext context(String devType, int devId) { + TVMContext ctx = new TVMContext(devType, devId); + int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK; + return new TVMContext(ctx.deviceType + encode, devId); + } + + /** + * Construct a remote context. + * @param devType device type. + * @return The corresponding encoded remote context. + */ + public TVMContext context(String devType) { + return context(devType, 0); + } + + /** + * Construct a remote context. + * @param devType device type. + * @param devId device id. + * @return The corresponding encoded remote context. + */ + public TVMContext context(int devType, int devId) { + int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK; + return new TVMContext(devType + encode, devId); + } + + /** + * Construct a remote context. + * @param devType device type. + * @return The corresponding encoded remote context. + */ + public TVMContext context(int devType) { + return context(devType, 0); + } + + /** + * Construct remote CPU device. + * @param devId device id. + * @return Remote CPU context. + */ + public TVMContext cpu(int devId) { + return context(1, devId); + } + + /** + * Construct remote CPU device. + * @return Remote CPU context. + */ + public TVMContext cpu() { + return cpu(0); + } + + /** + * Construct remote GPU device. + * @param devId device id. + * @return Remote GPU context. + */ + public TVMContext gpu(int devId) { + return context(2, devId); + } + + /** + * Construct remote GPU device. + * @return Remote GPU context. + */ + public TVMContext gpu() { + return gpu(0); + } + + /** + * Construct remote OpenCL device. + * @param devId device id. + * @return Remote OpenCL context. + */ + public TVMContext cl(int devId) { + return context(4, devId); + } + + /** + * Construct remote OpenCL device. + * @return Remote OpenCL context. + */ + public TVMContext cl() { + return cl(0); + } + + /** + * Construct remote Metal device. + * @param devId device id. + * @return Remote metal context. + */ + public TVMContext metal(int devId) { + return context(8, devId); + } + + /** + * Construct remote Metal device. + * @return Remote metal context. + */ + public TVMContext metal() { + return metal(0); + } + + /** + * Upload binary to remote runtime temp folder. + * @param data The binary in local to upload. + * @param target The path in remote, cannot be null. + */ + public void upload(byte[] data, String target) { + if (target == null) { + throw new IllegalArgumentException("Please specify the upload target"); + } + final String funcName = "upload"; + Function remoteFunc = remoteFuncs.get(funcName); + if (remoteFunc == null) { + remoteFunc = getFunction("tvm.contrib.rpc.server.upload"); + remoteFuncs.put(funcName, remoteFunc); + } + remoteFunc.pushArg(target).pushArg(data).invoke(); + } + + /** + * Upload file to remote runtime temp folder. + * @param data The file in local to upload. + * @param target The path in remote. + */ + public void upload(File data, String target) throws IOException { + byte[] blob = getBytesFromFile(data); + upload(blob, target); + } + + /** + * Upload file to remote runtime temp folder. + * @param data The file in local to upload. + */ + public void upload(File data) throws IOException { + upload(data, data.getName()); + } + + /** + * Download file from remote temp folder. + * @param path The relative location to remote temp folder. + * @return The result blob from the file. + */ + public byte[] download(String path) { + final String name = "download"; + Function func = remoteFuncs.get(name); + if (func == null) { + func = getFunction("tvm.contrib.rpc.server.download"); + remoteFuncs.put(name, func); + } + return func.pushArg(path).invoke().asBytes(); + } + + /** + * Load a remote module, the file need to be uploaded first. + * @param path The relative location to remote temp folder. + * @return The remote module containing remote function. + */ + public Module loadModule(String path) { + return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); + } + + + private static byte[] getBytesFromFile(File file) throws IOException { + // Get the size of the file + long length = file.length(); + + if (length > Integer.MAX_VALUE) { + throw new IOException("File " + file.getName() + " is too large!"); + } + + // cannot create an array using a long type. + byte[] bytes = new byte[(int)length]; + + // Read in the bytes + int offset = 0; + int numRead = 0; + + InputStream is = new FileInputStream(file); + try { + while (offset < bytes.length + && (numRead=is.read(bytes, offset, bytes.length-offset)) >= 0) { + offset += numRead; + } + } finally { + is.close(); + } + + // Ensure all the bytes have been read in + if (offset < bytes.length) { + throw new IOException("Could not completely read file " + file.getName()); + } + return bytes; + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java new file mode 100644 index 000000000000..06eb5bdbc44d --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java @@ -0,0 +1,207 @@ +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.TVMValue; +import sun.misc.SharedSecrets; + +import java.io.File; +import java.io.FileDescriptor; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +public class Server extends ServerSocket { + private final ConnectionThread connectionThread; + private SocketFileDescriptorGetter socketFdGetter = new SocketFileDescriptorGetter() { + @Override public int get(Socket socket) { + try { + InputStream is = socket.getInputStream(); + FileDescriptor fd = ((FileInputStream) is).getFD(); + return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd); + } catch (IOException e) { + e.printStackTrace(); + return -1; + } + } + }; + + public Server(int serverPort) throws IOException { + super(serverPort); + connectionThread = new ConnectionThread(this, socketFdGetter); + } + + public void start() { + connectionThread.start(); + } + + public void terminate() { + connectionThread.terminate(); + } + + public void registerFilDescriptorGetter(SocketFileDescriptorGetter getter) { + socketFdGetter = getter; + } + + public static interface SocketFileDescriptorGetter { + public int get(Socket socket); + } + + static class ServerLoop implements Runnable { + private final Socket socket; + private final SocketFileDescriptorGetter socketFdGetter; + + ServerLoop(Socket socket, SocketFileDescriptorGetter fdGetter) { + this.socket = socket; + socketFdGetter = fdGetter; + } + + @Override public void run() { + int sockFd = socketFdGetter.get(socket); + System.err.println("Socket fd = " + sockFd); + if (sockFd != -1) { + File tempDir = null; + try { + tempDir = serverEnv(); + RPC.getApi("_ServerLoop").pushArg(sockFd).invoke(); + System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); + } catch (IOException e) { + e.printStackTrace(); + } finally { + if (tempDir != null) { + if (!tempDir.delete()) { + System.err.println( + "[WARN] Couldn't delete temporary directory " + tempDir.getAbsolutePath()); + } + } + closeQuietly(socket); + } + } + } + + private File serverEnv() throws IOException { + // Server environment function return temp dir. + final File tempDir = File.createTempFile("tvm4j_rpc_", ""); + if (!tempDir.delete() || !tempDir.mkdir()) { + throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath()); + } + + Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + return tempDir + File.separator + args[0].asString(); + } + }); + + Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + String filename = args[0].asString(); + String path = tempDir + File.separator + filename; + // Try create a shared library in remote. + if (path.endsWith(".o")) { + System.err.println("Create shared library based on " + path); + // TODO(yizhi): create .so + } + System.err.println("Load module from " + path); + return Module.load(path); + } + }, true); + + return tempDir; + } + } + + static class ConnectionThread extends Thread { + private final ServerSocket server; + private final SocketFileDescriptorGetter socketFileDescriptorGetter; + private volatile boolean running = true; + + public ConnectionThread(Server server, SocketFileDescriptorGetter sockFdGetter) + throws IOException { + this.server = server; + this.socketFileDescriptorGetter = sockFdGetter; + } + + public void terminate() { + this.running = false; + } + + @Override public void run() { + while (running) { + Socket socket = null; + try { + socket = server.accept(); + InputStream in = socket.getInputStream(); + OutputStream out = socket.getOutputStream(); + int magic = wrapBytes(recvAll(in, 4)).getInt(); + if (magic != RPC.RPC_MAGIC) { + closeQuietly(socket); + continue; + } + int keyLen = wrapBytes(recvAll(in, 4)).getInt(); + String key = decodeToStr(recvAll(in, keyLen)); + if (!key.startsWith("client:")) { + out.write(toBytes(RPC.RPC_MAGIC + 2)); + } else { + out.write(toBytes(RPC.RPC_MAGIC)); + } + System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); + + // TODO(yizhi): use ExecutorService, i.e., ThreadPool + Thread processThread = new Thread(new ServerLoop(socket, socketFileDescriptorGetter)); + processThread.setDaemon(true); + processThread.start(); + } catch (IOException e) { + e.printStackTrace(); + terminate(); + } + } + } + + private byte[] recvAll(final InputStream in, final int nBytes) throws IOException { + byte[] res = new byte[nBytes]; + int nRead = 0; + while (nRead < nBytes) { + int chunk = in.read(res, nRead, Math.min(nBytes - nRead, 1024)); + nRead += chunk; + } + return res; + } + } + + private static void closeQuietly(Socket socket) { + if (socket != null) { + try { + socket.shutdownInput(); + socket.shutdownOutput(); + socket.close(); + } catch (IOException ioe) { + // close quietly, do nothing. + } + } + } + + private static ByteBuffer wrapBytes(byte[] bytes) { + ByteBuffer bb = ByteBuffer.wrap(bytes); + bb.order(ByteOrder.LITTLE_ENDIAN); + return bb; + } + + private static byte[] toBytes(int number) { + ByteBuffer bb = ByteBuffer.allocate(4); + bb.order(ByteOrder.LITTLE_ENDIAN); + return bb.putInt(number).array(); + } + + private static String decodeToStr(byte[] bytes) { + StringBuilder builder = new StringBuilder(); + for (byte bt : bytes) { + builder.append((char) bt); + } + return builder.toString(); + } +} From 0cf97a0e755b2585eb1edbb07c026cadb370d1e0 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Sat, 5 Aug 2017 01:25:50 +0800 Subject: [PATCH 2/3] [tvm4j] fix recursively function calling; connect to proxy server; osx rename .so to .dylib --- .../src/main/assembly/assembly.xml | 2 +- jvm/core/pom.xml | 5 +- jvm/core/src/main/java/ml/dmlc/tvm/Base.java | 13 +- .../src/main/java/ml/dmlc/tvm/rpc/Client.java | 17 ++ .../src/main/java/ml/dmlc/tvm/rpc/RPC.java | 27 ++- .../main/java/ml/dmlc/tvm/rpc/RPCSession.java | 19 +- .../src/main/java/ml/dmlc/tvm/rpc/Server.java | 228 ++++++++++++++---- .../test/java/ml/dmlc/tvm/rpc/RPCTest.java | 98 ++++++++ .../main/native/ml_dmlc_tvm_native_c_api.cc | 26 +- 9 files changed, 373 insertions(+), 62 deletions(-) create mode 100644 jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java diff --git a/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml b/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml index f5b6c1b1b9e0..404f935a1f3e 100644 --- a/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml +++ b/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml @@ -6,7 +6,7 @@ false - ../../../lib/libtvm_runtime.so + ../../../lib/libtvm_runtime.dylib lib/native 0644 diff --git a/jvm/core/pom.xml b/jvm/core/pom.xml index 62d6a866ae5d..8aed78525189 100644 --- a/jvm/core/pom.xml +++ b/jvm/core/pom.xml @@ -20,18 +20,21 @@ osx-x86_64-cpu osx-x86_64-cpu + libtvm_runtime.dylib linux-x86_64-cpu linux-x86_64-cpu + libtvm_runtime.so linux-x86_64-gpu linux-x86_64-gpu + libtvm_runtime.so @@ -88,7 +91,7 @@ 1 -Djava.library.path=${project.parent.basedir}/native/${platform}/target - -Dlibtvm.so.path=${project.parent.basedir}/../lib/libtvm_runtime.so + -Dlibtvm.so.path=${project.parent.basedir}/../lib/${libtvm.so.filename} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Base.java b/jvm/core/src/main/java/ml/dmlc/tvm/Base.java index ec88325b1a8b..09d1efe42097 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/Base.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/Base.java @@ -80,7 +80,18 @@ public RefTVMValue() { if (tvmLibFilename == null || !new File(tvmLibFilename).isFile() || _LIB.nativeLibInit(tvmLibFilename) != 0) { try { - NativeLibraryLoader.extractResourceFileToTempDir("libtvm_runtime.so", new Action() { + String runtimeLibname; + String os = System.getProperty("os.name"); + // ref: http://lopica.sourceforge.net/os.html + if (os.startsWith("Linux")) { + runtimeLibname = "libtvm_runtime.so"; + } else if (os.startsWith("Mac")) { + runtimeLibname = "libtvm_runtime.dylib"; + } else { + // TODO(yizhi) support windows later + throw new UnsatisfiedLinkError("Windows not supported currently"); + } + NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() { @Override public void invoke(File target) { System.err.println("Loading tvm runtime from " + target.getPath()); checkCall(_LIB.nativeLibInit(target.getPath())); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java index db4bcc50e52d..20292b7a6f82 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ml.dmlc.tvm.rpc; import ml.dmlc.tvm.Function; diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java index 1de6239614b7..e3b8b9366751 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ml.dmlc.tvm.rpc; import ml.dmlc.tvm.Function; @@ -11,11 +28,11 @@ public class RPC { private static ThreadLocal> apiFuncs = new ThreadLocal>() { - @Override - protected Map initialValue() { - return new HashMap(); - } - }; + @Override + protected Map initialValue() { + return new HashMap(); + } + }; static Function getApi(String name) { Function func = apiFuncs.get().get(name); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java index 3ef6b47a50ba..cb4ccf49434b 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ml.dmlc.tvm.rpc; import ml.dmlc.tvm.Function; @@ -222,7 +239,7 @@ private static byte[] getBytesFromFile(File file) throws IOException { InputStream is = new FileInputStream(file); try { while (offset < bytes.length - && (numRead=is.read(bytes, offset, bytes.length-offset)) >= 0) { + && (numRead = is.read(bytes, offset, bytes.length - offset)) >= 0) { offset += numRead; } } finally { diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java index 06eb5bdbc44d..493ebb1b2dbb 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package ml.dmlc.tvm.rpc; import ml.dmlc.tvm.Function; @@ -13,39 +30,88 @@ import java.io.OutputStream; import java.net.ServerSocket; import java.net.Socket; +import java.net.SocketException; import java.nio.ByteBuffer; import java.nio.ByteOrder; -public class Server extends ServerSocket { - private final ConnectionThread connectionThread; - private SocketFileDescriptorGetter socketFdGetter = new SocketFileDescriptorGetter() { - @Override public int get(Socket socket) { - try { - InputStream is = socket.getInputStream(); - FileDescriptor fd = ((FileInputStream) is).getFD(); - return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd); - } catch (IOException e) { - e.printStackTrace(); - return -1; - } - } - }; +/** + * RPC Server. + */ +public class Server { + private final Loop serverLoop; + private static SocketFileDescriptorGetter defaultSocketFdGetter + = new SocketFileDescriptorGetter() { + @Override public int get(Socket socket) { + try { + InputStream is = socket.getInputStream(); + FileDescriptor fd = ((FileInputStream) is).getFD(); + return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd); + } catch (IOException e) { + e.printStackTrace(); + return -1; + } + } + }; + /** + * Start a standalone server. + * @param serverPort Port. + * @param socketFdGetter Method to get system file descriptor of the server socket. + * @throws IOException if failed to bind localhost:port. + */ + public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException { + serverLoop = new ListenLoop(serverPort, socketFdGetter); + } + + + /** + * Start a standalone server. + * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess + * to get file descriptor for the socket. + * @param serverPort Port. + * @throws IOException if failed to bind localhost:port. + */ public Server(int serverPort) throws IOException { - super(serverPort); - connectionThread = new ConnectionThread(this, socketFdGetter); + this(serverPort, defaultSocketFdGetter); } - public void start() { - connectionThread.start(); + /** + * Start a server connected to proxy. + * @param proxyHost The proxy server host. + * @param proxyPort The proxy server port. + * @param key The key to identify the server. + * @param socketFdGetter Method to get system file descriptor of the server socket. + */ + public Server(String proxyHost, int proxyPort, String key, + SocketFileDescriptorGetter socketFdGetter) { + serverLoop = new ConnectProxyLoop(proxyHost, proxyPort, key, socketFdGetter); } - public void terminate() { - connectionThread.terminate(); + /** + * Start a server connected to proxy. + * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess + * to get file descriptor for the socket. + * @param proxyHost The proxy server host. + * @param proxyPort The proxy server port. + * @param key The key to identify the server. + */ + public Server(String proxyHost, int proxyPort, String key) { + this(proxyHost, proxyPort, key, defaultSocketFdGetter); } - public void registerFilDescriptorGetter(SocketFileDescriptorGetter getter) { - socketFdGetter = getter; + /** + * Start the server. + */ + public void start() { + serverLoop.start(); + } + + /** + * Stop the server. + */ + public void terminate() { + serverLoop.interrupt(); + serverLoop.terminate(); } public static interface SocketFileDescriptorGetter { @@ -63,7 +129,6 @@ static class ServerLoop implements Runnable { @Override public void run() { int sockFd = socketFdGetter.get(socket); - System.err.println("Socket fd = " + sockFd); if (sockFd != -1) { File tempDir = null; try { @@ -95,17 +160,12 @@ private File serverEnv() throws IOException { @Override public Object invoke(TVMValue... args) { return tempDir + File.separator + args[0].asString(); } - }); + }, true); Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() { @Override public Object invoke(TVMValue... args) { String filename = args[0].asString(); String path = tempDir + File.separator + filename; - // Try create a shared library in remote. - if (path.endsWith(".o")) { - System.err.println("Create shared library based on " + path); - // TODO(yizhi): create .so - } System.err.println("Load module from " + path); return Module.load(path); } @@ -115,26 +175,98 @@ private File serverEnv() throws IOException { } } - static class ConnectionThread extends Thread { + abstract static class Loop extends Thread { + public abstract void terminate(); + } + + static class ConnectProxyLoop extends Loop { + private volatile boolean running = true; + private final String host; + private final int port; + private final String key; + private final SocketFileDescriptorGetter socketFileDescriptorGetter; + private Socket waitingSocket = null; + + public ConnectProxyLoop(String host, int port, String key, + SocketFileDescriptorGetter sockFdGetter) { + this.host = host; + this.port = port; + this.key = "server:" + key; + socketFileDescriptorGetter = sockFdGetter; + } + + @Override public void terminate() { + running = false; + if (waitingSocket != null) { + try { + waitingSocket.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + + @Override public void run() { + while (running) { + try { + Socket socket = new Socket(host, port); + waitingSocket = socket; + InputStream in = socket.getInputStream(); + OutputStream out = socket.getOutputStream(); + out.write(toBytes(RPC.RPC_MAGIC)); + out.write(toBytes(key.length())); + out.write(toBytes(key)); + int magic = wrapBytes(recvAll(in, 4)).getInt(); + final String address = host + ":" + port; + if (magic == RPC.RPC_MAGIC + 1) { + throw new RuntimeException( + String.format("key: %s has already been used in proxy", key)); + } else if (magic == RPC.RPC_MAGIC + 2) { + System.err.println("RPCProxy do not have matching client key " + key); + } else if (magic != RPC.RPC_MAGIC) { + throw new RuntimeException(address + " is not RPC Proxy"); + } + System.err.println("RPCProxy connected to " + address); + + waitingSocket = null; + // TODO(yizhi): use ExecutorService, i.e., ThreadPool + Thread processThread = new Thread(new ServerLoop(socket, socketFileDescriptorGetter)); + processThread.setDaemon(true); + processThread.start(); + } catch (SocketException e) { + // when terminates, this is what we expect, do nothing. + } catch (IOException e) { + e.printStackTrace(); + terminate(); + } + } + } + } + + static class ListenLoop extends Loop { private final ServerSocket server; private final SocketFileDescriptorGetter socketFileDescriptorGetter; private volatile boolean running = true; - public ConnectionThread(Server server, SocketFileDescriptorGetter sockFdGetter) + public ListenLoop(int serverPort, SocketFileDescriptorGetter sockFdGetter) throws IOException { - this.server = server; + this.server = new ServerSocket(serverPort); this.socketFileDescriptorGetter = sockFdGetter; } - public void terminate() { + @Override public void terminate() { this.running = false; + try { + server.close(); + } catch (IOException e) { + e.printStackTrace(); + } } @Override public void run() { while (running) { - Socket socket = null; try { - socket = server.accept(); + Socket socket = server.accept(); InputStream in = socket.getInputStream(); OutputStream out = socket.getOutputStream(); int magic = wrapBytes(recvAll(in, 4)).getInt(); @@ -155,22 +287,24 @@ public void terminate() { Thread processThread = new Thread(new ServerLoop(socket, socketFileDescriptorGetter)); processThread.setDaemon(true); processThread.start(); + } catch (SocketException e) { + // when terminates, this is what we expect, do nothing. } catch (IOException e) { e.printStackTrace(); terminate(); } } } + } - private byte[] recvAll(final InputStream in, final int nBytes) throws IOException { - byte[] res = new byte[nBytes]; - int nRead = 0; - while (nRead < nBytes) { - int chunk = in.read(res, nRead, Math.min(nBytes - nRead, 1024)); - nRead += chunk; - } - return res; + private static byte[] recvAll(final InputStream in, final int numBytes) throws IOException { + byte[] res = new byte[numBytes]; + int numRead = 0; + while (numRead < numBytes) { + int chunk = in.read(res, numRead, Math.min(numBytes - numRead, 1024)); + numRead += chunk; } + return res; } private static void closeQuietly(Socket socket) { @@ -197,6 +331,14 @@ private static byte[] toBytes(int number) { return bb.putInt(number).array(); } + private static byte[] toBytes(String str) { + byte[] bytes = new byte[str.length()]; + for (int i = 0; i < str.length(); ++i) { + bytes[i] = (byte) str.charAt(i); + } + return bytes; + } + private static String decodeToStr(byte[] bytes) { StringBuilder builder = new StringBuilder(); for (byte bt : bytes) { diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java new file mode 100644 index 000000000000..935a2f77702d --- /dev/null +++ b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.TVMValue; +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class RPCTest { + static class RefInt { + public int value; + } + + private static Server startServer(RefInt portRef) { + Server server = null; + int port = 9981; + for (int i = 0; i < 10; ++i) { + try { + server = new Server(port + i); + server.start(); + portRef.value = port + i; + return server; + } catch (IOException e) { + } + } + throw new RuntimeException("Cannot find an available port."); + } + + @Test + public void test_rpc_addone() { + if (!Module.enabled("rpc")) { + return; + } + Function.register("rpc.test.addone", new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + return args[0].asLong() + 1L; + } + }); + + RefInt port = new RefInt(); + Server server = null; + try { + server = startServer(port); + RPCSession client = Client.connect("localhost", port.value); + Function func = client.getFunction("rpc.test.addone"); + assertEquals(11L, func.call(10).asLong()); + } finally { + if (server != null) { + server.terminate(); + } + } + } + + @Test + public void test_rpc_strcat() { + if (!Module.enabled("rpc")) { + return; + } + Function.register("rpc.test.strcat", new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + return args[0].asString() + ":" + args[1].asLong(); + } + }); + + RefInt port = new RefInt(); + Server server = null; + try { + server = startServer(port); + RPCSession client = Client.connect("localhost", port.value); + Function func = client.getFunction("rpc.test.strcat"); + assertEquals("abc:11", func.call("abc", 11L).asString()); + } finally { + if (server != null) { + server.terminate(); + } + } + } +} diff --git a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc index 2162da296ddc..32e4b7ca7da4 100644 --- a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc @@ -158,27 +158,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall( TVMValue retVal; int retTypeCode; + + // function can be invoked recursively, + // thus we copy the pushed arguments here. + auto argValues = e->tvmFuncArgValues; + auto argTypes = e->tvmFuncArgTypes; + auto pushedStrs = e->tvmFuncArgPushedStrs; + auto pushedBytes = e->tvmFuncArgPushedBytes; + + e->tvmFuncArgPushedStrs.clear(); + e->tvmFuncArgPushedBytes.clear(); + e->tvmFuncArgTypes.clear(); + e->tvmFuncArgValues.clear(); + int ret = TVMFuncCall(reinterpret_cast(jhandle), - &e->tvmFuncArgValues[0], &e->tvmFuncArgTypes[0], numArgs, &retVal, &retTypeCode); + &argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode); - for (auto iter = e->tvmFuncArgPushedStrs.cbegin(); - iter != e->tvmFuncArgPushedStrs.cend(); iter++) { + for (auto iter = pushedStrs.cbegin(); iter != pushedStrs.cend(); iter++) { env->ReleaseStringUTFChars(iter->first, iter->second); env->DeleteGlobalRef(iter->first); } - for (auto iter = e->tvmFuncArgPushedBytes.cbegin(); - iter != e->tvmFuncArgPushedBytes.cend(); iter++) { + for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) { env->ReleaseByteArrayElements(iter->first, reinterpret_cast(const_cast(iter->second->data)), 0); env->DeleteGlobalRef(iter->first); delete iter->second; } - e->tvmFuncArgPushedStrs.clear(); - e->tvmFuncArgPushedBytes.clear(); - e->tvmFuncArgTypes.clear(); - e->tvmFuncArgValues.clear(); - // return TVMValue object to Java jclass refTVMValueCls = env->FindClass("ml/dmlc/tvm/Base$RefTVMValue"); jfieldID refTVMValueFid From 412a41d3797fc955ba52460b417418f81d9b2204 Mon Sep 17 00:00:00 2001 From: Yizhi Liu Date: Wed, 9 Aug 2017 00:18:02 +0800 Subject: [PATCH 3/3] [tvm4j] test case for proxy connection; thread pool for serving --- .../src/main/java/ml/dmlc/tvm/rpc/Server.java | 42 ++++++++++++------- .../test/java/ml/dmlc/tvm/rpc/RPCTest.java | 38 ++++++++++++++--- .../src/test/scripts/test_rpc_proxy_server.py | 20 +++++++++ tests/scripts/task_java_unittest.sh | 9 +++- 4 files changed, 87 insertions(+), 22 deletions(-) create mode 100644 jvm/core/src/test/scripts/test_rpc_proxy_server.py diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java index 493ebb1b2dbb..9be1859bb46e 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java @@ -33,12 +33,13 @@ import java.net.SocketException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; /** * RPC Server. */ public class Server { - private final Loop serverLoop; private static SocketFileDescriptorGetter defaultSocketFdGetter = new SocketFileDescriptorGetter() { @Override public int get(Socket socket) { @@ -52,6 +53,10 @@ public class Server { } } }; + private static final int DEFAULT_THREAD_NUMBER_IN_A_POOL = 20; + + private final Loop serverLoop; + private final ExecutorService threadPool; /** * Start a standalone server. @@ -60,10 +65,10 @@ public class Server { * @throws IOException if failed to bind localhost:port. */ public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException { - serverLoop = new ListenLoop(serverPort, socketFdGetter); + threadPool = setupThreadPool(); + serverLoop = new ListenLoop(serverPort, threadPool, socketFdGetter); } - /** * Start a standalone server. * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess @@ -84,7 +89,8 @@ public Server(int serverPort) throws IOException { */ public Server(String proxyHost, int proxyPort, String key, SocketFileDescriptorGetter socketFdGetter) { - serverLoop = new ConnectProxyLoop(proxyHost, proxyPort, key, socketFdGetter); + threadPool = setupThreadPool(); + serverLoop = new ConnectProxyLoop(proxyHost, proxyPort, key, threadPool, socketFdGetter); } /** @@ -99,6 +105,13 @@ public Server(String proxyHost, int proxyPort, String key) { this(proxyHost, proxyPort, key, defaultSocketFdGetter); } + private ExecutorService setupThreadPool() { + final String workerThreadNumber = System.getProperty("rpc.server.thread.number"); + final int numThread = (workerThreadNumber == null) + ? DEFAULT_THREAD_NUMBER_IN_A_POOL : Integer.parseInt(workerThreadNumber); + return Executors.newFixedThreadPool(numThread); + } + /** * Start the server. */ @@ -112,6 +125,7 @@ public void start() { public void terminate() { serverLoop.interrupt(); serverLoop.terminate(); + threadPool.shutdown(); } public static interface SocketFileDescriptorGetter { @@ -184,14 +198,17 @@ static class ConnectProxyLoop extends Loop { private final String host; private final int port; private final String key; + private final ExecutorService workerPool; private final SocketFileDescriptorGetter socketFileDescriptorGetter; private Socket waitingSocket = null; public ConnectProxyLoop(String host, int port, String key, + ExecutorService workerPool, SocketFileDescriptorGetter sockFdGetter) { this.host = host; this.port = port; this.key = "server:" + key; + this.workerPool = workerPool; socketFileDescriptorGetter = sockFdGetter; } @@ -229,10 +246,7 @@ public ConnectProxyLoop(String host, int port, String key, System.err.println("RPCProxy connected to " + address); waitingSocket = null; - // TODO(yizhi): use ExecutorService, i.e., ThreadPool - Thread processThread = new Thread(new ServerLoop(socket, socketFileDescriptorGetter)); - processThread.setDaemon(true); - processThread.start(); + workerPool.execute(new ServerLoop(socket, socketFileDescriptorGetter)); } catch (SocketException e) { // when terminates, this is what we expect, do nothing. } catch (IOException e) { @@ -245,12 +259,14 @@ public ConnectProxyLoop(String host, int port, String key, static class ListenLoop extends Loop { private final ServerSocket server; + private final ExecutorService workerPool; private final SocketFileDescriptorGetter socketFileDescriptorGetter; private volatile boolean running = true; - public ListenLoop(int serverPort, SocketFileDescriptorGetter sockFdGetter) - throws IOException { + public ListenLoop(int serverPort, ExecutorService workerPool, + SocketFileDescriptorGetter sockFdGetter) throws IOException { this.server = new ServerSocket(serverPort); + this.workerPool = workerPool; this.socketFileDescriptorGetter = sockFdGetter; } @@ -282,11 +298,7 @@ public ListenLoop(int serverPort, SocketFileDescriptorGetter sockFdGetter) out.write(toBytes(RPC.RPC_MAGIC)); } System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); - - // TODO(yizhi): use ExecutorService, i.e., ThreadPool - Thread processThread = new Thread(new ServerLoop(socket, socketFileDescriptorGetter)); - processThread.setDaemon(true); - processThread.start(); + workerPool.execute(new ServerLoop(socket, socketFileDescriptorGetter)); } catch (SocketException e) { // when terminates, this is what we expect, do nothing. } catch (IOException e) { diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java index 935a2f77702d..f4dd9eb9cb4e 100644 --- a/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java +++ b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java @@ -47,11 +47,11 @@ private static Server startServer(RefInt portRef) { } @Test - public void test_rpc_addone() { + public void test_addone() { if (!Module.enabled("rpc")) { return; } - Function.register("rpc.test.addone", new Function.Callback() { + Function.register("test.rpc.addone", new Function.Callback() { @Override public Object invoke(TVMValue... args) { return args[0].asLong() + 1L; } @@ -62,7 +62,7 @@ public void test_rpc_addone() { try { server = startServer(port); RPCSession client = Client.connect("localhost", port.value); - Function func = client.getFunction("rpc.test.addone"); + Function func = client.getFunction("test.rpc.addone"); assertEquals(11L, func.call(10).asLong()); } finally { if (server != null) { @@ -72,11 +72,11 @@ public void test_rpc_addone() { } @Test - public void test_rpc_strcat() { + public void test_strcat() { if (!Module.enabled("rpc")) { return; } - Function.register("rpc.test.strcat", new Function.Callback() { + Function.register("test.rpc.strcat", new Function.Callback() { @Override public Object invoke(TVMValue... args) { return args[0].asString() + ":" + args[1].asLong(); } @@ -87,7 +87,7 @@ public void test_rpc_strcat() { try { server = startServer(port); RPCSession client = Client.connect("localhost", port.value); - Function func = client.getFunction("rpc.test.strcat"); + Function func = client.getFunction("test.rpc.strcat"); assertEquals("abc:11", func.call("abc", 11L).asString()); } finally { if (server != null) { @@ -95,4 +95,30 @@ public void test_rpc_strcat() { } } } + + @Test + public void test_connect_proxy_server() { + String proxyHost = System.getProperty("test.rpc.proxy.host"); + int proxyPort = Integer.parseInt(System.getProperty("test.rpc.proxy.port")); + + Function.register("test.rpc.proxy.addone", new Function.Callback() { + @Override public Object invoke(TVMValue... tvmValues) { + return tvmValues[0].asLong() + 1L; + } + }); + + Server server = null; + try { + server = new Server(proxyHost, proxyPort, "x1"); + server.start(); + + RPCSession client = Client.connect(proxyHost, proxyPort, "x1"); + Function f1 = client.getFunction("test.rpc.proxy.addone"); + assertEquals(11L, f1.call(10L).asLong()); + } finally { + if (server != null) { + server.terminate(); + } + } + } } diff --git a/jvm/core/src/test/scripts/test_rpc_proxy_server.py b/jvm/core/src/test/scripts/test_rpc_proxy_server.py new file mode 100644 index 000000000000..3f1f6466c715 --- /dev/null +++ b/jvm/core/src/test/scripts/test_rpc_proxy_server.py @@ -0,0 +1,20 @@ +import time +from tvm.contrib import rpc_proxy + +def start_proxy_server(port, timeout): + prox = rpc_proxy.Proxy("localhost", port=port, port_end=port+1) + if timeout > 0: + import time + time.sleep(timeout) + prox.terminate() + else: + prox.proc.join() + +if __name__ == "__main__": + import sys + if len(sys.argv) < 2: + sys.exit(-1) + port = int(sys.argv[1]) + timeout = 0 if len(sys.argv) == 2 else float(sys.argv[2]) + start_proxy_server(port, timeout) + diff --git a/tests/scripts/task_java_unittest.sh b/tests/scripts/task_java_unittest.sh index bb1f98247f36..9017079237f1 100755 --- a/tests/scripts/task_java_unittest.sh +++ b/tests/scripts/task_java_unittest.sh @@ -10,7 +10,14 @@ TEMP_DIR=$(mktemp -d) python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1 python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1 +# start rpc proxy server +PORT=$(( ( RANDOM % 1000 ) + 9000 )) +python $SCRIPT_DIR/test_rpc_proxy_server.py $PORT 30 & + make jvmpkg || exit -1 -make jvmpkg JVM_TEST_ARGS="-DskipTests=false -Dtest.tempdir=$TEMP_DIR" || exit -1 +make jvmpkg JVM_TEST_ARGS="-DskipTests=false \ + -Dtest.tempdir=$TEMP_DIR \ + -Dtest.rpc.proxy.host=localhost \ + -Dtest.rpc.proxy.port=$PORT" || exit -1 rm -rf $TEMP_DIR