In [1]:
import mxnet as mx

In [2]:
mnist = mx.test_utils.get_mnist()
mx.random.seed(1)
batch_size = 32
ctx = mx.gpu()
nd_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size)

In [3]:
data = mx.sym.Variable('data')
conv1= mx.symbol.Convolution(data = data, name='conv1', num_filter=32, kernel=(3,3), stride=(2,2))
bn1 = mx.symbol.BatchNorm(data = conv1, name="bn1")
act1 = mx.symbol.Activation(data = bn1, name='relu1', act_type="relu")
mp1 = mx.symbol.Pooling(data = act1, name = 'mp1', kernel=(2,2), stride=(2,2), pool_type='max')

conv2= mx.symbol.Convolution(data = mp1, name='conv2', num_filter=32, kernel=(3,3), stride=(2,2))
bn2 = mx.symbol.BatchNorm(data = conv2, name="bn2")
act2 = mx.symbol.Activation(data = bn2, name='relu2', act_type="relu")
mp2 = mx.symbol.Pooling(data = act2, name = 'mp2', kernel=(2,2), stride=(2,2), pool_type='max')

fl = mx.symbol.Flatten(data = mp2, name="flatten")
fc2 = mx.symbol.FullyConnected(data = fl, name='fc2', num_hidden=10)
softmax = mx.symbol.SoftmaxOutput(data = fc2, name = 'softmax')
mod = mx.mod.Module(softmax)  # create a module by given a Symbol

In [5]:
softmax.get_internals().list_outputs()

['data',
 'conv1_weight',
 'conv1_bias',
 'conv1_output',
 'bn1_gamma',
 'bn1_beta',
 'bn1_moving_mean',
 'bn1_moving_var',
 'bn1_output',
 'relu1_output',
 'mp1_output',
 'conv2_weight',
 'conv2_bias',
 'conv2_output',
 'bn2_gamma',
 'bn2_beta',
 'bn2_moving_mean',
 'bn2_moving_var',
 'bn2_output',
 'relu2_output',
 'mp2_output',
 'flatten_output',
 'fc2_weight',
 'fc2_bias',
 'fc2_output',
 'softmax_label',
 'softmax_output']

In [6]:
# mod.bind(data_shapes=nd_iter.provide_data,
#          label_shapes=nd_iter.provide_label) # create memory by given input shapes
# mod.init_params() # initial parameters with the default random initializer

In [7]:
mod.fit(nd_iter, num_epoch=10) # train
mod.predict(nd_iter) # predict on new data


[[6.1429414e-04 2.2412802e-05 2.9679781e-04 ... 1.0351928e-04
  1.1245396e-03 9.4142102e-04]
 [9.9995208e-01 1.5527347e-07 8.6388927e-06 ... 1.3666230e-05
  5.2755231e-06 3.3966360e-06]
 [1.3097575e-04 4.4493489e-02 7.5362218e-03 ... 7.9570457e-02
  4.3649800e-04 3.5894753e-03]
 ...
 [1.7041208e-05 1.5998005e-05 6.0236711e-05 ... 2.1640272e-04
  5.8221180e-05 1.4509353e-03]
 [9.2443489e-02 7.2884701e-05 5.5742497e-05 ... 2.3522411e-05
  8.3520580e-03 8.2647261e-05]
 [1.6435813e-03 8.2000879e-06 1.7693899e-04 ... 3.0793424e-05
  9.8038894e-01 1.4470483e-02]]
<NDArray 60000x10 @cpu(0)>

In [8]:
mod.save_checkpoint('mod', 0)

In [9]:
sym, arg_params, aux_params = mx.model.load_checkpoint('mod', 0)

In [10]:
import logging
nd_iter.reset()
qsym, qarg_params, qaux_params = mx.contrib.quantization.quantize_model(sym=sym, 
                                                                        arg_params=arg_params, 
                                                                        aux_params=aux_params,
                                                                        excluded_sym_names=['conv1'],
                                                ctx=ctx, calib_data=nd_iter, logger=logging)

In [11]:
mod = mx.mod.Module(qsym, context=ctx)

In [12]:
nd_iter.reset()
mod.bind(for_training=False,
         data_shapes=nd_iter.provide_data,
         label_shapes=nd_iter.provide_label)

In [14]:
mod.set_params(qarg_params, qaux_params)

In [15]:
nd_iter.reset()
mod.predict(nd_iter)


[[7.4198929e-04 2.8709939e-05 3.2908368e-04 ... 1.4595370e-04
  1.4609668e-03 1.2758268e-03]
 [9.9994636e-01 1.9554305e-07 8.6896416e-06 ... 1.7109787e-05
  5.7870370e-06 3.8539906e-06]
 [1.4486852e-04 4.2915381e-02 6.4377384e-03 ... 8.4499799e-02
  4.2831406e-04 3.2695699e-03]
 ...
 [1.9457897e-05 1.4838801e-05 6.5876840e-05 ... 2.2303322e-04
  5.7528650e-05 1.4867883e-03]
 [1.1410713e-01 9.9354729e-05 7.5768985e-05 ... 3.3604683e-05
  1.1399541e-02 1.3028237e-04]
 [1.6800329e-03 8.5157790e-06 2.2008461e-04 ... 3.3014912e-05
  9.7994012e-01 1.4685692e-02]]
<NDArray 60000x10 @gpu(0)>

In [16]:
mod = mx.mod.Module(qsym, context=mx.cpu())

# Kernel dies here upon using cpu()

In [None]:
nd_iter.reset()
mod.bind(for_training=False,
         data_shapes=nd_iter.provide_data,
         label_shapes=nd_iter.provide_label)