In [6]:
import json
import os
import sys

In [7]:
if '.' not in sys.path:
  sys.path.insert(0, '.')
import tensorflow as tf

In [8]:
%%writefile mnist.py

import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # The `x` arrays are in uint8 and have values in the [0, 255] range.
  # You need to convert them to float32 with values in the [0, 1] range.
  x_train = x_train / np.float32(255)
  y_train = y_train.astype(np.int64)
  train_dataset = tf.data.Dataset.from_tensor_slices(
      (x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset

def build_and_compile_cnn_model():
  model = tf.keras.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  return model

Overwriting mnist.py


In [9]:
IP_SAM_MAC = '10.0.0.89'
PORT_SAM_MAC = 2002
IP_SAM_UBUNTU = '10.0.0.179'
PORT_SAM_UBUNTU = 2002

tf_config = {
    'cluster': {
        'worker': [f'{IP_SAM_MAC}:{PORT_SAM_MAC}', f'{IP_SAM_UBUNTU}:{PORT_SAM_UBUNTU}']
    },
    'task': {'type': 'worker', 'index': 0}
}

json.dumps(tf_config)

'{"cluster": {"worker": ["10.0.0.89:2002", "10.0.0.179:2002"]}, "task": {"type": "worker", "index": 0}}'

In [10]:
# set environment variable to tf_config
os.environ['TF_CONFIG'] = json.dumps(tf_config)

In [11]:
%%writefile main.py

import os
import json

import tensorflow as tf
import mnist

per_worker_batch_size = 64
tf_config = json.loads(os.environ['TF_CONFIG'])
num_workers = len(tf_config['cluster']['worker'])

strategy = tf.distribute.MultiWorkerMirroredStrategy()

global_batch_size = per_worker_batch_size * num_workers
multi_worker_dataset = mnist.mnist_dataset(global_batch_size)

with strategy.scope():
  # Model building/compiling need to be within `strategy.scope()`.
  multi_worker_model = mnist.build_and_compile_cnn_model()


multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

Writing main.py


In [12]:
# first kill any previous runs
%killbgscripts

All background processes were killed.


In [13]:
%%bash
python main.py &> job_0.log

Process is interrupted.
