In [1]:
import mxnet as mx

In [2]:
mnist = mx.test_utils.get_mnist()
mx.random.seed(1)
batch_size = 32
ctx = mx.cpu()
nd_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size)

In [3]:
data = mx.sym.Variable('data')
conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2))
bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1")
act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu")
mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max')

conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2))
bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2")
act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu")
mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max')

fl = mx.symbol.Flatten(data = mp2, name="flatten")
fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10)
softmax = mx.symbol.SoftmaxOutput(data = fc2, name = 'softmax')
mod = mx.mod.Module(softmax)  # create a module by given a Symbol

In [12]:
softmax.get_internals().list_outputs()

['data',
 'conv1_weight',
 'conv1_bias',
 'conv1_output',
 'bn1_gamma',
 'bn1_beta',
 'bn1_moving_mean',
 'bn1_moving_var',
 'bn1_output',
 'relu1_output',
 'mp1_output',
 'conv2_weight',
 'conv2_bias',
 'conv2_output',
 'bn2_gamma',
 'bn2_beta',
 'bn2_moving_mean',
 'bn2_moving_var',
 'bn2_output',
 'relu2_output',
 'mp2_output',
 'flatten_output',
 'fc2_weight',
 'fc2_bias',
 'fc2_output',
 'softmax_label',
 'softmax_output']

In [4]:
# mod.bind(data_shapes=nd_iter.provide_data,
#          label_shapes=nd_iter.provide_label) # create memory by given input shapes
# mod.init_params() # initial parameters with the default random initializer

In [4]:
mod.fit(nd_iter, num_epoch=10) # train
mod.predict(nd_iter) # predict on new data


[[2.8249732e-04 8.5811716e-06 1.9661812e-04 ... 6.7650355e-05
  1.2813369e-03 7.9308427e-04]
 [9.9995291e-01 1.3227522e-07 1.2726049e-05 ... 7.0621363e-06
  4.6420523e-06 2.0200198e-06]
 [9.8071301e-05 1.7094301e-02 4.5305174e-03 ... 4.4743259e-02
  3.1286711e-04 3.2336875e-03]
 ...
 [1.5833242e-05 1.3727446e-05 3.5502053e-05 ... 1.7425697e-04
  5.5069049e-05 1.5886332e-03]
 [5.1516660e-02 9.4310803e-05 7.4359792e-05 ... 2.8431865e-05
  6.8629393e-03 6.4306769e-05]
 [1.5746262e-03 5.8110127e-06 1.7276568e-04 ... 4.3077554e-05
  9.8375678e-01 1.2043254e-02]]
<NDArray 60000x10 @cpu(0)>

In [5]:
mod.save_checkpoint('mod', 0)

In [6]:
sym, arg_params, aux_params = mx.model.load_checkpoint('mod', 0)

In [None]:
import logging
nd_iter.reset()
qsym, qarg_params, qaux_params = mx.contrib.quantization.quantize_model(sym=sym, 
                                                                        arg_params=arg_params, 
                                                                        aux_params=aux_params,
                                                                        excluded_sym_names=['conv1'],
                                                ctx=ctx, calib_data=nd_iter, logger=logging)

In [8]:
mod = mx.mod.Module(qsym, context=ctx)

In [9]:
nd_iter.reset()
mod.bind(for_training=False,
         data_shapes=nd_iter.provide_data,
         label_shapes=nd_iter.provide_label)

RuntimeError: simple_bind error. Arguments:
data: (32, 1, 28, 28)
softmax_label: (32,)
Error in operator quantized_conv1: [09:55:24] src/operator/quantization/quantized_conv.cc:50: Check failed: dshape[C] % 4 == 0U (1 vs. 0) for 8bit cudnn conv, the number of channel must be multiple of 4

Stack trace returned 10 entries:
[bt] (0) 0   libmxnet.so                         0x000000010a205684 libmxnet.so + 26244
[bt] (1) 1   libmxnet.so                         0x000000010a20543f libmxnet.so + 25663
[bt] (2) 2   libmxnet.so                         0x000000010a49b18d libmxnet.so + 2736525
[bt] (3) 3   libmxnet.so                         0x000000010b32b4ea MXNDListFree + 310682
[bt] (4) 4   libmxnet.so                         0x000000010b323b0c MXNDListFree + 279484
[bt] (5) 5   libmxnet.so                         0x000000010b318586 MXNDListFree + 233014
[bt] (6) 6   libmxnet.so                         0x000000010b31bf3a MXNDListFree + 247786
[bt] (7) 7   libmxnet.so                         0x000000010b2ad130 MXExecutorSimpleBind + 8656
[bt] (8) 8   _ctypes.cpython-36m-darwin.so       0x0000000107e7fd17 ffi_call_unix64 + 79
[bt] (9) 9   python                              0x00007fff5bff9cf0 __progname + 140730441894984



In [None]:
mod.set_params(arg_params, aux_params)