# Training on multiple GPUs with gluon
from https://gluon.mxnet.io/chapter07_distributed-learning/multiple-gpus-gluon.html


In [1]:
import mxnet as mx
from mxnet import nd, gluon, autograd
net = gluon.nn.Sequential(prefix='cnn_')
with net.name_scope():
    net.add(gluon.nn.Conv2D(channels=20, kernel_size=3, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2)))
    net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2)))
    net.add(gluon.nn.Flatten())
    net.add(gluon.nn.Dense(128, activation="relu"))
    net.add(gluon.nn.Dense(10))

loss = gluon.loss.SoftmaxCrossEntropyLoss()

In [2]:
GPU_COUNT = 4 # increase if you have more
ctx = [mx.gpu(i) for i in range(GPU_COUNT)]
net.collect_params().initialize(ctx=ctx)

In [4]:
from mxnet.test_utils import get_mnist
mnist = get_mnist()
batch = mnist['train_data'][0:GPU_COUNT*4, :]
data = gluon.utils.split_and_load(batch, ctx)
print(net(data[0]))
print(net(data[1]))
print(net(data[2]))
print(net(data[3]))


[[-0.00481915  0.0019142   0.02079279 -0.01032627  0.01591925 -0.00201489
  -0.00969937 -0.01089183 -0.00311435  0.00096905]
 [ 0.00226964 -0.00863827  0.01155043 -0.03112907  0.04687939 -0.01447407
  -0.00345835 -0.01686897 -0.00883059 -0.0048196 ]
 [-0.01494183  0.00641099  0.02740017 -0.01331076  0.01718415 -0.02049803
  -0.00956278 -0.0193483   0.0079064  -0.01144795]
 [-0.00765511  0.00055096  0.00661376 -0.01600863  0.02302844 -0.00434789
  -0.01212487 -0.01363104  0.00415654 -0.00635324]]
<NDArray 4x10 @gpu(0)>

