Skip to content

Commit

Permalink
Rename BatchNorm running_variance to running_var (pytorch#17371)
Browse files Browse the repository at this point in the history
Summary:
Currently there is a mismatch in naming between Python BatchNorm `running_var` and C++ BatchNorm `running_variance`, which causes JIT model parameters loading to fail (pytorch/vision#728 (comment)):
```
terminate called after throwing an instance of 'c10::Error'
  what():  No such serialized tensor 'running_variance' (read at /home/shahriar/Build/pytorch/torch/csrc/api/src/serialize/input-archive.cpp:27)
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x85 (0x7f2d92d32f95 in /usr/local/lib/libc10.so)
frame #1: torch::serialize::InputArchive::read(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, at::Tensor&, bool) + 0xdeb (0x7f2d938551ab in /usr/local/lib/libtorch.so.1)
frame #2: torch::nn::Module::load(torch::serialize::InputArchive&) + 0x98 (0x7f2d9381cd08 in /usr/local/lib/libtorch.so.1)
frame #3: torch::nn::Module::load(torch::serialize::InputArchive&) + 0xf9 (0x7f2d9381cd69 in /usr/local/lib/libtorch.so.1)
frame #4: torch::nn::Module::load(torch::serialize::InputArchive&) + 0xf9 (0x7f2d9381cd69 in /usr/local/lib/libtorch.so.1)
frame #5: torch::nn::operator>>(torch::serialize::InputArchive&, std::shared_ptr<torch::nn::Module> const&) + 0x32 (0x7f2d9381c7b2 in /usr/local/lib/libtorch.so.1)
frame #6: <unknown function> + 0x2b16c (0x5645f4d1916c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
frame #7: <unknown function> + 0x27a3c (0x5645f4d15a3c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
frame #8: <unknown function> + 0x2165c (0x5645f4d0f65c in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
frame #9: <unknown function> + 0x1540b (0x5645f4d0340b in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
frame #10: __libc_start_main + 0xf3 (0x7f2d051dd223 in /usr/lib/libc.so.6)
frame #11: <unknown function> + 0x1381e (0x5645f4d0181e in /home/shahriar/Projects/CXX/build-TorchVisionTest-Desktop_Qt_5_12_1_GCC_64bit-Debug/TorchVisionTest)
```
Renaming C++ BatchNorm `running_variance` to `running_var` should fix this problem.

This is a BC-breaking change, but it should be easy for end user to rename `running_variance` to `running_var` in their call sites.
Pull Request resolved: pytorch#17371

Reviewed By: goldsborough

Differential Revision: D14172775

Pulled By: yf225

fbshipit-source-id: b9d3729ec79272a8084269756f28a8f7c4dd16b6
  • Loading branch information
Will Feng authored and facebook-github-bot committed Feb 22, 2019
1 parent 562fa55 commit be6ad7d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions test/cpp/api/modules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,9 @@ TEST_F(ModulesTest, BatchNormStateful) {
ASSERT_EQ(bn->running_mean.dim(), 1);
ASSERT_EQ(bn->running_mean.size(0), 5);

ASSERT_TRUE(bn->running_variance.defined());
ASSERT_EQ(bn->running_variance.dim(), 1);
ASSERT_EQ(bn->running_variance.size(0), 5);
ASSERT_TRUE(bn->running_var.defined());
ASSERT_EQ(bn->running_var.dim(), 1);
ASSERT_EQ(bn->running_var.size(0), 5);

// Is affine by default.
ASSERT_TRUE(bn->options.affine());
Expand All @@ -267,7 +267,7 @@ TEST_F(ModulesTest, BatchNormStateless) {
BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false));

ASSERT_FALSE(bn->running_mean.defined());
ASSERT_FALSE(bn->running_variance.defined());
ASSERT_FALSE(bn->running_var.defined());
ASSERT_FALSE(bn->weight.defined());
ASSERT_FALSE(bn->bias.defined());

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/api/include/torch/nn/modules/batchnorm.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TORCH_API BatchNormImpl : public torch::nn::Cloneable<BatchNormImpl> {

/// The running variance.
/// Only defined if the `stateful` option was `true` upon construction.
Tensor running_variance;
Tensor running_var;
};

/// A `ModuleHolder` subclass for `BatchNormImpl`.
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/api/src/nn/modules/batchnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ void BatchNormImpl::reset() {
if (options.stateful_) {
running_mean =
register_buffer("running_mean", torch::zeros({options.features_}));
running_variance =
register_buffer("running_variance", torch::ones({options.features_}));
running_var =
register_buffer("running_var", torch::ones({options.features_}));
}
}

Expand All @@ -47,7 +47,7 @@ Tensor BatchNormImpl::forward(const Tensor& input) {
"Calling BatchNorm::forward is only permitted when "
"the 'stateful' option is true (was false). "
"Use BatchNorm::pure_forward instead.");
return pure_forward(input, running_mean, running_variance);
return pure_forward(input, running_mean, running_var);
}

Tensor BatchNormImpl::pure_forward(
Expand Down

0 comments on commit be6ad7d

Please sign in to comment.