-
Notifications
You must be signed in to change notification settings - Fork 6.8k
add axes support for dropouts in gluon #10032
Conversation
@@ -239,15 +241,16 @@ class Dropout(HybridBlock): | |||
`Dropout: A Simple Way to Prevent Neural Networks from Overfitting | |||
<http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ | |||
""" | |||
def __init__(self, rate, **kwargs): | |||
def __init__(self, rate, axes=(), **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Default to None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same default value as the one added in Dropout op. I set it this way so that I don't have to handle None.
compactshape[axis] = 1 | ||
compactx = mx.random.uniform(shape=tuple(compactshape)) | ||
broadcastx = compactx.broadcast_to(shape) | ||
dropouty = mx.gluon.nn.Dropout(rate=ratio, axes=axes)(broadcastx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May need consider to improve the test in the future. Currently there's no guarantee that the observed dropout ratio matches the given dropout ratio.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The axes=()
problem and the test should be revised later. The overall logic looks good.
Description
add axes support for dropouts in gluon
Checklist
Essentials
make lint
)Changes