In [None]:
import os
import json
import multiprocessing
import tensorflow as tf
tf.get_logger().setLevel('ERROR')

In [None]:
gpu_devices = tf.config.list_physical_devices('GPU') 
if len(gpu_devices) == 0: 
    raise SystemError('GPU device not found')
for gpu in gpu_devices: 
    tf.config.experimental.set_memory_growth(gpu, True)
gpu_devices

# Cluster setup

In [None]:
tf_config = {
    'cluster': {
        'worker': ['192.168.1.1:12345', '192.168.1.2:12345'],
        'ps': ['192.168.1.3:12345', '192.168.1.4:12345'],
        'chief': ['192.168.1.5:12345']
    },
    'task': {'type': 'worker', 'index': 0}
}
os.environ.pop('TF_CONFIG', None)
os.environ['TF_CONFIG'] = json.dumps(tf_config)

# Allow reporting worker and ps failure to the coordinator
os.environ['GRPC_FAIL_FAST'] = 'use_caller'

In [None]:
# Workers need some inter_ops threads to work properly.
num_workers = len(tf_config['cluster']['worker'])
worker_config = tf.compat.v1.ConfigProto()
if multiprocessing.cpu_count() < num_workers + 1:
    worker_config.inter_op_parallelism_threads = num_workers + 1

In [None]:
cluster_resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
if cluster_resolver.task_type != "worker":
    raise SystemError('Machine is in wrong role')

In [None]:
server = tf.distribute.Server(
    cluster_resolver.cluster_spec(),
    job_name = cluster_resolver.task_type,
    task_index = cluster_resolver.task_id,
    config = worker_config,
    protocol = cluster_resolver.rpc_layer or 'grpc',
    start = True
)
server.join()