In [None]:
import mxnet as mx
import os, sys, requests

In [None]:
def get_iterators(batch_size, data_shape=(3, 224, 224)):
    train = mx.io.ImageRecordIter(
        path_imgrec         = '/home/m3rc3n4ry/seedclassification-mxnet/seed-train.rec',
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        shuffle             = True,
        rand_crop           = True,
        rand_mirror         = True)
    val = mx.io.ImageRecordIter(
        path_imgrec         = '/home/m3rc3n4ry/seedclassification-mxnet/seed-val.rec',
        data_name           = 'data',
        label_name          = 'softmax_label',
        batch_size          = batch_size,
        data_shape          = data_shape,
        rand_crop           = False,
        rand_mirror         = False)
    return (train, val)

In [None]:
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50', 0)

In [None]:
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
    """
    symbol: the pretrained network symbol
    arg_params: the argument parameters of the pretrained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = symbol.get_internals()
    net = all_layers[layer_name+'_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = dict({k:arg_params[k] for k in arg_params if 'fc1' not in k})
    return (net, new_args)

In [None]:
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)

def fit(symbol, arg_params, aux_params, train, val, batch_size, num_gpus):
    devs = [mx.gpu(i) for i in range(num_gpus)]
    mod = mx.mod.Module(symbol=symbol, context=mx.gpu())
    mod.fit(train, val,
        num_epoch=8,
        arg_params=arg_params,
        aux_params=aux_params,
        allow_missing=True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 10),
        kvstore='device',
        optimizer='sgd',
        optimizer_params={'learning_rate':0.01},
        initializer=mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2),
        eval_metric='acc')
    metric = mx.metric.Accuracy()
    mod.save_checkpoint('resnet_checkpoint', 1, False)
    return mod.score(val, metric)

In [None]:
num_classes = 12
batch_per_gpu = 16
num_gpus = 1

(new_sym, new_args) = get_fine_tune_model(sym, arg_params, num_classes)

batch_size = batch_per_gpu * num_gpus
(train, val) = get_iterators(batch_size)
mod_score = fit(new_sym, new_args, aux_params, train, val, batch_size, num_gpus)
assert mod_score > 0.77, "Low training accuracy."

2018-06-28 22:11:04,101 Epoch[0] Batch [10]	Speed: 16.91 samples/sec	accuracy=0.159091
2018-06-28 22:11:13,572 Epoch[0] Batch [20]	Speed: 16.90 samples/sec	accuracy=0.156250
2018-06-28 22:11:23,038 Epoch[0] Batch [30]	Speed: 16.90 samples/sec	accuracy=0.262500
2018-06-28 22:11:32,520 Epoch[0] Batch [40]	Speed: 16.88 samples/sec	accuracy=0.343750
2018-06-28 22:11:42,003 Epoch[0] Batch [50]	Speed: 16.87 samples/sec	accuracy=0.462500
2018-06-28 22:11:51,517 Epoch[0] Batch [60]	Speed: 16.82 samples/sec	accuracy=0.400000
2018-06-28 22:12:01,016 Epoch[0] Batch [70]	Speed: 16.85 samples/sec	accuracy=0.606250
2018-06-28 22:12:10,489 Epoch[0] Batch [80]	Speed: 16.89 samples/sec	accuracy=0.531250
2018-06-28 22:12:19,969 Epoch[0] Batch [90]	Speed: 16.88 samples/sec	accuracy=0.662500
2018-06-28 22:12:29,444 Epoch[0] Batch [100]	Speed: 16.89 samples/sec	accuracy=0.625000
2018-06-28 22:12:38,934 Epoch[0] Batch [110]	Speed: 16.86 samples/sec	accuracy=0.668750
2018-06-28 22:12:48,425 Epoch[0] Batch [1

2018-06-28 22:26:03,202 Epoch[3] Batch [20]	Speed: 16.81 samples/sec	accuracy=0.931250
2018-06-28 22:26:12,688 Epoch[3] Batch [30]	Speed: 16.87 samples/sec	accuracy=0.937500
2018-06-28 22:26:22,168 Epoch[3] Batch [40]	Speed: 16.88 samples/sec	accuracy=0.937500
2018-06-28 22:26:31,661 Epoch[3] Batch [50]	Speed: 16.86 samples/sec	accuracy=0.925000
2018-06-28 22:26:41,154 Epoch[3] Batch [60]	Speed: 16.86 samples/sec	accuracy=0.912500
2018-06-28 22:26:50,644 Epoch[3] Batch [70]	Speed: 16.86 samples/sec	accuracy=0.925000
2018-06-28 22:27:00,136 Epoch[3] Batch [80]	Speed: 16.86 samples/sec	accuracy=0.950000
2018-06-28 22:27:09,622 Epoch[3] Batch [90]	Speed: 16.87 samples/sec	accuracy=0.925000
2018-06-28 22:27:19,119 Epoch[3] Batch [100]	Speed: 16.85 samples/sec	accuracy=0.943750
2018-06-28 22:27:28,595 Epoch[3] Batch [110]	Speed: 16.89 samples/sec	accuracy=0.906250
2018-06-28 22:27:38,117 Epoch[3] Batch [120]	Speed: 16.80 samples/sec	accuracy=0.931250
2018-06-28 22:27:47,601 Epoch[3] Batch [

2018-06-28 22:41:01,034 Epoch[6] Batch [30]	Speed: 16.85 samples/sec	accuracy=0.975000
2018-06-28 22:41:10,520 Epoch[6] Batch [40]	Speed: 16.87 samples/sec	accuracy=0.993750
2018-06-28 22:41:20,313 Epoch[6] Batch [50]	Speed: 16.34 samples/sec	accuracy=0.937500
2018-06-28 22:41:29,850 Epoch[6] Batch [60]	Speed: 16.78 samples/sec	accuracy=0.981250
2018-06-28 22:41:39,339 Epoch[6] Batch [70]	Speed: 16.86 samples/sec	accuracy=0.956250
2018-06-28 22:41:48,818 Epoch[6] Batch [80]	Speed: 16.88 samples/sec	accuracy=0.968750
2018-06-28 22:41:58,391 Epoch[6] Batch [90]	Speed: 16.72 samples/sec	accuracy=0.962500
2018-06-28 22:42:07,886 Epoch[6] Batch [100]	Speed: 16.85 samples/sec	accuracy=0.931250
2018-06-28 22:42:17,377 Epoch[6] Batch [110]	Speed: 16.86 samples/sec	accuracy=0.962500
2018-06-28 22:42:26,859 Epoch[6] Batch [120]	Speed: 16.88 samples/sec	accuracy=0.956250
2018-06-28 22:42:36,344 Epoch[6] Batch [130]	Speed: 16.87 samples/sec	accuracy=0.943750
2018-06-28 22:42:45,833 Epoch[6] Batch 

TypeError: unorderable types: list() > float()