|
|
@@ -1018,7 +1018,13 @@ void UpgradeNetBatchNorm(NetParameter* net_param) { |
|
|
// the previous BatchNorm layer definition.
|
|
|
if (net_param->layer(i).type() == "BatchNorm"
|
|
|
&& net_param->layer(i).param_size() == 3) {
|
|
|
- net_param->mutable_layer(i)->clear_param();
|
|
|
+ // set lr_mult and decay_mult to zero. leave all other param intact.
|
|
|
+ for (int ip = 0; ip < net_param->layer(i).param_size(); ip++) {
|
|
|
+ ParamSpec* fixed_param_spec =
|
|
|
+ net_param->mutable_layer(i)->mutable_param(ip);
|
|
|
+ fixed_param_spec->set_lr_mult(0.f);
|
|
|
+ fixed_param_spec->set_decay_mult(0.f);
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|