<a href="https://colab.research.google.com/github/Yug-Oswal/Custom-DistributedTraining-TF/blob/main/CustomModels%26DistributedTraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [41]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import os

In [42]:
print("Tensorflow version: {}".format(tf.__version__))
AUTO = tf.data.experimental.AUTOTUNE

Tensorflow version: 2.12.0


In [43]:
# Finding, connecting, and initializing the TPU Cluster
try:
  tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address)
  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  strategy = tf.distribute.TPUStrategy(tpu)
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
  print('Number of accelerators: ', strategy.num_replicas_in_sync)
except ValueError:
  print('TPU failed to initialize.')



Running on TPU  ['10.37.197.202:8470']
Number of accelerators:  8


In [None]:
training_images_file   = 'gs://mnist-public/train-images-idx3-ubyte'
training_labels_file   = 'gs://mnist-public/train-labels-idx1-ubyte'
validation_images_file = 'gs://mnist-public/t10k-images-idx3-ubyte'
validation_labels_file = 'gs://mnist-public/t10k-labels-idx1-ubyte'

In [None]:
def read_label(tf_bytestring):
    label = tf.io.decode_raw(tf_bytestring, tf.uint8)
    label = tf.reshape(label, [])
    label = tf.one_hot(label, 10)
    return label

def read_image(tf_bytestring):
    image = tf.io.decode_raw(tf_bytestring, tf.uint8)
    image = tf.cast(image, tf.float32)/255.0
    image = tf.reshape(image, [28*28])
    return image

def load_dataset(image_file, label_file):
    imagedataset = tf.data.FixedLengthRecordDataset(image_file, 28*28, header_bytes=16)
    imagedataset = imagedataset.map(read_image, num_parallel_calls=16)
    labelsdataset = tf.data.FixedLengthRecordDataset(label_file, 1, header_bytes=8)
    labelsdataset = labelsdataset.map(read_label, num_parallel_calls=16)
    dataset = tf.data.Dataset.zip((imagedataset, labelsdataset))
    return dataset

In [None]:
train_data = load_dataset(training_images_file, training_labels_file)
test_data = load_dataset(validation_images_file, validation_labels_file)

In [47]:
# Setting up appropriate batch size for sharding
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [48]:
# Shuffling, batching, and prefetching the dataset
train_dataset = train_data.shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE).prefetch(AUTO)
test_dataset = test_data.batch(GLOBAL_BATCH_SIZE)
# Sharding dataset
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

In [49]:
class CustomModel(tf.keras.Model):
  def __init__(self, filters=32, kernel=3, units=128, output_units=10):
    super(CustomModel, self).__init__()
    self.conv = tf.keras.layers.Conv2D(filters, kernel, input_shape=(28, 28, 1))
    self.max_pool = tf.keras.layers.MaxPooling2D()
    self.flatten = tf.keras.layers.Flatten()
    self.hidden1 = tf.keras.layers.Dense(units, activation='relu')
    self.out = tf.keras.layers.Dense(output_units)

  def call(self, inputs):
    x = self.conv(inputs)
    x = self.max_pool(x)
    x = self.flatten(x)
    x = self.hidden1(x)
    predictions = self.out(x)
    return predictions

In [50]:
with strategy.scope():
  model = CustomModel()
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                              reduction=tf.keras.losses.Reduction.NONE)
  optimizer = tf.keras.optimizers.Adam()
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

  def compute_loss(labels, logits):
    per_batch_loss = loss_object(labels, logits)
    return tf.nn.compute_average_loss(per_batch_loss, global_batch_size = GLOBAL_BATCH_SIZE)

  def train_step(inputs):
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(inputs)
      loss = compute_loss(labels, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    predictions = tf.nn.softmax(logits)
    train_accuracy.update_state(labels, predictions)

    return loss

  @tf.function
  def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.run(train_step, args=(dataset_inputs, ))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)

  def test_step(test_inputs):
    test_images, test_labels = test_inputs
    logits = model(test_images)
    test_loss = compute_loss(test_labels, logits)

    test_accuracy.update_state(test_labels, tf.nn.softmax(logits))
    return test_loss

  @tf.function
  def distributed_test_step(dataset_inputs):
    per_replica_losses = strategy.run(test_step, args=(dataset_inputs, ))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, axis=None)

In [51]:
# Custom Training Loop
with strategy.scope():
  EPOCHS = 20
  losses = []
  for epoch in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    for batch in train_dist_dataset:
      total_loss += distributed_train_step(batch)
      num_batches += 1
    train_loss = total_loss / num_batches
    losses.append(train_loss)

    total_test_loss = 0.0
    num_test_batches = 0
    for batch in test_dist_dataset:
      total_test_loss += distributed_test_step(batch)
      num_test_batches += 1
    test_loss = total_test_loss / num_batches

    print("Epoch {}:\nLoss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}".format(
        epoch + 1, train_loss, train_accuracy.result() * 100, test_loss, test_accuracy.result() * 100
    ))

    train_accuracy.reset_state()
    test_accuracy.reset_state()

UnimplementedError: File system scheme '[local]' not implemented (file: 'data/mnist/3.0.1/mnist-train.tfrecord-00000-of-00001')
	 [[{{node MultiDeviceIteratorGetNextFromShard}}]]
	 [[RemoteCall]]