<a href="https://colab.research.google.com/github/Yug-Oswal/Custom-DistributedTraining-TF/blob/main/CustomModels%26DistributedTraining.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import os

In [None]:
print("Tensorflow version: {}".format(tf.__version__))
AUTO = tf.data.experimental.AUTOTUNE

Tensorflow version: 2.12.0


In [None]:
# Finding, connecting, and initializing the TPU Cluster
try:
  tpu_address = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_address)
  tf.config.experimental_connect_to_cluster(tpu)
  tf.tpu.experimental.initialize_tpu_system(tpu)
  strategy = tf.distribute.TPUStrategy(tpu)
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
  print('Number of accelerators: ', strategy.num_replicas_in_sync)
except ValueError:
  print('TPU failed to initialize.')

Running on TPU  ['10.37.197.202:8470']
Number of accelerators:  8


In [None]:
# Loading the mnist dataset
train_data = tfds.load('mnist', split='train')
test_data = tfds.load('mnist', split='test')

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...


Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]

Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


In [None]:
# Setting up appropriate batch size for sharding
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [None]:
# Shuffling, batching, and prefetching the dataset
train_dataset = train_data.shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE).prefetch(AUTO)
test_dataset = test_data.batch(GLOBAL_BATCH_SIZE)

In [None]:
class CustomModel(tf.keras.Model):
  def __init__(self, filters=32, kernel=3, units=128, output_units=10):
    super(CustomModel, self).__init__()
    self.conv = tf.keras.layers.Conv2D(filters, kernel, input_shape=(28, 28, 1))
    self.max_pool = tf.keras.layers.MaxPooling2D()
    self.flatten = tf.keras.layers.Flatten()
    self.hidden1 = tf.keras.layers.Dense(units, activation='relu')
    self.out = tf.keras.layers.Dense(output_units)

  def call(self, inputs):
    x = self.conv(inputs)
    x = self.max_pool(x)
    x = self.flatten(x)
    x = self.hidden1(x)
    predictions = self.out(x)
    return predictions

In [None]:
with strategy.scope():
  model = CustomModel()
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                                              reduction=tf.keras.losses.Reduction.NONE)
  optimizer = tf.keras.optimizers.Adam()
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

  def compute_loss(labels, logits):
    loss = loss_object(labels, logits)
    return loss

  def train_step(inputs):
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(inputs)
      loss = compute_loss(labels, logits)

    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    predictions = tf.nn.softmax(logits)
    train_accuracy.update_state(labels, predictions)

    return loss

  @tf.function
  def distributed_train_step(dataset_inputs):
    per_replica_losses = strategy.run(train_step, args=(dataset_inputs, ))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)