<a href="https://colab.research.google.com/github/TheRadDani/TF-default-strategy/blob/main/TfDefaultStrategy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import tensorflow as tf
import tensorflow_datasets as tfds

# Load the MNIST dataset
mnist_dataset = tfds.load(name='mnist', split='train', as_supervised=True)

# Define the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Compile the model with a loss function, optimizer, and metrics
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

# Create a distribution strategy
strategy = tf.distribute.get_strategy()

# Define the batch size and number of epochs
batch_size = 64
num_epochs = 10

# Create a distributed dataset from the MNIST dataset
distributed_dataset = strategy.experimental_distribute_dataset(mnist_dataset.batch(batch_size))

# Define a training step function that runs on each replica
@tf.function
def train_step(inputs):
    features, labels = inputs
    
    with tf.GradientTape() as tape:
        predictions = model(features, training=True)
        loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(labels, predictions))
        
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
    accuracy = tf.reduce_mean(tf.keras.metrics.sparse_categorical_accuracy(labels, predictions))
    
    return loss, accuracy

optimizer = tf.keras.optimizers.Adam()

# Define a distributed training loop
for epoch in range(num_epochs):
    epoch_loss = strategy.reduce(tf.distribute.ReduceOp.SUM, 0.0, axis=None)
    epoch_accuracy = strategy.reduce(tf.distribute.ReduceOp.SUM, 0.0, axis=None)
    
    num_batches = 0
    
    for inputs in distributed_dataset:
        per_replica_loss, per_replica_accuracy = strategy.run(train_step, args=(inputs,))
        
        epoch_loss += strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)
        epoch_accuracy += strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_accuracy, axis=None)
        
        num_batches += 1
    
    epoch_loss /= num_batches
    epoch_accuracy /= num_batches
    
    print(f"Epoch {epoch + 1}: Loss = {epoch_loss}, Accuracy = {epoch_accuracy}")


Epoch 1: Loss = 2.7783584594726562, Accuracy = 0.8604244589805603
Epoch 2: Loss = 0.4168095886707306, Accuracy = 0.9084655046463013
Epoch 3: Loss = 0.2855697572231293, Accuracy = 0.930420458316803
Epoch 4: Loss = 0.23605339229106903, Accuracy = 0.9390324950218201
Epoch 5: Loss = 0.20447565615177155, Accuracy = 0.9455457329750061
Epoch 6: Loss = 0.1892833113670349, Accuracy = 0.9489772319793701
Epoch 7: Loss = 0.16988526284694672, Accuracy = 0.9548407793045044
Epoch 8: Loss = 0.16040684282779694, Accuracy = 0.9573394060134888
Epoch 9: Loss = 0.1529911756515503, Accuracy = 0.9599047303199768
Epoch 10: Loss = 0.1458587348461151, Accuracy = 0.9605876803398132
