Skip to content

Commit

Permalink
groupnorm 1d/2d/4d (#4312)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Oct 31, 2022
1 parent b853b3d commit 6e49fa3
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 54 deletions.
194 changes: 154 additions & 40 deletions src/layer/groupnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,66 +52,180 @@ int GroupNorm::load_model(const ModelBin& mb)

int GroupNorm::forward_inplace(Mat& bottom_top_blob, const Option& opt) const
{
// x = (x - mean) / sqrt(var + eps) * gamma + beta
const int dims = bottom_top_blob.dims;
const int channels_per_group = channels / group;

int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int size = w * h;

int channels_per_group = channels / group;

#pragma omp parallel for num_threads(opt.num_threads)
for (int g = 0; g < group; g++)
if (dims == 1)
{
Mat bottom_top_blob_g = bottom_top_blob.channel_range(g * channels_per_group, channels_per_group);

// mean and var
float sum = 0.f;
for (int q = 0; q < channels_per_group; q++)
#pragma omp parallel for num_threads(opt.num_threads)
for (int g = 0; g < group; g++)
{
const float* ptr = bottom_top_blob_g.channel(q);
for (int i = 0; i < size; i++)
Mat bottom_top_blob_g = bottom_top_blob.range(g * channels_per_group, channels_per_group);
const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);

// mean and var
float sum = 0.f;
for (int q = 0; q < channels_per_group; q++)
{
sum += ptr[i];
sum += bottom_top_blob_g[q];
}
}
float mean = sum / (channels_per_group * size);
float mean = sum / channels_per_group;

float sqsum = 0.f;
for (int q = 0; q < channels_per_group; q++)
{
const float* ptr = bottom_top_blob_g.channel(q);
for (int i = 0; i < size; i++)
float sqsum = 0.f;
for (int q = 0; q < channels_per_group; q++)
{
float tmp = ptr[i] - mean;
float tmp = bottom_top_blob_g[q] - mean;
sqsum += tmp * tmp;
}
float var = sqsum / channels_per_group;

for (int q = 0; q < channels_per_group; q++)
{
float a;
float b;
if (affine)
{
float gamma = gamma_data_g[q];
float beta = beta_data_g[q];

a = (float)(gamma / sqrt(var + eps));
b = -mean * a + beta;
}
else
{
a = (float)(1.f / (sqrt(var + eps)));
b = -mean * a;
}

bottom_top_blob_g[q] = bottom_top_blob_g[q] * a + b;
}
}
float var = sqsum / (channels_per_group * size);
}

for (int q = 0; q < channels_per_group; q++)
if (dims == 2)
{
int w = bottom_top_blob.w;

#pragma omp parallel for num_threads(opt.num_threads)
for (int g = 0; g < group; g++)
{
float a;
float b;
if (affine)
Mat bottom_top_blob_g = bottom_top_blob.row_range(g * channels_per_group, channels_per_group);
const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);

// mean and var
float sum = 0.f;
for (int q = 0; q < channels_per_group; q++)
{
float gamma = gamma_data[g * channels_per_group + q];
float beta = beta_data[g * channels_per_group + q];
const float* ptr = bottom_top_blob_g.row(q);
for (int i = 0; i < w; i++)
{
sum += ptr[i];
}
}
float mean = sum / (channels_per_group * w);

a = static_cast<float>(gamma / sqrt(var + eps));
b = -mean * a + beta;
float sqsum = 0.f;
for (int q = 0; q < channels_per_group; q++)
{
const float* ptr = bottom_top_blob_g.row(q);
for (int i = 0; i < w; i++)
{
float tmp = ptr[i] - mean;
sqsum += tmp * tmp;
}
}
else
float var = sqsum / (channels_per_group * w);

for (int q = 0; q < channels_per_group; q++)
{
a = static_cast<float>(1.f / (sqrt(var + eps)));
b = -mean * a;
float a;
float b;
if (affine)
{
float gamma = gamma_data_g[q];
float beta = beta_data_g[q];

a = (float)(gamma / sqrt(var + eps));
b = -mean * a + beta;
}
else
{
a = (float)(1.f / (sqrt(var + eps)));
b = -mean * a;
}

float* ptr = bottom_top_blob_g.row(q);
for (int i = 0; i < w; i++)
{
ptr[i] = ptr[i] * a + b;
}
}
}
}

float* ptr = bottom_top_blob_g.channel(q);
if (dims == 3 || dims == 4)
{
int w = bottom_top_blob.w;
int h = bottom_top_blob.h;
int d = bottom_top_blob.d;
int size = w * h * d;

#pragma omp parallel for num_threads(opt.num_threads)
for (int g = 0; g < group; g++)
{
Mat bottom_top_blob_g = bottom_top_blob.channel_range(g * channels_per_group, channels_per_group);
const Mat gamma_data_g = gamma_data.range(g * channels_per_group, channels_per_group);
const Mat beta_data_g = beta_data.range(g * channels_per_group, channels_per_group);

// mean and var
float sum = 0.f;
for (int q = 0; q < channels_per_group; q++)
{
const float* ptr = bottom_top_blob_g.channel(q);
for (int i = 0; i < size; i++)
{
sum += ptr[i];
}
}
float mean = sum / (channels_per_group * size);

float sqsum = 0.f;
for (int q = 0; q < channels_per_group; q++)
{
const float* ptr = bottom_top_blob_g.channel(q);
for (int i = 0; i < size; i++)
{
float tmp = ptr[i] - mean;
sqsum += tmp * tmp;
}
}
float var = sqsum / (channels_per_group * size);

for (int i = 0; i < size; i++)
for (int q = 0; q < channels_per_group; q++)
{
ptr[i] = ptr[i] * a + b;
float a;
float b;
if (affine)
{
float gamma = gamma_data_g[q];
float beta = beta_data_g[q];

a = (float)(gamma / sqrt(var + eps));
b = -mean * a + beta;
}
else
{
a = (float)(1.f / (sqrt(var + eps)));
b = -mean * a;
}

float* ptr = bottom_top_blob_g.channel(q);
for (int i = 0; i < size; i++)
{
ptr[i] = ptr[i] * a + b;
}
}
}
}
Expand Down
38 changes: 37 additions & 1 deletion tests/test_groupnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ static int test_groupnorm(const ncnn::Mat& a, int group, float eps)
}

