-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
rpc_server.py
146 lines (132 loc) · 6.02 KB
/
rpc_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=redefined-outer-name, invalid-name
"""Start an RPC server"""
from __future__ import absolute_import
import argparse
import ast
import json
import multiprocessing
import sys
import logging
import tvm
from tvm import micro
from .. import rpc
def main(args):
"""Main function
Parameters
----------
args : argparse.Namespace
parsed args from command-line invocation
"""
if args.tracker:
url, port = args.tracker.rsplit(":", 1)
port = int(port)
tracker_addr = (url, port)
if not args.key:
raise RuntimeError(
'Need key to present type of resource when tracker is available')
else:
tracker_addr = None
if args.utvm_dev_config or args.utvm_dev_id:
init_utvm(args)
server = rpc.Server(args.host,
args.port,
args.port_end,
key=args.key,
tracker_addr=tracker_addr,
load_library=args.load_library,
custom_addr=args.custom_addr,
silent=args.silent)
server.proc.join()
def init_utvm(args):
"""MicroTVM-specific RPC initialization
Parameters
----------
args : argparse.Namespace
parsed args from command-line invocation
"""
if args.utvm_dev_config and args.utvm_dev_id:
raise RuntimeError('only one of --utvm-dev-config and --utvm-dev-id allowed')
if args.utvm_dev_config:
with open(args.utvm_dev_config, 'r') as dev_conf_file:
dev_config = json.load(dev_conf_file)
else:
dev_config_args = ast.literal_eval(args.utvm_dev_config_args)
generate_config_func = micro.device.get_device_funcs(args.utvm_dev_id)['generate_config']
dev_config = generate_config_func(*dev_config_args)
if args.utvm_dev_config or args.utvm_dev_id:
# add MicroTVM overrides
@tvm.register_func('tvm.rpc.server.start', override=True)
def server_start():
# pylint: disable=unused-variable
session = micro.Session(dev_config)
session._enter()
@tvm.register_func('tvm.rpc.server.shutdown', override=True)
def server_shutdown():
session._exit()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--host', type=str, default="0.0.0.0",
help='the hostname of the server')
parser.add_argument('--port', type=int, default=9090,
help='The port of the RPC')
parser.add_argument('--port-end', type=int, default=9199,
help='The end search port of the RPC')
parser.add_argument('--tracker', type=str,
help=("The address of RPC tracker in host:port format. "
"e.g. (10.77.1.234:9190)"))
parser.add_argument('--key', type=str, default="",
help="The key used to identify the device type in tracker.")
parser.add_argument('--silent', action='store_true',
help="Whether run in silent mode.")
parser.add_argument('--load-library', type=str,
help="Additional library to load")
parser.add_argument('--no-fork', dest='fork', action='store_false',
help="Use spawn mode to avoid fork. This option \
is able to avoid potential fork problems with Metal, OpenCL \
and ROCM compilers.")
parser.add_argument('--custom-addr', type=str,
help="Custom IP Address to Report to RPC Tracker")
parser.add_argument('--utvm-dev-config', type=str,
help=('JSON config file for the target device (if using MicroTVM). '
'This file should contain serialized output similar to that returned '
"from the device module's generate_config. Can't be specified when "
'--utvm-dev-config-args is specified.'))
parser.add_argument('--utvm-dev-config-args', type=str,
help=("Arguments to the device module's generate_config function. "
'Must be a python literal parseable by literal_eval. If specified, '
"the device configuration is generated using the device module's "
"generate_config. Can't be specified when --utvm-dev-config is "
"specified."))
parser.add_argument('--utvm-dev-id', type=str,
help=('Unique ID for the target device (if using MicroTVM). Should '
'match the name of a module underneath tvm.micro.device).'))
parser.set_defaults(fork=True)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
if args.fork is False:
if sys.version_info[0] < 3:
raise RuntimeError(
"Python3 is required for spawn mode."
)
multiprocessing.set_start_method('spawn')
else:
if not args.silent:
logging.info("If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```")
main(args)