Fetching contributors…
Cannot retrieve contributors at this time
97 lines (86 sloc) 3.38 KB
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
Train mnist, see more explanation at
import os
import argparse
import logging
from common import find_mxnet, fit
from common.util import download_file
import mxnet as mx
import numpy as np
import gzip, struct
def read_data(label, image):
download and read data into numpy
base_url = ''
with, os.path.join('data',label))) as flbl:
magic, num = struct.unpack(">II",
label = np.fromstring(, dtype=np.int8)
with, os.path.join('data',image)), 'rb') as fimg:
magic, num, rows, cols = struct.unpack(">IIII",
image = np.fromstring(, dtype=np.uint8).reshape(len(label), rows, cols)
return (label, image)
def to4d(img):
reshape to 4D arrays
return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255
def get_mnist_iter(args, kv):
create data iterator with NDArrayIter
(train_lbl, train_img) = read_data(
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')
(val_lbl, val_img) = read_data(
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')
train =
to4d(train_img), train_lbl, args.batch_size, shuffle=True)
val =
to4d(val_img), val_lbl, args.batch_size)
return (train, val)
if __name__ == '__main__':
# parse args
parser = argparse.ArgumentParser(description="train mnist",
parser.add_argument('--num-classes', type=int, default=10,
help='the number of classes')
parser.add_argument('--num-examples', type=int, default=60000,
help='the number of training examples')
parser.add_argument('--add_stn', action="store_true", default=False, help='Add Spatial Transformer Network Layer (lenet only)')
# network
network = 'mlp',
# train
gpus = None,
batch_size = 64,
disp_batches = 100,
num_epochs = 20,
lr = .05,
lr_step_epochs = '10'
args = parser.parse_args()
# load network
from importlib import import_module
net = import_module('symbols.'
sym = net.get_symbol(**vars(args))
# train, sym, get_mnist_iter)