Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] Momentum for CPU Batchnorm #485

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 16 additions & 25 deletions flashlight/fl/autograd/backend/cpu/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ Variable batchnorm(
Variable& runningVar,
const std::vector<int>& axes,
bool train,
double momentum,
double epsilon) {
if (input.type() == f16) {
throw std::runtime_error("Half precision is not supported in CPU.");
}

auto output = af::array(input.dims(), input.type());

int nfeatures = 1;
Expand All @@ -52,6 +57,9 @@ Variable batchnorm(
runningMean = Variable(af::constant(0.0, nfeatures, input.type()), false);
}

Variable inputMean(runningMean.array().copy(), false);
Variable inputVar(runningVar.array().copy(), false);

// Check if axes are valid
auto maxAxis = *std::max_element(axes.begin(), axes.end());
auto minAxis = *std::min_element(axes.begin(), axes.end());
Expand Down Expand Up @@ -109,9 +117,9 @@ Variable batchnorm(
const detail::DnnlMemoryWrapper outputMemory(
output, inputOutputDims, formatNCHW);
const detail::DnnlMemoryWrapper meanMemory(
runningMean.array(), {runningMean.dims(0)}, formatX);
inputMean.array(), {inputMean.dims(0)}, formatX);
const detail::DnnlMemoryWrapper varMemory(
runningVar.array(), {runningVar.dims(0)}, formatX);
inputVar.array(), {inputVar.dims(0)}, formatX);
// combined scale and shift (weight and bias)
const detail::DnnlMemoryWrapper weightsMemory(
weightsDnnl, weightsDnnlDims, format2d);
Expand Down Expand Up @@ -142,6 +150,12 @@ Variable batchnorm(
network.push_back(bn);
detail::executeNetwork(network, fwdArgs);

// Update running mean and variance using momentum
if (train) {
runningMean = (1 - momentum) * runningMean + momentum * inputMean;
runningVar = momentum * runningVar + (1 - momentum) * inputVar;
}

/****************************************************************************/
// Setup backward func

Expand Down Expand Up @@ -239,27 +253,4 @@ Variable batchnorm(
return Variable(output, {input, weight, bias}, gradFunc);
}

Variable batchnorm(
const Variable& input,
const Variable& weight,
const Variable& bias,
Variable& runningMean,
Variable& runningVar,
const std::vector<int>& axes,
bool train,
double momentum,
double epsilon) {
if (input.type() == f16) {
throw std::runtime_error("Half precision is not supported in CPU.");
}
// CPU backend DNNL doesn't support a momentum factor.
// If momentum is enabled, throw.
if (momentum == 0.0) {
return batchnorm(
input, weight, bias, runningMean, runningVar, axes, train, epsilon);
} else {
throw std::runtime_error("BatchNorm CPU backend doesn't support momentum.");
}
}

} // namespace fl
38 changes: 38 additions & 0 deletions flashlight/fl/test/autograd/AutogradTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,44 @@ TEST(AutogradTest, BatchNormTrainModeOutputSingleAxis) {
ASSERT_TRUE(allClose(out.array(), expectedOut.array(), 1e-5));
}

TEST(AutogradTest, BatchNormTrainModeOutputSingleAxisMomentum) {
int numFeat = 3;
std::vector<int> featAxes = {2};
double momentum = 2.2;
double epsilon = 1E-5;
auto input = Variable(af::randu(13, 13, numFeat, 8), true);
af::print("input", input.array());
auto weight = Variable(af::randu(numFeat), true);
auto bias = Variable(af::randu(numFeat), true);
auto runningMean = Variable(af::randu(numFeat), false);
auto runningVar = Variable(af::randu(numFeat), false);

auto out = batchnorm(
input,
weight,
bias,
runningMean,
runningVar,
featAxes,
true,
momentum,
epsilon);

auto todim = af::dim4(1, 1, numFeat);
std::vector<int> nrm_axes = {0, 1, 3};
auto avg = moddims(mean(input, nrm_axes), todim);
auto variance =
moddims(var(input, nrm_axes, true /* population var */), todim);
auto expectedOut = (input - tileAs(avg, input)) /
fl::sqrt(tileAs(variance, input) + epsilon);
expectedOut = expectedOut * tileAs(moddims(weight, todim), input) +
tileAs(moddims(bias, todim), input);
ASSERT_TRUE(allClose(out.array(), expectedOut.array(), 1e-5));

af::print("runningMean momentum", runningMean.array());
af::print("runningVar momentum", runningVar.array());
}

TEST(AutogradTest, BatchNormTrainModeOutputMultipleAxis) {
std::vector<int> featAxes = {0, 1, 2};
auto input = Variable(af::randu(13, 13, 4, 8), true);
Expand Down