This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
[MXNET-979] Add fix_beta support in BatchNorm #12625
Merged
sandeep-krishnamurthy
merged 10 commits into
apache:master
from
sandeep-krishnamurthy:support_fix_beta_batchnorm
Oct 10, 2018
Merged
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
6b97294
Add fix_beta support in BatchNorm CPU implementation
sandeep-krishnamurthy cf618c2
Fix lint checks. Update GPU tests
sandeep-krishnamurthy 070fc6f
Fix gpu tests
sandeep-krishnamurthy fbefee8
make fix_beta not available for sparse. Update fix_beta for mkldnn
sandeep-krishnamurthy 9e95a61
Make default fix_beta to False for backward compatibility
sandeep-krishnamurthy c374f38
Add fix_beta to cudnn batchnorm operator
sandeep-krishnamurthy cecec80
Add tests for missing fix_beta and fix_gamma params
sandeep-krishnamurthy b7733e9
fix indentation
sandeep-krishnamurthy 60152f7
Fix failing tests
sandeep-krishnamurthy f11de73
simplify the cases with defaults for gamma, beta
sandeep-krishnamurthy File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -155,7 +155,7 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *, | |
|
||
// compute output | ||
AccReal *w = weights.dptr<AccReal>(); | ||
const AccReal *b = bias.dptr<AccReal>(); | ||
AccReal *b = bias.dptr<AccReal>(); | ||
|
||
const AccReal thisMean = mean[channel]; | ||
const AccReal thisInvstd = var[channel]; | ||
|
@@ -164,25 +164,65 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *, | |
|
||
// note that var is still invstd | ||
if (!param_.fix_gamma) { | ||
if (IsBNWriting(req[batchnorm::kData])) { | ||
ForEachFast(inputData, outputData, channel, | ||
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, | ||
DType *out_data) { | ||
*out_data = static_cast<DType>( | ||
((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias); | ||
}); | ||
if (!param_.fix_beta) { | ||
// Case 1 | ||
// fix_gamma = False | ||
// fix_beta = False | ||
if (IsBNWriting(req[batchnorm::kData])) { | ||
ForEachFast(inputData, outputData, channel, | ||
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of using lambda, I think maybe it's better to define a function here since it is used in all 4 cases. A function with proper name will make the code more readable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As discussed offline, this change would make readability difficult and pass more params unnecessarily across functions. |
||
DType *out_data) { | ||
*out_data = static_cast<DType>( | ||
((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias); | ||
}); | ||
} | ||
} else { | ||
// Case 2 | ||
// fix_gamma = False | ||
// fix_beta = True | ||
if (IsBNWriting(req[batchnorm::kBeta])) { | ||
b[channel] = AccReal(0); | ||
} | ||
if (IsBNWriting(req[batchnorm::kData])) { | ||
ForEachFast(inputData, outputData, channel, | ||
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, | ||
DType *out_data) { | ||
*out_data = static_cast<DType>( | ||
((*in_data - thisMean) * thisInvstd) * thisWeight); | ||
}); | ||
} | ||
} | ||
} else { | ||
if (IsBNWriting(req[batchnorm::kGamma])) { | ||
w[channel] = AccReal(1); | ||
} | ||
if (IsBNWriting(req[batchnorm::kData])) { | ||
ForEachFast(inputData, outputData, channel, | ||
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, | ||
DType *out_data) { | ||
*out_data = static_cast<DType>( | ||
((*in_data - thisMean) * thisInvstd) + thisBias); | ||
}); | ||
if (!param_.fix_beta) { | ||
// Case 3 | ||
// fix_gamma = True | ||
// fix_beta = False | ||
if (IsBNWriting(req[batchnorm::kData])) { | ||
ForEachFast(inputData, outputData, channel, | ||
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, | ||
DType *out_data) { | ||
*out_data = static_cast<DType>( | ||
((*in_data - thisMean) * thisInvstd) + thisBias); | ||
}); | ||
} | ||
} else { | ||
// Case 4 | ||
// fix_gamma = True | ||
// fix_beta = True | ||
if (IsBNWriting(req[batchnorm::kBeta])) { | ||
b[channel] = AccReal(0); | ||
} | ||
if (IsBNWriting(req[batchnorm::kData])) { | ||
ForEachFast(inputData, outputData, channel, | ||
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, | ||
DType *out_data) { | ||
*out_data = static_cast<DType>( | ||
((*in_data - thisMean) * thisInvstd)); | ||
}); | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -309,7 +349,11 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *, | |
} | ||
|
||
if (IsBNWriting(req[batchnorm::kBeta])) { | ||
gradBiasData[channel] = scale * sumGradOut; | ||
if (!param_.fix_beta) { | ||
gradBiasData[channel] = scale * sumGradOut; | ||
} else { | ||
gradBiasData[channel] = AccReal(0); | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -478,6 +522,9 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs, | |
if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_gamma) { | ||
LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. Tracked at #11647"; | ||
} | ||
if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_beta) { | ||
LOG(FATAL) << "fix_beta=True is not supported for sparse ndarrays. Tracked at #11647"; | ||
} | ||
return dispatched; | ||
} | ||
|
||
|
@@ -565,11 +612,12 @@ the 'channel' (separately normalized groups). The default is 1. Specifying -1 | |
axis to be the last item in the input shape. | ||
|
||
Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, | ||
then set ``gamma`` to 1 and its gradient to 0. | ||
then set ``gamma`` to 1 and its gradient to 0. If ``fix_beta`` is true, then set ``beta`` to 0 | ||
and its gradient to 0. | ||
|
||
Note:: | ||
|
||
When fix_gamma is set to True, no sparse support is provided. If fix_gamma is set to False, | ||
When fix_gamma/fix_beta is set to True, no sparse support is provided. If fix_gamma/fix_beta is set to False, | ||
the sparse tensors will fallback. | ||
|
||
)code" ADD_FILELINE) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you really need 4 cases rather than collate 2 cases separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is basically 2 cases based on fix_gamma. fix_beta cases are collated within it.
This has more readability with flags on the cases. Suggestions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix gamma and beta only involves whether including (*gamma) and (+bias), but here's more than 50 lines only to accomplish that. It should be fine with the current lambda function with one or two cases, but with 4 cases or maybe 8 cases in the future, it's going to cause a lot of trouble.