In [None]:
import os
import sys
import json
import time

# Log additional outputs from TF's C++ backend
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'

In [None]:
# Disable GPUs
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Add current directory to path
if '.' not in sys.path:
  sys.path.insert(0, '.')

In [None]:
import tensorflow as tf

# Ignore warnings
tf.get_logger().setLevel('ERROR')

In [None]:
%%writefile mnist.py

# import os
import tensorflow as tf
import numpy as np

def mnist_dataset(batch_size):
  # Load the data
  (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
  # Normalize pixel values for x_train and cast to float32
  x_train = x_train / np.float32(255)
  # Cast y_train to int64
  y_train = y_train.astype(np.int64)
  # Define repeated and shuffled dataset
  train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).repeat().batch(batch_size)
  return train_dataset


def build_and_compile_cnn_model():
  # Define simple CNN model using Keras Sequential
  model = tf.keras.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28)),
      tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  # Compile model
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
      metrics=['accuracy'])
  
  return model

In [None]:
!ls *.py

In [None]:
# Import your mnist model
import mnist

# Set batch size
batch_size = 64

# Load the dataset
single_worker_dataset = mnist.mnist_dataset(batch_size)

# Load compiled CNN model
single_worker_model = mnist.build_and_compile_cnn_model()

# As training progresses, the loss should drop and the accuracy should increase.
single_worker_model.fit(single_worker_dataset, epochs=3, steps_per_epoch=70)

In [None]:
tf_config = {
    'cluster': {
        'worker': ['localhost:12345', 'localhost:23456']
    },
    'task': {'type': 'worker', 'index': 0}
}

In [None]:
json.dumps(tf_config)

In [None]:
strategy = tf.distribute.MultiWorkerMirroredStrategy()

In [None]:
# Implementing distributed strategy via a context manager
with strategy.scope():
  multi_worker_model = mnist.build_and_compile_cnn_model()

In [None]:
%%writefile main.py

import os
import json

import tensorflow as tf
import mnist # Your module

# Define batch size
per_worker_batch_size = 64

# Get TF_CONFIG from the env variables and save it as JSON
tf_config = json.loads(os.environ['TF_CONFIG'])

# Infer number of workers from tf_config
num_workers = len(tf_config['cluster']['worker'])

# Define strategy
strategy = tf.distribute.MultiWorkerMirroredStrategy()

# Define global batch size
global_batch_size = per_worker_batch_size * num_workers

# Load dataset
multi_worker_dataset = mnist.mnist_dataset(global_batch_size)

# Create and compile model following the distributed strategy
with strategy.scope():
  multi_worker_model = mnist.build_and_compile_cnn_model()

# Train the model
multi_worker_model.fit(multi_worker_dataset, epochs=3, steps_per_epoch=70)

In [None]:
!ls *.py

In [None]:
# Set TF_CONFIG env variable
os.environ['TF_CONFIG'] = json.dumps(tf_config)

In [None]:
# first kill any previous runs
%killbgscripts

In [None]:
# This should not print anything at the moment
!lsof -i :12345

In [None]:
%%bash --bg
python main.py &> job_0.log

In [None]:
# Wait for logs to be written to the file
time.sleep(10)

In [None]:
!lsof -i :12345

In [None]:
%%bash
cat job_0.log

In [None]:
tf_config['task']['index'] = 1
os.environ['TF_CONFIG'] = json.dumps(tf_config)

In [None]:
%%bash
python main.py

In [None]:
%%bash
cat job_0.log