In [1]:
import mxnet as mx
from mxnet import nd, gluon, autograd

In [2]:
net = gluon.nn.Sequential(prefix="cnn_")
with net.name_scope():
    net.add(gluon.nn.Conv2D(channels=20, kernel_size=3, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2)))
    net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=(2,2), strides=(2,2)))
    net.add(gluon.nn.Flatten())
    net.add(gluon.nn.Dense(128, activation="relu"))
    net.add(gluon.nn.Dense(10))
loss = gluon.loss.SoftmaxCrossEntropyLoss()

In [3]:
net

Sequential(
  (0): Conv2D(20, kernel_size=(3, 3), stride=(1, 1))
  (1): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
  (2): Conv2D(50, kernel_size=(5, 5), stride=(1, 1))
  (3): MaxPool2D(size=(2, 2), stride=(2, 2), padding=(0, 0), ceil_mode=False)
  (4): Flatten
  (5): Dense(128, Activation(relu))
  (6): Dense(10, linear)
)

In [7]:
net.collect_params()

cnn_ (
  Parameter cnn_conv0_weight (shape=(20, 1, 3, 3), dtype=<class 'numpy.float32'>)
  Parameter cnn_conv0_bias (shape=(20,), dtype=<class 'numpy.float32'>)
  Parameter cnn_conv1_weight (shape=(50, 20, 5, 5), dtype=<class 'numpy.float32'>)
  Parameter cnn_conv1_bias (shape=(50,), dtype=<class 'numpy.float32'>)
  Parameter cnn_dense0_weight (shape=(128, 800), dtype=<class 'numpy.float32'>)
  Parameter cnn_dense0_bias (shape=(128,), dtype=<class 'numpy.float32'>)
  Parameter cnn_dense1_weight (shape=(10, 128), dtype=<class 'numpy.float32'>)
  Parameter cnn_dense1_bias (shape=(10,), dtype=<class 'numpy.float32'>)
)

In [4]:
GPU_COUNT = 2 # increase if you have more
ctx = [mx.gpu(i) for i in range(GPU_COUNT)]
net.collect_params().initialize(ctx=ctx)

In [5]:
from mxnet.test_utils import get_mnist
mnist = get_mnist()
batch = mnist['train_data'][0:GPU_COUNT*2, :]
data = gluon.utils.split_and_load(batch, ctx)
print(net(data[0]))
print(net(data[1]))


[[-0.01876061 -0.02165035 -0.01293944  0.03837406 -0.00821797 -0.0091153
   0.00416799 -0.00729157 -0.0023271  -0.00155548]
 [ 0.00441475 -0.01953595 -0.00128483  0.02768222  0.01389614 -0.01320441
  -0.01166505 -0.00637777  0.01354249 -0.00611765]]
<NDArray 2x10 @gpu(0)>

[[ -6.78736810e-03  -8.86893645e-03  -1.04004759e-02   1.72976386e-02
    2.26115324e-02  -6.36630971e-03  -1.54974945e-02  -1.22633735e-02
    1.19591532e-02  -6.60009682e-05]
 [ -1.17358584e-02  -2.16879621e-02   1.71219651e-03   2.49827579e-02
    1.16810845e-02  -9.52543132e-03  -1.03610354e-02   5.08510135e-03
    7.06663402e-03  -9.25292633e-03]]
<NDArray 2x10 @gpu(1)>


In [6]:
weight = net.collect_params()['cnn_conv0_weight']

for c in ctx:
    print('=== channel 0 of the first conv on {} ==={}'.format(
        c, weight.data(ctx=c)[0]))

=== channel 0 of the first conv on gpu(0) ===
[[[ 0.04118239  0.05352169 -0.04762455]
  [ 0.06035256 -0.01528978  0.04946674]
  [ 0.06110793 -0.00081179  0.02191102]]]
<NDArray 1x3x3 @gpu(0)>
=== channel 0 of the first conv on gpu(1) ===
[[[ 0.04118239  0.05352169 -0.04762455]
  [ 0.06035256 -0.01528978  0.04946674]
  [ 0.06110793 -0.00081179  0.02191102]]]
