Permalink
Fetching contributors…
Cannot retrieve contributors at this time
103 lines (92 sloc) 4.63 KB
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet
from common import data, fit, modelzoo
import mxnet as mx
import numpy as np
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name, dtype='float32'):
"""
symbol: the pre-trained network symbol
arg_params: the argument parameters of the pre-trained model
num_classes: the number of classes for the fine-tune datasets
layer_name: the layer name before the last fully-connected layer
"""
all_layers = symbol.get_internals()
net = all_layers[layer_name+'_output']
net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc')
if dtype == 'float16':
net = mx.sym.Cast(data=net, dtype=np.float32)
net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
new_args = dict({k:arg_params[k] for k in arg_params if 'fc' not in k})
return (net, new_args)
if __name__ == "__main__":
# parse args
parser = argparse.ArgumentParser(description="fine-tune a dataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
train = fit.add_fit_args(parser)
data.add_data_args(parser)
aug = data.add_data_aug_args(parser)
parser.add_argument('--pretrained-model', type=str,
help='the pre-trained model. can be prefix of local model files prefix \
or a model name from common/modelzoo')
parser.add_argument('--layer-before-fullc', type=str, default='flatten0',
help='the name of the layer before the last fullc layer')\
# use less augmentations for fine-tune. by default here it uses no augmentations
# use a small learning rate and less regularizations
parser.set_defaults(image_shape='3,224,224',
num_epochs=30,
lr=.01,
lr_step_epochs='20',
wd=0,
mom=0)
args = parser.parse_args()
# load pretrained model and params
dir_path = os.path.dirname(os.path.realpath(__file__))
(prefix, epoch) = modelzoo.download_model(
args.pretrained_model, os.path.join(dir_path, 'model'))
if prefix is None:
(prefix, epoch) = (args.pretrained_model, args.load_epoch)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
if args.dtype != 'float32':
# load symbol of trained network, so we can cast it to support other dtype
# fine tuning a network in a datatype which was not used for training originally,
# requires access to the code used to generate the symbol used to train that model.
# we then need to modify the symbol to add a layer at the beginning
# to cast data to that dtype. We also need to cast output of layers before softmax
# to float32 so that softmax can still be in float32.
# if the network chosen from symols/ folder doesn't have cast for the new datatype,
# it will still train in fp32
if args.network not in ['inception-v3',\
'inception-v4', 'resnet-v1', 'resnet', 'resnext', 'vgg']:
raise ValueError('Given network does not have support for dtypes other than float32.\
Please add a cast layer at the beginning to train in that mode.')
from importlib import import_module
net = import_module('symbols.'+args.network)
sym = net.get_symbol(**vars(args))
# remove the last fullc layer and add a new softmax layer
(new_sym, new_args) = get_fine_tune_model(sym, arg_params, args.num_classes,
args.layer_before_fullc, args.dtype)
# train
fit.fit(args = args,
network = new_sym,
data_loader = data.get_rec_iter,
arg_params = new_args,
aux_params = aux_params)