diff --git a/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml b/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml index f5b6c1b1b9e0..404f935a1f3e 100644 --- a/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml +++ b/jvm/assembly/osx-x86_64-cpu/src/main/assembly/assembly.xml @@ -6,7 +6,7 @@ false - ../../../lib/libtvm_runtime.so + ../../../lib/libtvm_runtime.dylib lib/native 0644 diff --git a/jvm/core/pom.xml b/jvm/core/pom.xml index 62d6a866ae5d..8aed78525189 100644 --- a/jvm/core/pom.xml +++ b/jvm/core/pom.xml @@ -20,18 +20,21 @@ osx-x86_64-cpu osx-x86_64-cpu + libtvm_runtime.dylib linux-x86_64-cpu linux-x86_64-cpu + libtvm_runtime.so linux-x86_64-gpu linux-x86_64-gpu + libtvm_runtime.so @@ -88,7 +91,7 @@ 1 -Djava.library.path=${project.parent.basedir}/native/${platform}/target - -Dlibtvm.so.path=${project.parent.basedir}/../lib/libtvm_runtime.so + -Dlibtvm.so.path=${project.parent.basedir}/../lib/${libtvm.so.filename} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Base.java b/jvm/core/src/main/java/ml/dmlc/tvm/Base.java index ec88325b1a8b..09d1efe42097 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/Base.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/Base.java @@ -80,7 +80,18 @@ public RefTVMValue() { if (tvmLibFilename == null || !new File(tvmLibFilename).isFile() || _LIB.nativeLibInit(tvmLibFilename) != 0) { try { - NativeLibraryLoader.extractResourceFileToTempDir("libtvm_runtime.so", new Action() { + String runtimeLibname; + String os = System.getProperty("os.name"); + // ref: http://lopica.sourceforge.net/os.html + if (os.startsWith("Linux")) { + runtimeLibname = "libtvm_runtime.so"; + } else if (os.startsWith("Mac")) { + runtimeLibname = "libtvm_runtime.dylib"; + } else { + // TODO(yizhi) support windows later + throw new UnsatisfiedLinkError("Windows not supported currently"); + } + NativeLibraryLoader.extractResourceFileToTempDir(runtimeLibname, new Action() { @Override public void invoke(File target) { System.err.println("Loading tvm runtime from " + target.getPath()); checkCall(_LIB.nativeLibInit(target.getPath())); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java index 3f38a5aba155..63602f3a14d0 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/Function.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/Function.java @@ -34,7 +34,7 @@ public class Function extends TVMValue { * @param name full function name. * @return TVM function. */ - static Function getFunction(final String name) { + public static Function getFunction(final String name) { for (String fullName : listGlobalFuncNames()) { if (fullName.equals(name)) { return getGlobalFunc(fullName, true, false); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java b/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java index c073416ae0c4..396f740b914f 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/NativeLibraryLoader.java @@ -30,7 +30,7 @@ class NativeLibraryLoader { static { try { - tempDir = File.createTempFile("tvm", ""); + tempDir = File.createTempFile("tvm4j", ""); if (!tempDir.delete() || !tempDir.mkdir()) { throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath()); } diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java b/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java index e91b90cce083..0d108e0a2943 100644 --- a/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java +++ b/jvm/core/src/main/java/ml/dmlc/tvm/TVMContext.java @@ -17,12 +17,12 @@ package ml.dmlc.tvm; +import ml.dmlc.tvm.rpc.RPC; + import java.util.HashMap; import java.util.Map; public class TVMContext { - private static final int RPC_SESS_MASK = 128; - private static final Map MASK2STR = new HashMap(); private static final Map STR2MASK = new HashMap(); @@ -169,9 +169,9 @@ public void sync() { } @Override public String toString() { - if (deviceType >= RPC_SESS_MASK) { - int tblId = deviceType / RPC_SESS_MASK - 1; - int devType = deviceType % RPC_SESS_MASK; + if (deviceType >= RPC.RPC_SESS_MASK) { + int tblId = deviceType / RPC.RPC_SESS_MASK - 1; + int devType = deviceType % RPC.RPC_SESS_MASK; return String.format("remote[%d]:%s(%d)", tblId, MASK2STR.get(devType), deviceId); } return String.format("%s(%d)", MASK2STR.get(deviceType), deviceId); diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java new file mode 100644 index 000000000000..20292b7a6f82 --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Client.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.TVMValue; + +public class Client { + /** + * Connect to RPC Server. + * @param url The url of the host. + * @param port The port to connect to. + * @param key Additional key to match server. + * @return The connected session. + */ + public static RPCSession connect(String url, int port, String key) { + Function doConnect = RPC.getApi("_Connect"); + if (doConnect == null) { + throw new RuntimeException("Please compile with USE_RPC=1"); + } + TVMValue sess = doConnect.pushArg(url).pushArg(port).pushArg(key).invoke(); + return new RPCSession(sess.asModule()); + } + + /** + * Connect to RPC Server. + * @param url The url of the host. + * @param port The port to connect to. + * @return The connected session. + */ + public static RPCSession connect(String url, int port) { + return connect(url, port, ""); + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java new file mode 100644 index 000000000000..e3b8b9366751 --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; + +import java.util.HashMap; +import java.util.Map; + +public class RPC { + public static final int RPC_MAGIC = 0xff271; + public static final int RPC_SESS_MASK = 128; + + private static ThreadLocal> apiFuncs + = new ThreadLocal>() { + @Override + protected Map initialValue() { + return new HashMap(); + } + }; + + static Function getApi(String name) { + Function func = apiFuncs.get().get(name); + if (func == null) { + func = Function.getFunction("contrib.rpc." + name); + if (func == null) { + return null; + } + apiFuncs.get().put(name, func); + } + return func; + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java new file mode 100644 index 000000000000..cb4ccf49434b --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.TVMContext; + +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.HashMap; +import java.util.Map; + +/** + * RPC Client session module. + * Do not directly create the object, use Client.connect. + */ +public class RPCSession { + private final Module session; + private final int tblIndex; + private final Map remoteFuncs = new HashMap(); + + RPCSession(Module sess) { + session = sess; + tblIndex = (int) RPC.getApi("_SessTableIndex").pushArg(session).invoke().asLong(); + } + + /** + * Get function from the session. + * @param name The name of the function. + * @return The result function. + */ + public Function getFunction(String name) { + return session.getFunction(name); + } + + /** + * Construct a remote context. + * @param devType device type. + * @param devId device id. + * @return The corresponding encoded remote context. + */ + public TVMContext context(String devType, int devId) { + TVMContext ctx = new TVMContext(devType, devId); + int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK; + return new TVMContext(ctx.deviceType + encode, devId); + } + + /** + * Construct a remote context. + * @param devType device type. + * @return The corresponding encoded remote context. + */ + public TVMContext context(String devType) { + return context(devType, 0); + } + + /** + * Construct a remote context. + * @param devType device type. + * @param devId device id. + * @return The corresponding encoded remote context. + */ + public TVMContext context(int devType, int devId) { + int encode = (tblIndex + 1) * RPC.RPC_SESS_MASK; + return new TVMContext(devType + encode, devId); + } + + /** + * Construct a remote context. + * @param devType device type. + * @return The corresponding encoded remote context. + */ + public TVMContext context(int devType) { + return context(devType, 0); + } + + /** + * Construct remote CPU device. + * @param devId device id. + * @return Remote CPU context. + */ + public TVMContext cpu(int devId) { + return context(1, devId); + } + + /** + * Construct remote CPU device. + * @return Remote CPU context. + */ + public TVMContext cpu() { + return cpu(0); + } + + /** + * Construct remote GPU device. + * @param devId device id. + * @return Remote GPU context. + */ + public TVMContext gpu(int devId) { + return context(2, devId); + } + + /** + * Construct remote GPU device. + * @return Remote GPU context. + */ + public TVMContext gpu() { + return gpu(0); + } + + /** + * Construct remote OpenCL device. + * @param devId device id. + * @return Remote OpenCL context. + */ + public TVMContext cl(int devId) { + return context(4, devId); + } + + /** + * Construct remote OpenCL device. + * @return Remote OpenCL context. + */ + public TVMContext cl() { + return cl(0); + } + + /** + * Construct remote Metal device. + * @param devId device id. + * @return Remote metal context. + */ + public TVMContext metal(int devId) { + return context(8, devId); + } + + /** + * Construct remote Metal device. + * @return Remote metal context. + */ + public TVMContext metal() { + return metal(0); + } + + /** + * Upload binary to remote runtime temp folder. + * @param data The binary in local to upload. + * @param target The path in remote, cannot be null. + */ + public void upload(byte[] data, String target) { + if (target == null) { + throw new IllegalArgumentException("Please specify the upload target"); + } + final String funcName = "upload"; + Function remoteFunc = remoteFuncs.get(funcName); + if (remoteFunc == null) { + remoteFunc = getFunction("tvm.contrib.rpc.server.upload"); + remoteFuncs.put(funcName, remoteFunc); + } + remoteFunc.pushArg(target).pushArg(data).invoke(); + } + + /** + * Upload file to remote runtime temp folder. + * @param data The file in local to upload. + * @param target The path in remote. + */ + public void upload(File data, String target) throws IOException { + byte[] blob = getBytesFromFile(data); + upload(blob, target); + } + + /** + * Upload file to remote runtime temp folder. + * @param data The file in local to upload. + */ + public void upload(File data) throws IOException { + upload(data, data.getName()); + } + + /** + * Download file from remote temp folder. + * @param path The relative location to remote temp folder. + * @return The result blob from the file. + */ + public byte[] download(String path) { + final String name = "download"; + Function func = remoteFuncs.get(name); + if (func == null) { + func = getFunction("tvm.contrib.rpc.server.download"); + remoteFuncs.put(name, func); + } + return func.pushArg(path).invoke().asBytes(); + } + + /** + * Load a remote module, the file need to be uploaded first. + * @param path The relative location to remote temp folder. + * @return The remote module containing remote function. + */ + public Module loadModule(String path) { + return RPC.getApi("_LoadRemoteModule").pushArg(session).pushArg(path).invoke().asModule(); + } + + + private static byte[] getBytesFromFile(File file) throws IOException { + // Get the size of the file + long length = file.length(); + + if (length > Integer.MAX_VALUE) { + throw new IOException("File " + file.getName() + " is too large!"); + } + + // cannot create an array using a long type. + byte[] bytes = new byte[(int)length]; + + // Read in the bytes + int offset = 0; + int numRead = 0; + + InputStream is = new FileInputStream(file); + try { + while (offset < bytes.length + && (numRead = is.read(bytes, offset, bytes.length - offset)) >= 0) { + offset += numRead; + } + } finally { + is.close(); + } + + // Ensure all the bytes have been read in + if (offset < bytes.length) { + throw new IOException("Could not completely read file " + file.getName()); + } + return bytes; + } +} diff --git a/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java new file mode 100644 index 000000000000..9be1859bb46e --- /dev/null +++ b/jvm/core/src/main/java/ml/dmlc/tvm/rpc/Server.java @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.TVMValue; +import sun.misc.SharedSecrets; + +import java.io.File; +import java.io.FileDescriptor; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * RPC Server. + */ +public class Server { + private static SocketFileDescriptorGetter defaultSocketFdGetter + = new SocketFileDescriptorGetter() { + @Override public int get(Socket socket) { + try { + InputStream is = socket.getInputStream(); + FileDescriptor fd = ((FileInputStream) is).getFD(); + return SharedSecrets.getJavaIOFileDescriptorAccess().get(fd); + } catch (IOException e) { + e.printStackTrace(); + return -1; + } + } + }; + private static final int DEFAULT_THREAD_NUMBER_IN_A_POOL = 20; + + private final Loop serverLoop; + private final ExecutorService threadPool; + + /** + * Start a standalone server. + * @param serverPort Port. + * @param socketFdGetter Method to get system file descriptor of the server socket. + * @throws IOException if failed to bind localhost:port. + */ + public Server(int serverPort, SocketFileDescriptorGetter socketFdGetter) throws IOException { + threadPool = setupThreadPool(); + serverLoop = new ListenLoop(serverPort, threadPool, socketFdGetter); + } + + /** + * Start a standalone server. + * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess + * to get file descriptor for the socket. + * @param serverPort Port. + * @throws IOException if failed to bind localhost:port. + */ + public Server(int serverPort) throws IOException { + this(serverPort, defaultSocketFdGetter); + } + + /** + * Start a server connected to proxy. + * @param proxyHost The proxy server host. + * @param proxyPort The proxy server port. + * @param key The key to identify the server. + * @param socketFdGetter Method to get system file descriptor of the server socket. + */ + public Server(String proxyHost, int proxyPort, String key, + SocketFileDescriptorGetter socketFdGetter) { + threadPool = setupThreadPool(); + serverLoop = new ConnectProxyLoop(proxyHost, proxyPort, key, threadPool, socketFdGetter); + } + + /** + * Start a server connected to proxy. + * Use sun.misc.SharedSecrets.getJavaIOFileDescriptorAccess + * to get file descriptor for the socket. + * @param proxyHost The proxy server host. + * @param proxyPort The proxy server port. + * @param key The key to identify the server. + */ + public Server(String proxyHost, int proxyPort, String key) { + this(proxyHost, proxyPort, key, defaultSocketFdGetter); + } + + private ExecutorService setupThreadPool() { + final String workerThreadNumber = System.getProperty("rpc.server.thread.number"); + final int numThread = (workerThreadNumber == null) + ? DEFAULT_THREAD_NUMBER_IN_A_POOL : Integer.parseInt(workerThreadNumber); + return Executors.newFixedThreadPool(numThread); + } + + /** + * Start the server. + */ + public void start() { + serverLoop.start(); + } + + /** + * Stop the server. + */ + public void terminate() { + serverLoop.interrupt(); + serverLoop.terminate(); + threadPool.shutdown(); + } + + public static interface SocketFileDescriptorGetter { + public int get(Socket socket); + } + + static class ServerLoop implements Runnable { + private final Socket socket; + private final SocketFileDescriptorGetter socketFdGetter; + + ServerLoop(Socket socket, SocketFileDescriptorGetter fdGetter) { + this.socket = socket; + socketFdGetter = fdGetter; + } + + @Override public void run() { + int sockFd = socketFdGetter.get(socket); + if (sockFd != -1) { + File tempDir = null; + try { + tempDir = serverEnv(); + RPC.getApi("_ServerLoop").pushArg(sockFd).invoke(); + System.err.println("Finish serving " + socket.getRemoteSocketAddress().toString()); + } catch (IOException e) { + e.printStackTrace(); + } finally { + if (tempDir != null) { + if (!tempDir.delete()) { + System.err.println( + "[WARN] Couldn't delete temporary directory " + tempDir.getAbsolutePath()); + } + } + closeQuietly(socket); + } + } + } + + private File serverEnv() throws IOException { + // Server environment function return temp dir. + final File tempDir = File.createTempFile("tvm4j_rpc_", ""); + if (!tempDir.delete() || !tempDir.mkdir()) { + throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath()); + } + + Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + return tempDir + File.separator + args[0].asString(); + } + }, true); + + Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + String filename = args[0].asString(); + String path = tempDir + File.separator + filename; + System.err.println("Load module from " + path); + return Module.load(path); + } + }, true); + + return tempDir; + } + } + + abstract static class Loop extends Thread { + public abstract void terminate(); + } + + static class ConnectProxyLoop extends Loop { + private volatile boolean running = true; + private final String host; + private final int port; + private final String key; + private final ExecutorService workerPool; + private final SocketFileDescriptorGetter socketFileDescriptorGetter; + private Socket waitingSocket = null; + + public ConnectProxyLoop(String host, int port, String key, + ExecutorService workerPool, + SocketFileDescriptorGetter sockFdGetter) { + this.host = host; + this.port = port; + this.key = "server:" + key; + this.workerPool = workerPool; + socketFileDescriptorGetter = sockFdGetter; + } + + @Override public void terminate() { + running = false; + if (waitingSocket != null) { + try { + waitingSocket.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + } + + @Override public void run() { + while (running) { + try { + Socket socket = new Socket(host, port); + waitingSocket = socket; + InputStream in = socket.getInputStream(); + OutputStream out = socket.getOutputStream(); + out.write(toBytes(RPC.RPC_MAGIC)); + out.write(toBytes(key.length())); + out.write(toBytes(key)); + int magic = wrapBytes(recvAll(in, 4)).getInt(); + final String address = host + ":" + port; + if (magic == RPC.RPC_MAGIC + 1) { + throw new RuntimeException( + String.format("key: %s has already been used in proxy", key)); + } else if (magic == RPC.RPC_MAGIC + 2) { + System.err.println("RPCProxy do not have matching client key " + key); + } else if (magic != RPC.RPC_MAGIC) { + throw new RuntimeException(address + " is not RPC Proxy"); + } + System.err.println("RPCProxy connected to " + address); + + waitingSocket = null; + workerPool.execute(new ServerLoop(socket, socketFileDescriptorGetter)); + } catch (SocketException e) { + // when terminates, this is what we expect, do nothing. + } catch (IOException e) { + e.printStackTrace(); + terminate(); + } + } + } + } + + static class ListenLoop extends Loop { + private final ServerSocket server; + private final ExecutorService workerPool; + private final SocketFileDescriptorGetter socketFileDescriptorGetter; + private volatile boolean running = true; + + public ListenLoop(int serverPort, ExecutorService workerPool, + SocketFileDescriptorGetter sockFdGetter) throws IOException { + this.server = new ServerSocket(serverPort); + this.workerPool = workerPool; + this.socketFileDescriptorGetter = sockFdGetter; + } + + @Override public void terminate() { + this.running = false; + try { + server.close(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + @Override public void run() { + while (running) { + try { + Socket socket = server.accept(); + InputStream in = socket.getInputStream(); + OutputStream out = socket.getOutputStream(); + int magic = wrapBytes(recvAll(in, 4)).getInt(); + if (magic != RPC.RPC_MAGIC) { + closeQuietly(socket); + continue; + } + int keyLen = wrapBytes(recvAll(in, 4)).getInt(); + String key = decodeToStr(recvAll(in, keyLen)); + if (!key.startsWith("client:")) { + out.write(toBytes(RPC.RPC_MAGIC + 2)); + } else { + out.write(toBytes(RPC.RPC_MAGIC)); + } + System.err.println("Connection from " + socket.getRemoteSocketAddress().toString()); + workerPool.execute(new ServerLoop(socket, socketFileDescriptorGetter)); + } catch (SocketException e) { + // when terminates, this is what we expect, do nothing. + } catch (IOException e) { + e.printStackTrace(); + terminate(); + } + } + } + } + + private static byte[] recvAll(final InputStream in, final int numBytes) throws IOException { + byte[] res = new byte[numBytes]; + int numRead = 0; + while (numRead < numBytes) { + int chunk = in.read(res, numRead, Math.min(numBytes - numRead, 1024)); + numRead += chunk; + } + return res; + } + + private static void closeQuietly(Socket socket) { + if (socket != null) { + try { + socket.shutdownInput(); + socket.shutdownOutput(); + socket.close(); + } catch (IOException ioe) { + // close quietly, do nothing. + } + } + } + + private static ByteBuffer wrapBytes(byte[] bytes) { + ByteBuffer bb = ByteBuffer.wrap(bytes); + bb.order(ByteOrder.LITTLE_ENDIAN); + return bb; + } + + private static byte[] toBytes(int number) { + ByteBuffer bb = ByteBuffer.allocate(4); + bb.order(ByteOrder.LITTLE_ENDIAN); + return bb.putInt(number).array(); + } + + private static byte[] toBytes(String str) { + byte[] bytes = new byte[str.length()]; + for (int i = 0; i < str.length(); ++i) { + bytes[i] = (byte) str.charAt(i); + } + return bytes; + } + + private static String decodeToStr(byte[] bytes) { + StringBuilder builder = new StringBuilder(); + for (byte bt : bytes) { + builder.append((char) bt); + } + return builder.toString(); + } +} diff --git a/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java new file mode 100644 index 000000000000..f4dd9eb9cb4e --- /dev/null +++ b/jvm/core/src/test/java/ml/dmlc/tvm/rpc/RPCTest.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ml.dmlc.tvm.rpc; + +import ml.dmlc.tvm.Function; +import ml.dmlc.tvm.Module; +import ml.dmlc.tvm.TVMValue; +import org.junit.Test; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; + +public class RPCTest { + static class RefInt { + public int value; + } + + private static Server startServer(RefInt portRef) { + Server server = null; + int port = 9981; + for (int i = 0; i < 10; ++i) { + try { + server = new Server(port + i); + server.start(); + portRef.value = port + i; + return server; + } catch (IOException e) { + } + } + throw new RuntimeException("Cannot find an available port."); + } + + @Test + public void test_addone() { + if (!Module.enabled("rpc")) { + return; + } + Function.register("test.rpc.addone", new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + return args[0].asLong() + 1L; + } + }); + + RefInt port = new RefInt(); + Server server = null; + try { + server = startServer(port); + RPCSession client = Client.connect("localhost", port.value); + Function func = client.getFunction("test.rpc.addone"); + assertEquals(11L, func.call(10).asLong()); + } finally { + if (server != null) { + server.terminate(); + } + } + } + + @Test + public void test_strcat() { + if (!Module.enabled("rpc")) { + return; + } + Function.register("test.rpc.strcat", new Function.Callback() { + @Override public Object invoke(TVMValue... args) { + return args[0].asString() + ":" + args[1].asLong(); + } + }); + + RefInt port = new RefInt(); + Server server = null; + try { + server = startServer(port); + RPCSession client = Client.connect("localhost", port.value); + Function func = client.getFunction("test.rpc.strcat"); + assertEquals("abc:11", func.call("abc", 11L).asString()); + } finally { + if (server != null) { + server.terminate(); + } + } + } + + @Test + public void test_connect_proxy_server() { + String proxyHost = System.getProperty("test.rpc.proxy.host"); + int proxyPort = Integer.parseInt(System.getProperty("test.rpc.proxy.port")); + + Function.register("test.rpc.proxy.addone", new Function.Callback() { + @Override public Object invoke(TVMValue... tvmValues) { + return tvmValues[0].asLong() + 1L; + } + }); + + Server server = null; + try { + server = new Server(proxyHost, proxyPort, "x1"); + server.start(); + + RPCSession client = Client.connect(proxyHost, proxyPort, "x1"); + Function f1 = client.getFunction("test.rpc.proxy.addone"); + assertEquals(11L, f1.call(10L).asLong()); + } finally { + if (server != null) { + server.terminate(); + } + } + } +} diff --git a/jvm/core/src/test/scripts/test_rpc_proxy_server.py b/jvm/core/src/test/scripts/test_rpc_proxy_server.py new file mode 100644 index 000000000000..3f1f6466c715 --- /dev/null +++ b/jvm/core/src/test/scripts/test_rpc_proxy_server.py @@ -0,0 +1,20 @@ +import time +from tvm.contrib import rpc_proxy + +def start_proxy_server(port, timeout): + prox = rpc_proxy.Proxy("localhost", port=port, port_end=port+1) + if timeout > 0: + import time + time.sleep(timeout) + prox.terminate() + else: + prox.proc.join() + +if __name__ == "__main__": + import sys + if len(sys.argv) < 2: + sys.exit(-1) + port = int(sys.argv[1]) + timeout = 0 if len(sys.argv) == 2 else float(sys.argv[2]) + start_proxy_server(port, timeout) + diff --git a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc index 2162da296ddc..32e4b7ca7da4 100644 --- a/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc +++ b/jvm/native/src/main/native/ml_dmlc_tvm_native_c_api.cc @@ -158,27 +158,33 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_tvm_LibInfo_tvmFuncCall( TVMValue retVal; int retTypeCode; + + // function can be invoked recursively, + // thus we copy the pushed arguments here. + auto argValues = e->tvmFuncArgValues; + auto argTypes = e->tvmFuncArgTypes; + auto pushedStrs = e->tvmFuncArgPushedStrs; + auto pushedBytes = e->tvmFuncArgPushedBytes; + + e->tvmFuncArgPushedStrs.clear(); + e->tvmFuncArgPushedBytes.clear(); + e->tvmFuncArgTypes.clear(); + e->tvmFuncArgValues.clear(); + int ret = TVMFuncCall(reinterpret_cast(jhandle), - &e->tvmFuncArgValues[0], &e->tvmFuncArgTypes[0], numArgs, &retVal, &retTypeCode); + &argValues[0], &argTypes[0], numArgs, &retVal, &retTypeCode); - for (auto iter = e->tvmFuncArgPushedStrs.cbegin(); - iter != e->tvmFuncArgPushedStrs.cend(); iter++) { + for (auto iter = pushedStrs.cbegin(); iter != pushedStrs.cend(); iter++) { env->ReleaseStringUTFChars(iter->first, iter->second); env->DeleteGlobalRef(iter->first); } - for (auto iter = e->tvmFuncArgPushedBytes.cbegin(); - iter != e->tvmFuncArgPushedBytes.cend(); iter++) { + for (auto iter = pushedBytes.cbegin(); iter != pushedBytes.cend(); iter++) { env->ReleaseByteArrayElements(iter->first, reinterpret_cast(const_cast(iter->second->data)), 0); env->DeleteGlobalRef(iter->first); delete iter->second; } - e->tvmFuncArgPushedStrs.clear(); - e->tvmFuncArgPushedBytes.clear(); - e->tvmFuncArgTypes.clear(); - e->tvmFuncArgValues.clear(); - // return TVMValue object to Java jclass refTVMValueCls = env->FindClass("ml/dmlc/tvm/Base$RefTVMValue"); jfieldID refTVMValueFid diff --git a/tests/scripts/task_java_unittest.sh b/tests/scripts/task_java_unittest.sh index bb1f98247f36..9017079237f1 100755 --- a/tests/scripts/task_java_unittest.sh +++ b/tests/scripts/task_java_unittest.sh @@ -10,7 +10,14 @@ TEMP_DIR=$(mktemp -d) python $SCRIPT_DIR/test_add_cpu.py $TEMP_DIR || exit -1 python $SCRIPT_DIR/test_add_gpu.py $TEMP_DIR || exit -1 +# start rpc proxy server +PORT=$(( ( RANDOM % 1000 ) + 9000 )) +python $SCRIPT_DIR/test_rpc_proxy_server.py $PORT 30 & + make jvmpkg || exit -1 -make jvmpkg JVM_TEST_ARGS="-DskipTests=false -Dtest.tempdir=$TEMP_DIR" || exit -1 +make jvmpkg JVM_TEST_ARGS="-DskipTests=false \ + -Dtest.tempdir=$TEMP_DIR \ + -Dtest.rpc.proxy.host=localhost \ + -Dtest.rpc.proxy.port=$PORT" || exit -1 rm -rf $TEMP_DIR