<NDArray 1x3x3 @gpu(1)>


In [8]:
def forward_backward(net, data, label):
    with autograd.record():
        losses = [loss(net(X), Y) for X, Y in zip(data, label)]
    for l in losses:
        l.backward()

label = gluon.utils.split_and_load(mnist['train_label'][0:4], ctx)
forward_backward(net, data, label)
for c in ctx:
    print('=== grad of channel 0 of the first conv2d on {} ==={}'.format(
        c, weight.grad(ctx=c)[0]))

=== grad of channel 0 of the first conv2d on gpu(0) ===
[[[-0.02078936 -0.00562427  0.01711006]
  [ 0.01138538  0.0280002   0.04094724]
  [ 0.00993335  0.01218192  0.02122577]]]
<NDArray 1x3x3 @gpu(0)>
=== grad of channel 0 of the first conv2d on gpu(1) ===
[[[-0.02543038 -0.0278994  -0.00302116]
  [-0.04816785 -0.03347274 -0.00403482]
  [-0.03178394 -0.01254032  0.00855637]]]
<NDArray 1x3x3 @gpu(1)>


In [9]:

from mxnet.io import NDArrayIter
from time import time

def train_batch(batch, ctx, net, trainer):
    # split the data batch and load them on GPUs
    data = gluon.utils.split_and_load(batch.data[0], ctx)
    label = gluon.utils.split_and_load(batch.label[0], ctx)
    # compute gradient
    forward_backward(net, data, label)
    # update parameters
    trainer.step(batch.data[0].shape[0])

def valid_batch(batch, ctx, net):
    data = batch.data[0].as_in_context(ctx[0])
    pred = nd.argmax(net(data), axis=1)
    return nd.sum(pred == batch.label[0].as_in_context(ctx[0])).asscalar()

def run(num_gpus, batch_size, lr):
    # the list of GPUs will be used
    ctx = [mx.gpu(i) for i in range(num_gpus)]
    print('Running on {}'.format(ctx))

    # data iterator
    mnist = get_mnist()
    train_data = NDArrayIter(mnist["train_data"], mnist["train_label"], batch_size)
    valid_data = NDArrayIter(mnist["test_data"], mnist["test_label"], batch_size)
    print('Batch size is {}'.format(batch_size))

    net.collect_params().initialize(force_reinit=True, ctx=ctx)
    trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': lr})
    for epoch in range(5):
        # train
        start = time()
        train_data.reset()
        for batch in train_data:
            train_batch(batch, ctx, net, trainer)
        nd.waitall()  # wait until all computations are finished to benchmark the time
        print('Epoch %d, training time = %.1f sec'%(epoch, time()-start))

        # validating
        valid_data.reset()
        correct, num = 0.0, 0.0
        for batch in valid_data:
            correct += valid_batch(batch, ctx, net)
            num += batch.data[0].shape[0]
        print('         validation accuracy = %.4f'%(correct/num))

run(1, 64, .3)
run(GPU_COUNT, 64*GPU_COUNT, .3)

Running on [gpu(0)]
Batch size is 64
Epoch 0, training time = 2.9 sec
         validation accuracy = 0.9701
Epoch 1, training time = 2.6 sec
         validation accuracy = 0.9826
Epoch 2, training time = 2.6 sec
         validation accuracy = 0.9852
Epoch 3, training time = 2.6 sec
         validation accuracy = 0.9862
Epoch 4, training time = 2.7 sec
         validation accuracy = 0.9867
Running on [gpu(0), gpu(1)]
Batch size is 128
Epoch 0, training time = 2.8 sec
         validation accuracy = 0.9493
Epoch 1, training time = 2.6 sec
         validation accuracy = 0.9714
Epoch 2, training time = 2.4 sec
         validation accuracy = 0.9788
Epoch 3, training time = 2.5 sec
         validation accuracy = 0.9830
Epoch 4, training time = 2.3 sec
         validation accuracy = 0.9852
