Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RPC] graduate tvm.contrib.rpc -> tvm.rpc #1410

Merged
merged 4 commits into from
Jul 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion apps/android_rpc/tests/android_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import tvm
import os
from tvm.contrib import rpc, util, ndk
from tvm import rpc
from tvm.contrib import util, ndk
import numpy as np

# Set to be address of tvm proxy.
Expand Down
3 changes: 2 additions & 1 deletion apps/ios_rpc/tests/ios_rpc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import tvm
import os
from tvm.contrib import rpc, util, xcode
from tvm import rpc
from tvm.contrib import util, xcode
import numpy as np

# Set to be address of tvm proxy.
Expand Down
4 changes: 2 additions & 2 deletions apps/ios_rpc/tvmrpc/TVMRuntime.mm
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,13 @@ void LaunchSyncServer() {
->ServerLoop();
}

TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.workpath")
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
.set_body([](TVMArgs args, TVMRetValue* rv) {
static RPCEnv env;
*rv = env.GetPath(args[0]);
});

TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.load_module")
TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string name = args[0];
std::string fmt = GetFileFormat(name, "");
Expand Down
18 changes: 9 additions & 9 deletions docs/api/python/rpc.rst
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
tvm.contrib.rpc
---------------
.. automodule:: tvm.contrib.rpc
tvm.rpc
-------
.. automodule:: tvm.rpc

.. autofunction:: tvm.contrib.rpc.connect
.. autofunction:: tvm.contrib.rpc.connect_tracker
.. autofunction:: tvm.rpc.connect
.. autofunction:: tvm.rpc.connect_tracker

.. autoclass:: tvm.contrib.rpc.TrackerSession
.. autoclass:: tvm.rpc.TrackerSession
:members:
:inherited-members:

.. autoclass:: tvm.contrib.rpc.RPCSession
.. autoclass:: tvm.rpc.RPCSession
:members:
:inherited-members:

.. autoclass:: tvm.contrib.rpc.LocalSession
.. autoclass:: tvm.rpc.LocalSession
:members:
:inherited-members:

.. autoclass:: tvm.contrib.rpc.Server
.. autoclass:: tvm.rpc.Server
:members:
:inherited-members:
6 changes: 3 additions & 3 deletions docs/install/docker.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ We can then use the following command to launch a `tvmai/demo-cpu` image.

.. code:: bash

/path/to/tvm/docker/bash.sh tvmai/demo_cpu
/path/to/tvm/docker/bash.sh tvmai/demo-cpu

.. note::
You can find all the prebuilt images in `<https://hub.docker.com/r/tvmai/>`_
You can also change `demo-cpu` to `demo-gpu` to get a CUDA enabled image.
You can find all the prebuilt images in `<https://hub.docker.com/r/tvmai/>`_


This auxiliary script does the following things:
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/rpc/NativeServerLoop.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,13 @@ private static File serverEnv() throws IOException {
throw new IOException("Couldn't create directory " + tempDir.getAbsolutePath());
}

Function.register("tvm.contrib.rpc.server.workpath", new Function.Callback() {
Function.register("tvm.rpc.server.workpath", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
return tempDir + File.separator + args[0].asString();
}
}, true);

