In [1]:
import os
import json
import tensorflow as tf
tf.get_logger().setLevel('ERROR')

# Cluster setup

In [2]:
tf_config = {
    'cluster': {
        'worker': ['192.168.1.1:12345', '192.168.1.2:12345'],
        'ps': ['192.168.1.3:12345', '192.168.1.4:12345'],
        'chief': ['192.168.1.5:12345']
    },
    'task': {'type': 'worker', 'index': 0},
    # 'task': {'type': 'ps', 'index': 0}
}
os.environ.pop('TF_CONFIG', None)
os.environ['TF_CONFIG'] = json.dumps(tf_config)

In [3]:
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
if cluster_resolver.task_type == 'ps':
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  
    print('Parameter server detected')
elif cluster_resolver.task_type == 'worker':
    gpu_devices = tf.config.list_physical_devices('GPU') 
    if len(gpu_devices) == 0: raise SystemError('GPU device not found')
    for gpu in gpu_devices: 
        tf.config.experimental.set_memory_growth(gpu, True)
    print('Worker detected with GPU(s):', gpu_devices)
else: raise SystemError('Machine in wrong role')

Worker detected with GPU(s): [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [None]:
# Allow reporting worker and ps failure to the coordinator
os.environ['GRPC_FAIL_FAST'] = 'use_caller'

# Start a TensorFlow server and wait.
server = tf.distribute.Server(
    cluster_resolver.cluster_spec(),
    job_name = cluster_resolver.task_type,
    task_index = cluster_resolver.task_id,
    protocol = cluster_resolver.rpc_layer or 'grpc',
    start = True
)
server.join()