Source: https://mxnet.apache.org/versions/1.5.0/tutorials/gluon/multi_gpu.html

In [2]:
import mxnet as mx

n_gpu = mx.context.num_gpus()
context = [mx.gpu(0), mx.gpu(1)] if n_gpu >= 2 else \
          [mx.gpu(), mx.gpu()] if n_gpu == 1 else \
          [mx.cpu(), mx.cpu()]

a = mx.nd.array([1, 2, 3], ctx=context[0])
b = mx.nd.array([5, 6, 7], ctx=context[1])



Storing the network on multiple GPUs¶

In [18]:
from mxnet import init
from mxnet.gluon import nn

net = nn.Sequential()
net.add(nn.Conv2D(channels=6, kernel_size=5, activation='relu'),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Conv2D(channels=16, kernel_size=3, activation='relu'),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Conv2D(channels=16, kernel_size=3, activation='relu'),
        nn.MaxPool2D(pool_size=2, strides=2),
        nn.Flatten(),
        nn.Dense(120, activation="relu"),
        nn.Dense(84, activation="relu"),
        nn.Dense(10))

net.initialize(init=init.Xavier(), ctx=context)


Splitting data between GPUs

In [20]:
data = mx.random.uniform(shape=(100, 10))
result = mx.gluon.utils.split_and_load(data, ctx_list=context)

In [21]:
train_data = mx.gluon.data.vision.MNIST(train=True).transform_first(mx.gluon.data.vision.transforms.ToTensor())
val_data = mx.gluon.data.vision.MNIST(train=False).transform_first(mx.gluon.data.vision.transforms.ToTensor())

In [22]:
batch_size = 256
train_loader = mx.gluon.data.DataLoader(train_data, shuffle=True, batch_size=batch_size)
val_loader = mx.gluon.data.DataLoader(val_data, shuffle=False, batch_size=batch_size)

In [23]:
trainer = mx.gluon.Trainer(
    params=net.collect_params(),
    optimizer='sgd',
    optimizer_params={'learning_rate': 0.005},
)

metric = mx.metric.Accuracy()
loss_function = mx.gluon.loss.SoftmaxCrossEntropyLoss()

In [24]:
num_epochs = 100

for epoch in range(num_epochs):
    for inputs, labels in train_loader:
        actual_batch_size = inputs.shape[0]
        # Split data among GPUs. Since split_and_load is a deterministic function
        # inputs and labels are going to be split in the same way between GPUs.
        inputs = mx.gluon.utils.split_and_load(inputs, ctx_list=context, even_split=False)
        labels = mx.gluon.utils.split_and_load(labels, ctx_list=context, even_split=False)

        # The forward pass and the loss computation need to be wrapped
        # in a `record()` scope to make sure the computational graph is
        # recorded in order to automatically compute the gradients
        # during the backward pass.
        with mx.autograd.record():
            outputs = [net(input_slice) for input_slice in inputs]
            losses = [loss_function(o, l) for o, l in zip(outputs, labels)]

        # Iterate over losses to compute gradients for each input slice
        for loss in losses:
            loss.backward()

        # update metric for each output
        for l, o in zip(labels, outputs):
            metric.update(l, o)

        # Update the parameters by stepping the trainer; the batch size
        # is required to normalize the gradients by `1 / batch_size`.
        trainer.step(batch_size=actual_batch_size, ignore_stale_grad=True)

    # Print the evaluation metric and reset it for the next epoch
    name, acc = metric.get()
    print('After epoch {}: {} = {}'.format(epoch + 1, name, acc))
    metric.reset()

[12:36:33] ../src/kvstore/././comm.h:741: only 0 out of 2 GPU pairs are enabled direct access. It may affect the performance. You can set MXNET_ENABLE_GPU_P2P=0 to turn it off
[12:36:33] ../src/kvstore/././comm.h:750: ..
[12:36:33] ../src/kvstore/././comm.h:750: ..


After epoch 1: accuracy = 0.18065
After epoch 2: accuracy = 0.37133333333333335
After epoch 3: accuracy = 0.6742833333333333
After epoch 4: accuracy = 0.85235
After epoch 5: accuracy = 0.8885166666666666
After epoch 6: accuracy = 0.9051333333333333
After epoch 7: accuracy = 0.9148166666666666
After epoch 8: accuracy = 0.9229333333333334
After epoch 9: accuracy = 0.9292333333333334
After epoch 10: accuracy = 0.9342833333333334
After epoch 11: accuracy = 0.9393833333333333
After epoch 12: accuracy = 0.9430333333333333
After epoch 13: accuracy = 0.9471166666666667
After epoch 14: accuracy = 0.9495166666666667
After epoch 15: accuracy = 0.9531666666666667
After epoch 16: accuracy = 0.9546
After epoch 17: accuracy = 0.9572166666666667
After epoch 18: accuracy = 0.95825
After epoch 19: accuracy = 0.9605166666666667
After epoch 20: accuracy = 0.96175
After epoch 21: accuracy = 0.9627166666666667
After epoch 22: accuracy = 0.9642333333333334
After epoch 23: accuracy = 0.9652166666666666
After 