<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#MirroredStragety" data-toc-modified-id="MirroredStragety-1">MirroredStragety</a></span><ul class="toc-item"><li><span><a href="#Generate-Fake-Data" data-toc-modified-id="Generate-Fake-Data-1.1">Generate Fake Data</a></span></li><li><span><a href="#Define-Network-and-Solver" data-toc-modified-id="Define-Network-and-Solver-1.2">Define Network and Solver</a></span></li><li><span><a href="#Define-Wrapper,-model_fn-and-per_device_dataset" data-toc-modified-id="Define-Wrapper,-model_fn-and-per_device_dataset-1.3">Define Wrapper, model_fn and per_device_dataset</a></span></li><li><span><a href="#Single-GPU-Training" data-toc-modified-id="Single-GPU-Training-1.4">Single GPU Training</a></span></li><li><span><a href="#Multi-GPUs-Training" data-toc-modified-id="Multi-GPUs-Training-1.5">Multi-GPUs Training</a></span></li></ul></li></ul></div>

# MirroredStragety

We give a simple example of MirroredStragety in Tensorflow which can be used for multi-gpus training.

* Tensorflow 1.13.1
* Numpy 1.16.3

In [None]:
import argparse
import tensorflow as tf
from tensorflow import keras
from tensorflow.python.distribute import values
from tensorflow.python.util import nest
import numpy as np
import collections

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
# from tensorflow.python.util import deprecation
# os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
# deprecation._PRINT_DEPRECATION_WARNINGS = False

print(tf.VERSION, tf.keras.__version__)

## Generate Fake Data

* Data point: (x1, x2) 
* Data label: 1 if x1 > 0 else 0

In [None]:
def gen_data(bs):
    image = np.random.uniform(-1, 1, size=(bs, 2))
    label = image[:, 0] > 0
    return image.astype(np.float32), label.astype(np.int32)

## Define Network and Solver

In [None]:
def network(features, labels):
    out = keras.layers.Dense(2, kernel_initializer="ones")(features)
    losses = tf.losses.sparse_softmax_cross_entropy(logits=out, labels=labels,
                                                    reduction=tf.losses.Reduction.NONE)
    # Here we set reduction to NONE for checking a batch of losses in single gpu training
    return losses

def solver(loss):
    global_step = tf.train.get_or_create_global_step()
    opt = tf.train.AdamOptimizer(learning_rate=0.01)
    train_op = opt.minimize(loss, global_step=global_step)
    return train_op

## Define Wrapper, model_fn and per_device_dataset

In [None]:
class Wrapper(collections.namedtuple('Father', ['train_op', 'losses', 'loss'])):
    def __new__(cls, train_op, losses, loss):
        return super(Wrapper, cls).__new__(cls, train_op, losses, loss)

def model_fn(features, labels):
    losses = network(features, labels)
    loss = tf.reduce_mean(losses)
    train_op = solver(loss)
    return Wrapper(train_op=train_op, losses=losses, loss=loss)

def per_device_dataset(batch, devices):
    """
    batch: [num_gpus, batch_size / num_gpus, data_dim1, data_dim2, ...], here we have shape of [2, 1, 2]
    devices: gpu device names
    """
    index = {}

    def get_ith(i_):
        return lambda x: x[i_]
    
    for i, d in enumerate(devices):
        index[d] = nest.map_structure(get_ith(i), batch)
        
    return values.regroup(index)

## Single GPU Training

In [None]:
def main_1():
    tf.reset_default_graph()
    tf.set_random_seed(1234)
    x_input = tf.placeholder(tf.float32, shape=(2, 2), name="x_input")
    y_input = tf.placeholder(tf.int32, shape=(2, ), name="y_input")
    wp = model_fn(x_input, y_input)

    np.random.seed(1234)
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    for i in range(1000):
        x, y = gen_data(bs=2)
        _, l1, l2 = sess.run([wp.train_op, wp.loss, wp.losses], feed_dict={x_input: x, y_input: y})
        if i % 100 == 0:
            # Print mean_loss and losses for checking loss = reduce_mean(losses) and comparing with mirrored strategy results
            print("step {}, loss {} {}".format(i, l1, l2))

In [None]:
main_1()

## Multi-GPUs Training

Here we just use 2 gpus.

Single GPU training and Double-GPU training shoule produce the same loss and the same weight updates.

You can run main_1() and main_2() to check if the results are the same. 

In [None]:
def main_2():
    tf.reset_default_graph()
    tf.set_random_seed(1234)
    strategy = tf.distribute.MirroredStrategy(["device:GPU:0", "device:GPU:1"])

    with strategy.scope():
        x_input = tf.placeholder(tf.float32, shape=(2, 2), name="x_input")
        y_input = tf.placeholder(tf.int32, shape=(2, ), name="y_input")

        # -----------------------------------------------------------------------------------
        # Convert a batch with shape [bs, dim] to a batch [num_gpus, bs/num_gpus, dim]
        features, labels = per_device_dataset((tf.reshape(x_input, (2, 1, 2)),
                                               tf.reshape(y_input, (2, 1))), strategy.extended._devices)
        # Then we get a PerReplica instances, whose each gpu entry will be a batch with shape [bs/num_gpus, dim]
        # Try print(features)
        # And get PerReplica:{'/replica:0/task:0/device:GPU:0': <tf.Tensor 'strided_slice:0' shape=(1, 2) dtype=float32>, '/replica:0/task:0/device:GPU:1': <tf.Tensor 'strided_slice_2:0' shape=(1, 2) dtype=float32>}
        
        # Call model_fn for each replica(i.e. gpu)
        grouped_wp = strategy.call_for_each_replica(model_fn, args=(features, labels))
        # Get loss reduction across all the gpus, i.e. mean loss
        mean_loss = strategy.reduce(tf.distribute.get_loss_reduction(), grouped_wp.loss)
        # We can also get losses from all the gpus for checking
        concat_loss = tf.stack(strategy.unwrap(grouped_wp.loss), axis=0)
        # We just need group train_op 
        train_op = strategy.group(grouped_wp.train_op)
        # -----------------------------------------------------------------------------------

        np.random.seed(1234)
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())
        for i in range(1000):
            x, y = gen_data(bs=2)
            _, l1, l2 = sess.run([train_op, mean_loss, concat_loss], feed_dict={x_input: x, y_input: y})
            if i % 100 == 0:
                # Check mean_loss = reduce_mean(concat_loss)
                print("step {}, loss {} {}".format(i, l1, l2))

In [None]:
main_2()