static int test_groupnorm_0()
{
return 0
|| test_groupnorm(RandomMat(3, 6, 4, 2), 1, 0.01f)
|| test_groupnorm(RandomMat(2, 3, 3, 8), 2, 0.002f)
|| test_groupnorm(RandomMat(3, 4, 5, 6), 3, 0.01f)
|| test_groupnorm(RandomMat(4, 5, 6, 12), 4, 0.02f)
|| test_groupnorm(RandomMat(5, 6, 7, 24), 2, 0.001f)
|| test_groupnorm(RandomMat(2, 8, 9, 24), 3, 0.0001f);
}

static int test_groupnorm_1()
{
return 0
|| test_groupnorm(RandomMat(6, 4, 2), 1, 0.01f)
Expand All @@ -48,10 +59,35 @@ static int test_groupnorm_0()
|| test_groupnorm(RandomMat(8, 9, 24), 3, 0.0001f);
}

static int test_groupnorm_2()
{
return 0
|| test_groupnorm(RandomMat(24, 2), 1, 0.01f)
|| test_groupnorm(RandomMat(23, 8), 2, 0.002f)
|| test_groupnorm(RandomMat(25, 6), 3, 0.01f)
|| test_groupnorm(RandomMat(26, 12), 4, 0.02f)
|| test_groupnorm(RandomMat(27, 24), 2, 0.001f)
|| test_groupnorm(RandomMat(29, 24), 3, 0.0001f);
}

static int test_groupnorm_3()
{
return 0
|| test_groupnorm(RandomMat(12), 1, 0.01f)
|| test_groupnorm(RandomMat(18), 2, 0.002f)
|| test_groupnorm(RandomMat(36), 3, 0.01f)
|| test_groupnorm(RandomMat(212), 4, 0.02f)
|| test_groupnorm(RandomMat(124), 2, 0.001f)
|| test_groupnorm(RandomMat(324), 3, 0.0001f);
}

int main()
{
SRAND(7767517);

return 0
|| test_groupnorm_0();
|| test_groupnorm_0()
|| test_groupnorm_1()
|| test_groupnorm_2()
|| test_groupnorm_3();
}
18 changes: 13 additions & 5 deletions tools/pnnx/tests/ncnn/test_F_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,29 +20,37 @@ class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

self.w3 = nn.Parameter(torch.rand(16))
self.b3 = nn.Parameter(torch.rand(16))
self.w4 = nn.Parameter(torch.rand(12))
self.b4 = nn.Parameter(torch.rand(12))
self.w5 = nn.Parameter(torch.rand(32))
self.b5 = nn.Parameter(torch.rand(32))

def forward(self, z):
def forward(self, x, y, z):
x = F.group_norm(x, 4, self.w3, self.b3)
y = F.group_norm(y, 6, self.w4, self.b4)
z = F.group_norm(z, 8, self.w5, self.b5, eps=1e-2)
return z
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 16)
y = torch.rand(1, 12, 16)
z = torch.rand(1, 32, 12, 16)

a = net(z)
a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, z)
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_F_group_norm.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_F_group_norm.pt inputshape=[1,32,12,16]")
os.system("../../src/pnnx test_F_group_norm.pt inputshape=[1,16],[1,12,16],[1,32,12,16]")

# ncnn inference
import test_F_group_norm_ncnn
Expand Down
29 changes: 21 additions & 8 deletions tools/pnnx/tests/ncnn/test_nn_GroupNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,47 @@ def __init__(self):
self.gn_1 = nn.GroupNorm(num_groups=12, num_channels=12, eps=1e-2, affine=True)
self.gn_2 = nn.GroupNorm(num_groups=1, num_channels=12, eps=1e-4, affine=True)

def forward(self, x):
def forward(self, x, y, z):
x = self.gn_0(x)
x = self.gn_1(x)
x = self.gn_2(x)
return x

y = self.gn_0(y)
y = self.gn_1(y)
y = self.gn_2(y)

z = self.gn_0(z)
z = self.gn_1(z)
z = self.gn_2(z)
return x, y, z

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 12, 24, 64)
x = torch.rand(1, 12, 64)
y = torch.rand(1, 12, 24, 64)
z = torch.rand(1, 12, 24, 32, 64)

a0 = net(x)
a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, x)
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_nn_GroupNorm.pt")

# torchscript to pnnx
import os
os.system("../../src/pnnx test_nn_GroupNorm.pt inputshape=[1,12,24,64]")
os.system("../../src/pnnx test_nn_GroupNorm.pt inputshape=[1,12,64],[1,12,24,64],[1,12,24,32,64]")

# ncnn inference
import test_nn_GroupNorm_ncnn
b0 = test_nn_GroupNorm_ncnn.test_inference()
b = test_nn_GroupNorm_ncnn.test_inference()

return torch.allclose(a0, b0, 1e-4, 1e-4)
for a0, b0 in zip(a, b):
if not torch.allclose(a0, b0, 1e-4, 1e-4):
return False
return True

if __name__ == "__main__":
if test():
Expand Down

0 comments on commit 6e49fa3

Please sign in to comment.