Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #176 from tqchen/master
Browse files Browse the repository at this point in the history
Autoselect kvstore update mode
  • Loading branch information
tqchen committed Sep 28, 2015
2 parents e373355 + f623a5b commit 927c102
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 25 deletions.
20 changes: 0 additions & 20 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,6 @@ def update(self, label, pred):
self.sum_metric += numpy.sum(pred_label == label)
self.num_inst += label.size

# pylint: disable=pointless-string-statement
"""
class LogLoss(EvalMetric):
# remove because it because it is too slow
def __init__(self):
self.eps = 1e-15
super(LogLoss, self).__init__('logloss')
def update(self, label, pred):
# pylint: disable=invalid-name
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
for i in range(label.size):
p = pred[i][label[i]]
assert(numpy.isnan(p) == False)
p = max(min(p, 1 - self.eps), self.eps)
self.sum_metric += -numpy.log(p)
self.num_inst += label.size
"""
# pylint: enable=pointless-string-statement

class CustomMetric(EvalMetric):
"""Custom evaluation metric that takes a NDArray function.
Expand Down
20 changes: 15 additions & 5 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _train_multi_device(symbol, ctx, input_shape,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, epoch_end_callback=None,
update_on_kvstore=False,
update_on_kvstore=None,
logger=None):
"""Internal training function on multiple devices.
Expand Down Expand Up @@ -183,8 +183,9 @@ def _train_multi_device(symbol, ctx, input_shape,
-----
- This function will inplace update the NDArrays in arg_parans and aux_states.
- Turning update_on_kvstore on and off can affect speed of multi-gpu training.
- update_on_kvstore=True works well for inception type nets that contains many small weights.
- update_on_kvstore=False works better for Alexnet style net with bulk weights.
- It is auto selected by default.
- update_on_kvstore=True works well for inception type nets that contains many small weights.
- update_on_kvstore=False works better for Alexnet style net with bulk weights.
"""
if logger is None:
logger = logging
Expand All @@ -210,10 +211,17 @@ def _train_multi_device(symbol, ctx, input_shape,

for texec in train_execs:
texec.copy_params_from(arg_params, aux_params)

# ky value store
kv = kvstore.create() if num_device != 1 else None
if kv is None:
update_on_kvstore = False
else:
# auto decide update_on_kvstore
if update_on_kvstore is None:
max_size = max(np.prod(param.shape) for param in arg_params.values())
update_on_kvstore = max_size < 1024 * 1024 * 16
logging.info('Auto-select update_on_kvstore=%s', str(update_on_kvstore))

opt_state_blocks = []
# If there are multiple devices, initialize the weights.
Expand Down Expand Up @@ -586,7 +594,7 @@ def predict(self, X):

def fit(self, X, y=None, eval_data=None, eval_metric='acc',
iter_end_callback=None, epoch_end_callback=None,
update_on_kvstore=False, logger=None):
update_on_kvstore=None, logger=None):
"""Fit the model.
Parameters
Expand Down Expand Up @@ -618,6 +626,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
update_on_kvstore: boolean, optional
Whether to perform parameter update on kvstore instead of training device.
By default, the trainer will automatically decide the policy.
logger : logging logger, optional
When not specified, default logger will be used.
Expand Down Expand Up @@ -711,7 +720,7 @@ def load(prefix, iteration, ctx=None):
def create(symbol, X, y=None, ctx=None,
num_round=None, optimizer='sgd', initializer=Xavier(),
eval_data=None, eval_metric='acc', iter_end_callback=None,
update_on_kvstore=False, logger=None, **kwargs):
update_on_kvstore=None, logger=None, **kwargs):
"""Functional style to create a model.
This function will be more consistent with functional
Expand Down Expand Up @@ -755,6 +764,7 @@ def create(symbol, X, y=None, ctx=None,
update_on_kvstore: boolean, optional
Whether to perform parameter update on kvstore instead of training device.
By default, the trainer will automatically decide the policy.
logger : logging logger, optional
"""
Expand Down

0 comments on commit 927c102

Please sign in to comment.