[[-0.00461744  0.01375779  0.03823002 -0.03126055  0.0333534  -0.01731854
  -0.00492363 -0.01554649 -0.00743299 -0.00639034]
 [-0.0091598  -0.00684692  0.01705891 -0.03048797  0.02992028 -0.01165187
  -0.01089531 -0.01976818  0.00234484  0.00295688]
 [-0.01501885  0.00427208  0.008796   -0.0090834   0.0166386  -0.00924412
  -0.00995954  0.00503164 -0.00183729 -0.00526959]
 [-0.0001063  -0.010359    0.02876046 -0.02604114  0.02878538 -0.01912549
  -0.00995447 -0.00781563

In [7]:
weight = net.collect_params()['cnn_conv0_weight']

for c in ctx:
    print('=== channel 0 of the first conv on {} ==={}'.format(
        c, weight.data(ctx=c)[0]))

=== channel 0 of the first conv on gpu(0) ===
[[[ 0.0068339   0.01299825  0.0301265 ]
  [ 0.04819721  0.01438687  0.05011239]
  [ 0.00628365  0.04861524 -0.01068833]]]
<NDArray 1x3x3 @gpu(0)>
=== channel 0 of the first conv on gpu(1) ===
[[[ 0.0068339   0.01299825  0.0301265 ]
  [ 0.04819721  0.01438687  0.05011239]
  [ 0.00628365  0.04861524 -0.01068833]]]
<NDArray 1x3x3 @gpu(1)>
=== channel 0 of the first conv on gpu(2) ===
[[[ 0.0068339   0.01299825  0.0301265 ]
  [ 0.04819721  0.01438687  0.05011239]
  [ 0.00628365  0.04861524 -0.01068833]]]
<NDArray 1x3x3 @gpu(2)>
=== channel 0 of the first conv on gpu(3) ===
[[[ 0.0068339   0.01299825  0.0301265 ]
  [ 0.04819721  0.01438687  0.05011239]
  [ 0.00628365  0.04861524 -0.01068833]]]
<NDArray 1x3x3 @gpu(3)>


In [9]:
def forward_backward(net, data, label):
    with autograd.record():
        losses = [loss(net(X), Y) for X, Y in zip(data, label)]
    for l in losses:
        l.backward()

label = gluon.utils.split_and_load(mnist['train_label'][0:16], ctx) # Changed to 0:16 for 4 GPUs
forward_backward(net, data, label)
for c in ctx:
    print('=== grad of channel 0 of the first conv2d on {} ==={}'.format(
        c, weight.grad(ctx=c)[0]))

=== grad of channel 0 of the first conv2d on gpu(0) ===
[[[-0.05446478 -0.06920426 -0.03950822]
  [-0.0742496  -0.06848756 -0.02138805]
  [-0.00942846 -0.02732194 -0.02662538]]]
<NDArray 1x3x3 @gpu(0)>
=== grad of channel 0 of the first conv2d on gpu(1) ===
[[[ 0.13215259  0.13300328  0.09848808]
  [ 0.10377123  0.09225045  0.09771539]
  [ 0.0414196   0.08219814  0.14231277]]]
<NDArray 1x3x3 @gpu(1)>
=== grad of channel 0 of the first conv2d on gpu(2) ===
[[[-0.15555602 -0.12390864  0.00930843]
  [-0.17069009 -0.10882725  0.00049725]
  [-0.08165155 -0.04167913  0.01819187]]]
<NDArray 1x3x3 @gpu(2)>
=== grad of channel 0 of the first conv2d on gpu(3) ===
[[[-0.12295836 -0.08860554 -0.08663689]
  [-0.08101771 -0.10529852 -0.11382867]
  [-0.04248851 -0.08572525 -0.07257522]]]
<NDArray 1x3x3 @gpu(3)>


In [12]:
from mxnet.io import NDArrayIter
from time import time

def train_batch(batch, ctx, net, trainer):
    # split the data batch and load them on GPUs
    data = gluon.utils.split_and_load(batch.data[0], ctx)
    label = gluon.utils.split_and_load(batch.label[0], ctx)
    # compute gradient
    forward_backward(net, data, label)
    # update parameters
    trainer.step(batch.data[0].shape[0])

def valid_batch(batch, ctx, net):
    data = batch.data[0].as_in_context(ctx[0])
    pred = nd.argmax(net(data), axis=1)
    return nd.sum(pred == batch.label[0].as_in_context(ctx[0])).asscalar()

def run(num_gpus, batch_size, lr):
    # the list of GPUs will be used
    ctx = [mx.gpu(i) for i in range(num_gpus)]
    print('Running on {}'.format(ctx))

    # data iterator
    mnist = get_mnist()
    train_data = NDArrayIter(mnist["train_data"], mnist["train_label"], batch_size)
    valid_data = NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)
    print('Batch size is {}'.format(batch_size))

    net.collect_params().initialize(force_reinit=True, ctx=ctx)
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
    for epoch in range(5):
        # train
        start = time()
        train_data.reset()
        for batch in train_data:
            train_batch(batch, ctx, net, trainer)
        nd.waitall()  # wait until all computations are finished to benchmark the time
        print('Epoch %d, training time = %.1f sec'%(epoch, time()-start))

        # validating
        valid_data.reset()
        correct, num = 0.0, 0.0
        for batch in valid_data:
            correct += valid_batch(batch, ctx, net)
            num += batch.data[0].shape[0]
        print('         validation accuracy = %.4f'%(correct/num))

run(1, 64, .3)
run(GPU_COUNT, 64*GPU_COUNT, .3)

Running on [gpu(0)]
Batch size is 64
Epoch 0, training time = 3.6 sec
         validation accuracy = 0.9685
Epoch 1, training time = 3.5 sec
         validation accuracy = 0.9814
Epoch 2, training time = 3.5 sec
         validation accuracy = 0.9862
Epoch 3, training time = 3.4 sec
         validation accuracy = 0.9848
Epoch 4, training time = 3.5 sec
         validation accuracy = 0.9853
Running on [gpu(0), gpu(1), gpu(2), gpu(3)]
Batch size is 256
Epoch 0, training time = 2.4 sec
         validation accuracy = 0.9488
Epoch 1, training time = 2.3 sec
         validation accuracy = 0.9693
Epoch 2, training time = 2.4 sec
         validation accuracy = 0.9763
Epoch 3, training time = 2.4 sec
         validation accuracy = 0.9801
Epoch 4, training time = 2.3 sec
         validation accuracy = 0.9844
