Skip to content
Permalink
 
 
Cannot retrieve contributors at this time
"""
Copied from: https://raw.githubusercontent.com/dmlc/dmlc-core/master/tracker/dmlc_tracker/tracker.py
Git version: https://github.com/dmlc/dmlc-core/tree/c56d3c7ffb1ea7a5ba1da08910b64fcbd38f3308
This has been slightly modified and is used to override the same file in XGBoost dmlc-core until
upstream contribution is approved and merged.
* Add timeout on sockets
* Reduce the thread.join() timeout to 1
* Add logs of debug statements to understand the node communication.
Tracker script for DMLC
Implements the tracker control protocol
- start dmlc jobs
- start ps scheduler and rabit tracker
- help nodes to establish links with each other
Tianqi Chen
"""
# pylint: disable=invalid-name, missing-docstring, too-many-arguments, too-many-locals
# pylint: disable=too-many-branches, too-many-statements
from __future__ import absolute_import
import os
import sys
import socket
import struct
import subprocess
import argparse
import time
import logging
from threading import Thread
logger = logging.getLogger('RabitTracker')
class ExSocket(object):
"""
Extension of socket to handle recv and send of special data
"""
def __init__(self, sock, timeout=30.0):
self.sock = sock
self.sock.settimeout(timeout)
def recvall(self, nbytes):
res = []
nread = 0
while nread < nbytes:
chunk = self.sock.recv(min(nbytes - nread, 1024))
if len(chunk) == 0:
raise socket.timeout()
nread += len(chunk)
res.append(chunk)
return b''.join(res)
def recvint(self):
return struct.unpack('@i', self.recvall(4))[0]
def sendint(self, n):
self.sock.sendall(struct.pack('@i', n))
def sendstr(self, s):
self.sendint(len(s))
self.sock.sendall(s.encode())
def recvstr(self):
slen = self.recvint()
return self.recvall(slen).decode()
# magic number used to verify existence of data
kMagic = 0xff99
def get_some_ip(host):
return socket.getaddrinfo(host, None)[0][4][0]
def get_family(addr):
return socket.getaddrinfo(addr, None)[0][0]
class SlaveEntry(object):
def __init__(self, sock, s_addr):
slave = ExSocket(sock)
self.sock = slave
self.host = get_some_ip(s_addr[0])
magic = slave.recvint()
assert magic == kMagic, 'invalid magic number=%d from %s' % (magic, self.host)
slave.sendint(kMagic)
self.rank = slave.recvint()
logger.debug("{}:{} Rank {}".format(self.sock, s_addr, self.rank))
self.world_size = slave.recvint()
logger.debug("{}:{} world_size {}".format(self.sock, s_addr, self.world_size))
self.jobid = slave.recvstr()
logger.debug("{}:{} jobid {}".format(self.sock, s_addr, self.jobid))
self.cmd = slave.recvstr()
logger.debug("{}:{} cmd {}".format(self.sock, s_addr, self.cmd))
self.wait_accept = 0
self.port = None
def decide_rank(self, job_map):
if self.rank >= 0:
return self.rank
if self.jobid != 'NULL' and self.jobid in job_map:
return job_map[self.jobid]
return -1
def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
logger.debug("Assigning rank: {}".format(rank))
self.rank = rank
nnset = set(tree_map[rank])
rprev, rnext = ring_map[rank]
self.sock.sendint(rank)
# send parent rank
self.sock.sendint(parent_map[rank])
logger.debug("Parent rank: {}".format(parent_map[rank]))
# send world size
self.sock.sendint(len(tree_map))
logger.debug("World view: {}".format(len(tree_map)))
self.sock.sendint(len(nnset))
# send the rprev and next link
for r in nnset:
self.sock.sendint(r)
logger.debug("Link: {}".format(r))
# send prev link
if rprev != -1 and rprev != rank:
nnset.add(rprev)
self.sock.sendint(rprev)
logger.debug("Rprev: {}".format(rprev))
else:
self.sock.sendint(-1)
# send next link
if rnext != -1 and rnext != rank:
nnset.add(rnext)
self.sock.sendint(rnext)
logger.debug("Rnext: {}".format(rnext))
else:
self.sock.sendint(-1)
while True:
ngood = self.sock.recvint()
goodset = set([])
for _ in range(ngood):
goodset.add(self.sock.recvint())
logger.debug("In goodset: {}".format(goodset))
assert goodset.issubset(nnset)
badset = nnset - goodset
conset = []
for r in badset:
if r in wait_conn:
conset.append(r)
logger.debug("In conset: {}".format(conset))
self.sock.sendint(len(conset))
self.sock.sendint(len(badset) - len(conset))
for r in conset:
self.sock.sendstr(wait_conn[r].host)
self.sock.sendint(wait_conn[r].port)
self.sock.sendint(r)
nerr = self.sock.recvint()
if nerr != 0:
continue
self.port = self.sock.recvint()
rmset = []
# all connection was successuly setup
for r in conset:
wait_conn[r].wait_accept -= 1
if wait_conn[r].wait_accept == 0:
rmset.append(r)
for r in rmset:
wait_conn.pop(r, None)
self.wait_accept = len(badset) - len(conset)
return rmset
class RabitTracker(object):
"""
tracker for rabit
"""
def __init__(self, hostIP, nslave, port=9091, port_end=9999):
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
for port in range(port, port_end):
try:
sock.bind((hostIP, port))
self.port = port
break
except socket.error as e:
if e.errno in [98, 48]:
continue
else:
raise
sock.listen()
self.sock = sock
self.hostIP = hostIP
self.thread = None
self.start_time = None
self.end_time = None
self.nslave = nslave
logger.info('start listen on %s:%d', hostIP, self.port)
def __del__(self):
self.sock.close()
@staticmethod
def get_neighbor(rank, nslave):
rank = rank + 1
ret = []
if rank > 1:
ret.append(rank // 2 - 1)
if rank * 2 - 1 < nslave:
ret.append(rank * 2 - 1)
if rank * 2 < nslave:
ret.append(rank * 2)
return ret
def slave_envs(self):
"""
get enviroment variables for slaves
can be passed in as args or envs
"""
return {'DMLC_TRACKER_URI': self.hostIP,
'DMLC_TRACKER_PORT': self.port}
def get_tree(self, nslave):
tree_map = {}
parent_map = {}
for r in range(nslave):
tree_map[r] = self.get_neighbor(r, nslave)
parent_map[r] = (r + 1) // 2 - 1
return tree_map, parent_map
def find_share_ring(self, tree_map, parent_map, r):
"""
get a ring structure that tends to share nodes with the tree
return a list starting from r
"""
nset = set(tree_map[r])
cset = nset - set([parent_map[r]])
if len(cset) == 0:
return [r]
rlst = [r]
cnt = 0
for v in cset:
vlst = self.find_share_ring(tree_map, parent_map, v)
cnt += 1
if cnt == len(cset):
vlst.reverse()
rlst += vlst
return rlst
def get_ring(self, tree_map, parent_map):
"""
get a ring connection used to recover local data
"""
assert parent_map[0] == -1
rlst = self.find_share_ring(tree_map, parent_map, 0)
assert len(rlst) == len(tree_map)
ring_map = {}
nslave = len(tree_map)
for r in range(nslave):
rprev = (r + nslave - 1) % nslave
rnext = (r + 1) % nslave
ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
return ring_map
def get_link_map(self, nslave):
"""
get the link map, this is a bit hacky, call for better algorithm
to place similar nodes together
"""
tree_map, parent_map = self.get_tree(nslave)
ring_map = self.get_ring(tree_map, parent_map)
rmap = {0 : 0}
k = 0
for i in range(nslave - 1):
k = ring_map[k][1]
rmap[k] = i + 1
ring_map_ = {}
tree_map_ = {}
parent_map_ = {}
for k, v in ring_map.items():
ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
for k, v in tree_map.items():
tree_map_[rmap[k]] = [rmap[x] for x in v]
for k, v in parent_map.items():
if k != 0:
parent_map_[rmap[k]] = rmap[v]
else:
parent_map_[rmap[k]] = -1
return tree_map_, parent_map_, ring_map_
def accept_slaves(self, nslave):
# set of nodes that finishs the job
shutdown = {}
# set of nodes that is waiting for connections
wait_conn = {}
# maps job id to rank
job_map = {}
# list of workers that is pending to be assigned rank
pending = []
# lazy initialize tree_map
tree_map = None
logger.debug("Looking for {} connections.".format(nslave))
while len(shutdown) != nslave:
fd, s_addr = self.sock.accept()
logger.debug("Accepted connection")
logger.debug(fd)
logger.debug(s_addr)
try:
s = SlaveEntry(fd, s_addr)
except socket.timeout as ex:
logger.info("No data received from connection {}. Closing.".format(s_addr))
continue
logger.debug("Slave command is: {}".format(s.cmd))
if s.cmd == 'print':
msg = s.sock.recvstr()
logger.debug("PRINTING FROM {}:{}".format(fd, s_addr))
logger.info(msg.strip())
continue
if s.cmd == 'shutdown':
assert s.rank >= 0 and s.rank not in shutdown
assert s.rank not in wait_conn
shutdown[s.rank] = s
logger.debug('Recieve %s signal from %d', s.cmd, s.rank)
continue
assert s.cmd == 'start' or s.cmd == 'recover'
# lazily initialize the slaves
if tree_map is None:
assert s.cmd == 'start'
if s.world_size > 0:
nslave = s.world_size
tree_map, parent_map, ring_map = self.get_link_map(nslave)
# set of nodes that is pending for getting up
todo_nodes = list(range(nslave))
else:
assert s.world_size == -1 or s.world_size == nslave
if s.cmd == 'recover':
assert s.rank >= 0
rank = s.decide_rank(job_map)
# batch assignment of ranks
if rank == -1:
assert len(todo_nodes) != 0
pending.append(s)
logger.debug("Pending slaves: {}".format(pending))
logger.debug("TO-do slaves: {}".format(todo_nodes))
if len(pending) == len(todo_nodes):
pending.sort(key=lambda x: x.host)
for s in pending:
rank = todo_nodes.pop(0)
if s.jobid != 'NULL':
job_map[s.jobid] = rank
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
if s.wait_accept > 0:
wait_conn[rank] = s
logger.info('Recieve %s signal from %s; assign rank %d',
s.cmd, s.host, s.rank)
if len(todo_nodes) == 0:
logger.info('@tracker All of %d nodes getting started', nslave)
self.start_time = time.time()
else:
s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
logger.debug('Recieve %s signal from %d', s.cmd, s.rank)
if s.wait_accept > 0:
wait_conn[rank] = s
logger.info('@tracker All nodes finishes job')
self.end_time = time.time()
logger.info('@tracker %s secs between node start and job finish',
str(self.end_time - self.start_time))
def start(self, nslave):
def run():
self.accept_slaves(nslave)
self.thread = Thread(target=run, args=())
self.thread.setDaemon(True)
self.thread.start()
def join(self):
while self.thread.isAlive():
self.thread.join(1)
def alive(self):
return self.thread.isAlive()
class PSTracker(object):
"""
Tracker module for PS
"""
def __init__(self, hostIP, cmd, port=9091, port_end=9999, envs=None):
"""
Starts the PS scheduler
"""
self.cmd = cmd
if cmd is None:
return
envs = {} if envs is None else envs
self.hostIP = hostIP
sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
for port in range(port, port_end):
try:
sock.bind(('', port))
self.port = port
sock.close()
break
except socket.error:
continue
env = os.environ.copy()
env['DMLC_ROLE'] = 'scheduler'
env['DMLC_PS_ROOT_URI'] = str(self.hostIP)
env['DMLC_PS_ROOT_PORT'] = str(self.port)
for k, v in envs.items():
env[k] = str(v)
self.thread = Thread(
target=(lambda: subprocess.check_call(self.cmd, env=env, shell=True, executable='/bin/bash')), args=())
self.thread.setDaemon(True)
self.thread.start()
def join(self):
if self.cmd is not None:
while self.thread.isAlive():
self.thread.join(100)
def slave_envs(self):
if self.cmd is None:
return {}
else:
return {'DMLC_PS_ROOT_URI': self.hostIP,
'DMLC_PS_ROOT_PORT': self.port}
def alive(self):
if self.cmd is not None:
return self.thread.isAlive()
else:
return False
def get_host_ip(hostIP=None):
if hostIP is None or hostIP == 'auto':
hostIP = 'ip'
if hostIP == 'dns':
hostIP = socket.getfqdn()
elif hostIP == 'ip':
from socket import gaierror
try:
hostIP = socket.gethostbyname(socket.getfqdn())
except gaierror:
logger.warn('gethostbyname(socket.getfqdn()) failed... trying on hostname()')
hostIP = socket.gethostbyname(socket.gethostname())
if hostIP.startswith("127."):
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# doesn't have to be reachable
s.connect(('10.255.255.255', 1))
hostIP = s.getsockname()[0]
return hostIP
def submit(nworker, nserver, fun_submit, hostIP='auto', pscmd=None):
if nserver == 0:
pscmd = None
envs = {'DMLC_NUM_WORKER' : nworker,
'DMLC_NUM_SERVER' : nserver}
hostIP = get_host_ip(hostIP)
if nserver == 0:
rabit = RabitTracker(hostIP=hostIP, nslave=nworker)
envs.update(rabit.slave_envs())
rabit.start(nworker)
if rabit.alive():
fun_submit(nworker, nserver, envs)
else:
pserver = PSTracker(hostIP=hostIP, cmd=pscmd, envs=envs)
envs.update(pserver.slave_envs())
if pserver.alive():
fun_submit(nworker, nserver, envs)
if nserver == 0:
rabit.join()
else:
pserver.join()
def start_rabit_tracker(args):
"""Standalone function to start rabit tracker.
Parameters
----------
args: arguments to start the rabit tracker.
"""
envs = {'DMLC_NUM_WORKER' : args.num_workers,
'DMLC_NUM_SERVER' : args.num_servers}
rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), nslave=args.num_workers)
envs.update(rabit.slave_envs())
rabit.start(args.num_workers)
sys.stdout.write('DMLC_TRACKER_ENV_START\n')
# simply write configuration to stdout
for k, v in envs.items():
sys.stdout.write('%s=%s\n' % (k, str(v)))
sys.stdout.write('DMLC_TRACKER_ENV_END\n')
sys.stdout.flush()
rabit.join()
def main():
"""Main function if tracker is executed in standalone mode."""
parser = argparse.ArgumentParser(description='Rabit Tracker start.')
parser.add_argument('--num-workers', required=True, type=int,
help='Number of worker proccess to be launched.')
parser.add_argument('--num-servers', default=0, type=int,
help='Number of server process to be launched. Only used in PS jobs.')
parser.add_argument('--host-ip', default=None, type=str,
help=('Host IP addressed, this is only needed ' +
'if the host IP cannot be automatically guessed.'))
parser.add_argument('--log-level', default='INFO', type=str,
choices=['INFO', 'DEBUG'],
help='Logging level of the logger.')
args = parser.parse_args()
fmt = '%(asctime)s %(levelname)s %(message)s'
if args.log_level == 'INFO':
level = logging.INFO
elif args.log_level == 'DEBUG':
level = logging.DEBUG
else:
raise RuntimeError("Unknown logging level %s" % args.log_level)
logging.basicConfig(format=fmt, level=level)
if args.num_servers == 0:
start_rabit_tracker(args)
else:
raise RuntimeError("Do not yet support start ps tracker in standalone mode.")
if __name__ == "__main__":
main()