Skip to content

Commit

Permalink
handle batchnorm fix_gamma attr properly
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Nov 29, 2018
1 parent c77383f commit ed2a24c
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions tools/mxnet/mxnet2ncnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ std::vector<float> MXNetNode::weight(int i, int init_len) const

if (!p.init.empty() && init_len != 0)
{
if (p.init == "[\\$zero\\$, {}]")
if (p.init == "[\\$zero\\$, {}]" || p.init == "[\\\"zero\\\", {}]")
{
data.resize(init_len, 0.f);
}
else if (p.init == "[\\$one\\$, {}]")
else if (p.init == "[\\$one\\$, {}]" || p.init == "[\\\"one\\\", {}]")
{
data.resize(init_len, 1.f);
}
Expand Down Expand Up @@ -1483,6 +1483,16 @@ int main(int argc, char** argv)

fprintf(pp, " 0=%d", channels);

int fix_gamma = n.has_attr("fix_gamma") ? n.attr("fix_gamma") : 0;
if (fix_gamma)
{
// slope data are all 0 here, force set 1
for (int j=0; j<channels; j++)
{
slope_data[j] = 1.f;
}
}

fwrite(slope_data.data(), sizeof(float), slope_data.size(), bp);
fwrite(mean_data.data(), sizeof(float), mean_data.size(), bp);
fwrite(var_data.data(), sizeof(float), var_data.size(), bp);
Expand Down

0 comments on commit ed2a24c

Please sign in to comment.