# Distributed Training

Tensorflow 2 provide tf.distribute.Strategy API for distributed training. All we have to do is to follow two procedures. 
1. Wraping components with a distributed strategy
2. Defining training procedure

Notice! Current tensorflow 2.0 is very unstable in handling variable-lenghts inputs. We have to modify our codes in complicated ways to avoid errors.

## 1. Wraping models with a distributed strategy
Tensorflow provide wide range of distributed strategies, such as mirrored strategy, TPU strategy. One of the widely used strategy is mirrored strategy. Mirrored strategy is basically the same to data parrelel in pytorch. It deploies replica of a model at each GPU device and distributes inputs evenly to them. Gradients are computed at each GPU device and they updated after merging. We can apply mirrored strategy by `calling strategy instance` and `wrapping components (model, optimizer, dataset)` with the strategy.



In [56]:
import tensorflow as tf
from tensorflow import keras

#create distributed strategy instance
strategy = tf.distribute.MirroredStrategy()


In [57]:
#model
class SimpleModel(keras.Model):
    def __init__(self, n_class):
        super(SimpleModel,self).__init__()
        self.output_layer = keras.layers.Dense(n_class)
    
    def call(self,x):
        return self.output_layer(x)

#dataset

X = tf.random.normal((10000,100))
Y = tf.random.uniform((10000,),0,5,dtype=tf.int64)

batch_size = 8
dataset = tf.data.Dataset.from_tensor_slices((X,Y)).batch(batch_size)

In [58]:
#wraping the model, optimizer
with strategy.scope():
    model = SimpleModel(5)
    optimizer = keras.optimizers.Adam(0.001)
    
#create distributed dataset
dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset)

# we have to use dataset with an iterator instance to prevent errors.  
dist_dataset = iter(dist_dataset)

## 2. Defining training procedure

The concept is very intuitive. We first define the training step for each device and distribute them into multiple GPUs. 

However, current tensorflow 2.0 is unstable in handling variable-length inputs, so we must circumvent it to avoid potential errors.

In [59]:
# current tensorflow 2.0 only operate when training procedure is decorated with @tf.function
@tf.function
def train_step():
    # tensorflow 2.0 can't recognize variable-lenghts inputs when they are passed as arguements.
    # We need to generate create inputs within the method by calling next() 
    x,y = next(dist_dataset)
    def step_fn(x,y):
        # this inner method is training step for each device
        with tf.GradientTape() as tape:
            y_ = model(x)
            per_device_loss = keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_, from_logits=True)
            
            # we scaled the loss by the number of GPUs.
            # This is important because all the replicas are training in sync
            losses = tf.reduce_sum(per_device_loss) * (1.0 / strategy.num_replicas_in_sync)
        step_grad = tape.gradient(losses, model.trainable_variables)
        optimizer.apply_gradients(zip(step_grad, model.trainable_variables))
        
        # return losses to keep track the loss values
        return tf.reduce_mean(losses)[None]

    example_loss = strategy.experimental_run_v2(step_fn, args=(x, y))
    
    # Normally, we record loss with the mean values over batch losses.
    # But current tensorflow 2.0 raises an error when averaging the losses.
    # To compute mean losses over global batch, we first get scaled mean loss per each device, and adding them up.
    # This is very complicated. I hope this would be fixed in the next version.
    losses_sum = strategy.reduce(
        tf.distribute.ReduceOp.SUM, example_loss,axis=0)
    return losses_sum

## 4. Training Epoch
These procedures can be applied for the entire datasets

In [60]:
def train_epoch(dataset,model,optimizer):
    pbar = tf.keras.utils.Progbar(1000)
    pbar_cnt = 0
    # we must know the amount of iteration per epoch
    # It's because we can't pass inputs to the train_step(), and generate inputs within the train_step() methods.
    for i in range(1000):
        pbar_cnt+=1
        computed_loss = train_step()
        pbar.update(pbar_cnt, [['loss',computed_loss]])

In [61]:
train_epoch(dataset,model,optimizer)