Function.register("tvm.contrib.rpc.server.load_module", new Function.Callback() {
Function.register("tvm.rpc.server.load_module", new Function.Callback() {
@Override public Object invoke(TVMValue... args) {
String filename = args[0].asString();
String path = tempDir + File.separator + filename;
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPC.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ protected Map<String, Function> initialValue() {
static Function getApi(String name) {
Function func = apiFuncs.get().get(name);
if (func == null) {
func = Function.getFunction("contrib.rpc." + name);
func = Function.getFunction("rpc." + name);
if (func == null) {
return null;
}
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/ml/dmlc/tvm/rpc/RPCSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public void upload(byte[] data, String target) {
final String funcName = "upload";
Function remoteFunc = remoteFuncs.get(funcName);
if (remoteFunc == null) {
remoteFunc = getFunction("tvm.contrib.rpc.server.upload");
remoteFunc = getFunction("tvm.rpc.server.upload");
remoteFuncs.put(funcName, remoteFunc);
}
remoteFunc.pushArg(target).pushArg(data).invoke();
Expand Down Expand Up @@ -205,7 +205,7 @@ public byte[] download(String path) {
final String name = "download";
Function func = remoteFuncs.get(name);
if (func == null) {
func = getFunction("tvm.contrib.rpc.server.download");
func = getFunction("tvm.rpc.server.download");
remoteFuncs.put(name, func);
}
return func.pushArg(path).invoke().asBytes();
Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/test/scripts/test_rpc_proxy_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from tvm.contrib.rpc import proxy
from tvm.rpc import proxy

def start_proxy_server(port, timeout):
prox = proxy.Proxy("localhost", port=port, port_end=port+1)
Expand Down
3 changes: 2 additions & 1 deletion nnvm/tests/python/compiler/test_param_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np
import nnvm.compiler
import tvm
from tvm.contrib import rpc, util, graph_runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime


def test_save_load():
Expand Down
3 changes: 2 additions & 1 deletion nnvm/tests/python/compiler/test_rpc_exec.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tvm
from tvm.contrib import util, rpc, graph_runtime
from tvm import rpc
from tvm.contrib import util, graph_runtime
import nnvm.symbol as sym
import nnvm.compiler
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/graph_runtime.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Minimum graph runtime that executes graph containing TVM PackedFunc."""
from .._ffi.base import string_types
from .._ffi.function import get_global_func
from .rpc import base as rpc_base
from ..rpc import base as rpc_base
from .. import ndarray as nd


Expand Down
12 changes: 6 additions & 6 deletions python/tvm/contrib/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
"""measure bandwidth and compute peak"""

import logging

import tvm
from tvm.contrib import rpc, util
from . import util
from .. import rpc

def _convert_to_remote(func, remote):
""" convert module function to remote rpc function"""
Expand Down Expand Up @@ -47,7 +47,7 @@ def measure_bandwidth_sum(total_item, item_per_thread, stride,
host compilation target
ctx: TVMcontext
the context of array
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
n_times: int
number of runs for taking mean
Expand Down Expand Up @@ -107,7 +107,7 @@ def measure_bandwidth_all_types(total_item, item_per_thread, n_times,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
ctx: TVMcontext
the context of array
Expand Down Expand Up @@ -165,7 +165,7 @@ def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
if it is not None, use remote rpc session
ctx: TVMcontext
the context of array
Expand Down Expand Up @@ -250,7 +250,7 @@ def measure_compute_all_types(total_item, item_per_thread, n_times,
the target and option of the compilation.
target_host : str or :any:`tvm.target.Target`
host compilation target
remote: tvm.contrib.rpc.RPCSession
remote: tvm.rpc.RPCSession
remote rpc session
ctx: TVMcontext
the context of array
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/contrib/rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Deprecation RPC module"""
# pylint: disable=unused-import
from __future__ import absolute_import as _abs
import warnings
from ..rpc import Server, RPCSession, LocalSession, TrackerSession, connect, connect_tracker

warnings.warn(
"Please use tvm.rpc instead of tvm.conrtib.rpc. tvm.contrib.rpc is going to be removed in 0.5",
DeprecationWarning)
2 changes: 1 addition & 1 deletion python/tvm/exec/query_rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import argparse
import os
from ..contrib import rpc
from .. import rpc

def main():
"""Main funciton"""
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/exec/rpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import multiprocessing
import sys
import os
from ..contrib.rpc.proxy import Proxy
from ..rpc.proxy import Proxy


def find_example_resource():
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/exec/rpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import multiprocessing
import sys
import logging
from ..contrib import rpc
from .. import rpc

def main(args):
"""Main function"""
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/exec/rpc_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
import multiprocessing
import sys
from ..contrib.rpc.tracker import Tracker
from ..rpc.tracker import Tracker


def main(args):
Expand Down
File renamed without changes.
8 changes: 4 additions & 4 deletions python/tvm/contrib/rpc/base.py → python/tvm/rpc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import random
import logging

from ..._ffi.function import _init_api
from ..._ffi.base import py_str
from .._ffi.function import _init_api
from .._ffi.base import py_str

# Magic header for RPC data plane
RPC_MAGIC = 0xff271
Expand Down Expand Up @@ -158,5 +158,5 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
time.sleep(retry_period)


# Still use tvm.contrib.rpc for the foreign functions
_init_api("tvm.contrib.rpc", "tvm.contrib.rpc.base")
# Still use tvm.rpc for the foreign functions
_init_api("tvm.rpc", "tvm.rpc.base")
14 changes: 7 additions & 7 deletions python/tvm/contrib/rpc/client.py → python/tvm/rpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import time

from . import base
from .. import util
from ..._ffi.base import TVMError
from ..._ffi import function as function
from ..._ffi import ndarray as nd
from ...module import load as _load_module
from ..contrib import util
from .._ffi.base import TVMError
from .._ffi import function as function
from .._ffi import ndarray as nd
from ..module import load as _load_module


class RPCSession(object):
Expand Down Expand Up @@ -82,7 +82,7 @@ def upload(self, data, target=None):

if "upload" not in self._remote_funcs:
self._remote_funcs["upload"] = self.get_function(
"tvm.contrib.rpc.server.upload")
"tvm.rpc.server.upload")
self._remote_funcs["upload"](target, blob)

def download(self, path):
Expand All @@ -100,7 +100,7 @@ def download(self, path):
"""
if "download" not in self._remote_funcs:
self._remote_funcs["download"] = self.get_function(
"tvm.contrib.rpc.server.download")
"tvm.rpc.server.download")
return self._remote_funcs["download"](path)

def load_module(self, path):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from . import base
from .base import TrackerCode
from .server import _server_env
from ..._ffi.base import py_str
from .._ffi.base import py_str


class ForwardHandler(object):
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/contrib/rpc/server.py → python/tvm/rpc/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
import time
import sys

from ..._ffi.function import register_func
from ..._ffi.base import py_str
from ..._ffi.libinfo import find_lib_path
from ...module import load as _load_module
from .. import util
from .._ffi.function import register_func
from .._ffi.base import py_str
from .._ffi.libinfo import find_lib_path
from ..module import load as _load_module
from ..contrib import util
from . import base
from . base import TrackerCode

Expand All @@ -36,11 +36,11 @@ def _server_env(load_library, logger):
logger = logging.getLogger()

# pylint: disable=unused-variable
@register_func("tvm.contrib.rpc.server.workpath")
@register_func("tvm.rpc.server.workpath")
def get_workpath(path):
return temp.relpath(path)

@register_func("tvm.contrib.rpc.server.load_module", override=True)
@register_func("tvm.rpc.server.load_module", override=True)
def load_module(file_name):
"""Load module from remote side."""
path = temp.relpath(file_name)
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
raise ImportError(
"RPCTracker module requires tornado package %s" % error_msg)

from ..._ffi.base import py_str
from .._ffi.base import py_str
from . import base
from .base import RPC_TRACKER_MAGIC, TrackerCode

Expand Down
2 changes: 1 addition & 1 deletion src/runtime/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void Module::Import(Module other) {
if (!std::strcmp((*this)->type_key(), "rpc")) {
static const PackedFunc* fimport_ = nullptr;
if (fimport_ == nullptr) {
fimport_ = runtime::Registry::Get("contrib.rpc._ImportRemoteModule");
fimport_ = runtime::Registry::Get("rpc._ImportRemoteModule");
CHECK(fimport_ != nullptr);
}
(*fimport_)(*this, other);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime/rpc/rpc_event_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ PackedFunc CreateEventDrivenServer(PackedFunc fsend,
});
}

TVM_REGISTER_GLOBAL("contrib.rpc._CreateEventDrivenServer")
TVM_REGISTER_GLOBAL("rpc._CreateEventDrivenServer")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CreateEventDrivenServer(args[0], args[1], args[2]);
});
Expand Down
Loading