-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Refactor AdaGrad optimizer to support sparse tensors + unary and binary refactoring for new infer storage type logic #7903
Conversation
@@ -36,21 +36,21 @@ NNVM_REGISTER_OP(_backward_add) | |||
ElemwiseBinaryOp::BackwardUseNoneWithHalf2<gpu, mshadow_op::identity, | |||
mshadow_op::identity>); | |||
|
|||
NNVM_REGISTER_OP(_sub) | |||
NNVM_REGISTER_OP(elemwise_sub) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't this break backward compatibility in cpp package?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not that I know of. I don't think anyone using cpp-package uses these sort of operators.
The naming inconsistency between elemwise_sub and the other similar three is nonsensical.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The definition for sparse update is to only apply update on weight/states on the rows whose gradients are non-zeros.
We should revisit our approach and have sparse update primitives like:
`scatter_add`: adds `scalar` to all rows of `lhs` (row_sparse) specified by `idx`, which returns row_sparse result.
`scatter_div`: divides `lhs` (row_sparse) by `rhs` (row_sparse) for the rows specified by `idx`, which returns row_sparse updated result
so that Adagrad can be implemented as:
indices = grad.indices
history[:] = op.elemwise_add(history, op.square(grad))
srt = op.sqrt(nd.sparse.scatter_add(lhs=history, scalar=eps, idx=indices)
div = nd.sparse.scatter_div(lhs=grad, rhs=srt, idx=indices)
weight[:] += (div + sparse.retain(weight, indices) * wd) * -lr
Implementing these primitives takes extra time but definitely very useful when it comes to support sparse udpate for other optimizers implemented in python. @cjolivier01 This involves a large scope. what do you think?
cjolivier01: We discussed before that these scatter ops were internal. It's not clear what you're suggesting with that adagrad code. what are you saying would be different? Feel free to implement other operators if you like.
python/mxnet/optimizer.py
Outdated
|
||
def create_state(self, index, weight): | ||
return zeros(weight.shape, weight.context) # history | ||
return zeros(weight.shape, weight.context, stype=self.stype) # history |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please create the state based on weight.stype. The states should always have the same stype as the weight. Perform sparse update only when w.stype == g.stype == state.stype
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
they don't, actually. I get update calls for sized-1 dense weights along with the sparse ones. is the update not expected to occur then? Because the non-sparse version updates them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/mxnet/optimizer.py
Outdated
@@ -665,26 +667,46 @@ class AdaGrad(Optimizer): | |||
eps: float, optional | |||
Small value to avoid division by 0. | |||
""" | |||
def __init__(self, eps=1e-7, **kwargs): | |||
def __init__(self, eps=1e-7, stype='default', **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove stype argument here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/mxnet/optimizer.py
Outdated
@@ -665,26 +667,46 @@ class AdaGrad(Optimizer): | |||
eps: float, optional | |||
Small value to avoid division by 0. | |||
""" | |||
def __init__(self, eps=1e-7, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the documentation the same way as Adam/SGD
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you link to PR?
python/mxnet/optimizer.py
Outdated
weight[:] += -lr * (grad / sqrt(history + self.float_stable_eps) + wd * weight) | ||
save_history_stype = history.stype | ||
|
||
is_sparse = True if weight.stype != 'default' or grad.stype != 'default' else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_sparse = true iff w.stype == g.stype == state.stype==row_sparse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x = True if cond else False
<=> x = cond
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Adam and SGD, it
- perform dense updates, if everything is
default
- perform sparse updates, if everything is
row_sparse
- fallback to dense and print warning message, if the inputs have both sparse and dense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
many of these ops support both sparse and dense input combinations (and handle them in an efficient manner without fallback). not to say it's the most efficient way to do it, but it's legal.
python/mxnet/optimizer.py
Outdated
if is_sparse: | ||
history[:] = op.elemwise_add(history, op.square(grad)) | ||
assert history.stype == save_history_stype | ||
srt = op.sqrt(history) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not adding eps will lead to numerical errors (nan) since some entries in grad.data is zero.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, will scatter_plus them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cjolivier01 The purpose of the revised scatter_div operator (with specified indices) is that we can make it public to other people who want to work on other optimizers that perform sparse updates. The current scatter_div operator doesn't take indices as input and has some limitations.
BTW when you edit my comment, I don't see any notification so I easily missed your updates...
python/mxnet/optimizer.py
Outdated
history[:] += square(grad) | ||
div = grad / sqrt(history + self.float_stable_eps) | ||
|
||
weight[:] += (div + weight * wd) * -lr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of weight * wd
, it should be sparse.retain(weight, grad.indices) * wd
since we're only updating the row slices appeared in grad.indices. Otherwise, the update is not sparse - after one epoch each update touches a million rows if you use weight
directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, that version made sense! thanks!
python/mxnet/optimizer.py
Outdated
weight[:] += -lr * (grad / sqrt(history + self.float_stable_eps) + wd * weight) | ||
save_history_stype = history.stype | ||
|
||
is_sparse = True if weight.stype != 'default' or grad.stype != 'default' else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For Adam and SGD, it
- perform dense updates, if everything is
default
- perform sparse updates, if everything is
row_sparse
- fallback to dense and print warning message, if the inputs have both sparse and dense
different/more selective scatter behavior can be done in a different PR. I don't have the bandwidth for that now. |
python/mxnet/optimizer.py
Outdated
if is_sparse: | ||
history[:] = op.elemwise_add(history, op.square(grad)) | ||
assert history.stype == save_history_stype | ||
srt = op.sqrt(_internal._scatter_plus_scalar(history, self.float_stable_eps)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use scatter_plus(sparse.retain(history, indices))
instead of scatter_plus(history)
? otherwise the scatter_plus is expensive
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok...
@@ -697,6 +697,19 @@ def check_binary_op_with_scalar(stype, | |||
force_overlap=force_overlap, | |||
verbose=False) | |||
|
|||
# plus_scalar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have tests for scatter_plus/scatter_div?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also scatter_minus
mod.update() # update parameters | ||
# print('Epoch %d, Training %s' % (epoch, metric.get())) | ||
assert(metric.get()[1] < 0.05), metric.get()[1] | ||
def check_factorization_machine_module(optimizer=None, num_epochs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add unit test in test_optimizer.py to test sparse AdaGrad..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_optimizer appears to test C++ version against python version. There is only a python version for AdaGrad, therefore it's not clear what it tests against. I am using the test_module() test with an expected accuracy rate to test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, the purpose of the tests in test_optimizer is to verify if the update only involves rows appeared in grad.indices for rsp weight and rsp grad
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have asserts in update that test for this as well as storage type, which run during the test_factorization_machine_module() test, which I modified to test multiple optimizers (sgd, adam, adagrad)
…o sparse_adagrad_pr
I don't think clip should be sparse op. |
/*! | ||
* \brief CSR operation requires temp space | ||
*/ | ||
struct ResourceRequest { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
??
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Used to index into ctx.resources during CSR pass. It is, in fact, referencing a RespurceRequest.
Clip handles that situation in FInferStorageType |
Changed |
…ry refactoring for new infer storage type logic (apache#7903) Refactor AdaGrad optimizer to support sparse tensors + unary and binary refactoring for new infer storage type logic (apache#7903)
Refactor AdaGrad optimizer to support sparse tensors
Add sparse support for _plus_scalar, _minus_scalar, clip
Some additional unary and binary refactoring for new infer storage type logic