In [221]:
import torch
import torch.nn as nn

## BatchNorm1d - (N, L) の2Dテンソルを入力とした時

- 特徴量Lの各次元ごとに平均、分散を求めて標準化している

In [328]:
m = nn.BatchNorm1d(100, affine=False, track_running_stats=False)

In [329]:
input = torch.rand(16, 100)

In [330]:
out = m(input)

In [331]:
mu = torch.mean(input, dim=0)
var = torch.var(input, dim=0, unbiased=False)
sigma = torch.sqrt(var + 1e-5)

In [354]:
mu.size()

torch.Size([100])

In [332]:
out2 = (input - mu) / sigma

In [339]:
out2

tensor([[-5.3436e-01, -1.2006e+00,  1.6590e+00,  ...,  8.0590e-01,
          8.9554e-01, -1.1625e+00],
        [-9.4430e-02,  1.0396e+00,  1.0832e+00,  ..., -2.1762e-01,
         -1.0105e+00,  6.0995e-01],
        [-1.3220e+00, -6.3421e-01,  9.1513e-02,  ..., -8.8394e-01,
         -6.8018e-01,  1.1044e+00],
        ...,
        [-1.2507e+00,  5.3776e-01, -9.0805e-01,  ...,  1.1384e+00,
          7.5489e-01, -7.2468e-01],
        [ 1.7693e+00,  4.2704e-01, -5.9021e-01,  ...,  2.8768e-01,
          1.3708e+00, -8.9842e-01],
        [ 1.0805e-01, -3.1008e-01, -6.5997e-01,  ...,  1.2403e+00,
         -2.2597e+00, -8.8617e-01]])

In [340]:
out

tensor([[-5.3436e-01, -1.2006e+00,  1.6590e+00,  ...,  8.0590e-01,
          8.9554e-01, -1.1625e+00],
        [-9.4430e-02,  1.0396e+00,  1.0832e+00,  ..., -2.1763e-01,
         -1.0105e+00,  6.0995e-01],
        [-1.3220e+00, -6.3421e-01,  9.1513e-02,  ..., -8.8394e-01,
         -6.8018e-01,  1.1044e+00],
        ...,
        [-1.2507e+00,  5.3776e-01, -9.0805e-01,  ...,  1.1384e+00,
          7.5489e-01, -7.2468e-01],
        [ 1.7693e+00,  4.2704e-01, -5.9021e-01,  ...,  2.8768e-01,
          1.3708e+00, -8.9842e-01],
        [ 1.0805e-01, -3.1008e-01, -6.5997e-01,  ...,  1.2403e+00,
         -2.2597e+00, -8.8617e-01]])

## BatchNorm1d - (N, C, L) の3Dテンソルを入力とした時

- チャネルCごとに平均・分散をとって標準かしている

In [394]:
input = torch.rand(16, 100, 256)

In [395]:
m = nn.BatchNorm1d(100, affine=False, track_running_stats=False)

In [396]:
out = m(input)

In [397]:
input = input.transpose(1, 2)
print(input.size())
input = input.contiguous().view(16 * 256, 100)
print(input.size())

torch.Size([16, 256, 100])
torch.Size([4096, 100])


In [398]:
mu = torch.mean(input, dim=0)
var = torch.var(input, dim=0, unbiased=False)
sigma = torch.sqrt(var + 1e-5)

In [399]:
print(mu.size())
print(var.size())

torch.Size([100])
torch.Size([100])


In [400]:
out2 = (input - mu) / sigma

In [401]:
print(out2.size())
out2 = out2.view(16, 256, 100)
out2 = out2.transpose(1, 2)
print(out2.size())

torch.Size([4096, 100])
torch.Size([16, 100, 256])


In [405]:
out2[0]

tensor([[ 5.6969e-01, -3.9478e-01,  4.3667e-01,  ..., -1.1388e+00,
         -1.3282e+00,  1.5558e+00],
        [-1.4932e+00, -1.0965e+00,  1.6722e+00,  ..., -1.4873e+00,
          9.7042e-01, -1.1067e+00],
        [-7.0421e-02, -9.0799e-02, -1.6933e+00,  ...,  1.3000e+00,
          9.8717e-01, -9.6405e-01],
        ...,
        [ 7.2707e-01,  1.4468e+00, -4.5854e-01,  ...,  1.1904e+00,
          3.9943e-01, -1.3513e-01],
        [-2.8769e-01, -8.7361e-01,  1.1163e+00,  ...,  4.8414e-02,
          1.2485e+00, -5.0380e-01],
        [-1.5412e+00, -6.9136e-01,  2.4907e-02,  ...,  1.4118e+00,
          2.0856e-01, -6.2272e-01]])

In [406]:
out[0]

tensor([[ 5.6969e-01, -3.9478e-01,  4.3667e-01,  ..., -1.1388e+00,
         -1.3282e+00,  1.5558e+00],
        [-1.4932e+00, -1.0965e+00,  1.6722e+00,  ..., -1.4873e+00,
          9.7042e-01, -1.1067e+00],
        [-7.0420e-02, -9.0798e-02, -1.6932e+00,  ...,  1.3000e+00,
          9.8717e-01, -9.6405e-01],
        ...,
        [ 7.2706e-01,  1.4468e+00, -4.5854e-01,  ...,  1.1904e+00,
          3.9943e-01, -1.3513e-01],
        [-2.8769e-01, -8.7361e-01,  1.1163e+00,  ...,  4.8414e-02,
          1.2485e+00, -5.0380e-01],
        [-1.5412e+00, -6.9136e-01,  2.4905e-02,  ...,  1.4118e+00,
          2.0856e-01, -6.2273e-01]])

## BatchNorm2d

In [407]:
input = torch.rand(16, 100, 32, 32)

In [410]:
m = nn.BatchNorm2d(100, affine=False, track_running_stats=False)

In [411]:
out = m(input)

In [412]:
input = input.transpose(1, 2)
input = input.transpose(2, 3)
print(input.size())

torch.Size([16, 32, 32, 100])


In [415]:
input = input.contiguous().view(16 * 32 * 32, 100)
print(input.size())

torch.Size([16384, 100])


In [416]:
mu = torch.mean(input, dim=0)
var = torch.var(input, dim=0, unbiased=False)
sigma = torch.sqrt(var + 1e-5)

In [417]:
mu.size()

torch.Size([100])

In [418]:
out2 = (input - mu) / sigma
print(out2.size())

torch.Size([16384, 100])


In [419]:
out2 = out2.view(16, 32, 32, 100)
out2 = out2.transpose(2, 3)
out2 = out2.transpose(1, 2)
print(out2.size())

torch.Size([16, 100, 32, 32])


In [422]:
out2[0][0]

tensor([[-0.3676,  1.0660, -1.7270,  ...,  0.2349, -0.9167, -1.5125],
        [-1.5726, -0.9856,  1.6037,  ..., -1.7133,  0.8148,  0.8512],
        [ 0.8848,  0.4292,  1.2239,  ...,  0.4213,  0.6701, -1.3766],
        ...,
        [-1.3123,  0.3576,  1.6153,  ...,  0.2425, -0.4328, -0.0343],
        [-1.0931, -0.5602,  0.2153,  ..., -0.8334, -0.0813, -0.1970],
        [ 1.5390, -0.6581,  0.4920,  ..., -0.2716,  1.4207, -1.6593]])

In [423]:
out[0][0]

tensor([[-0.3676,  1.0660, -1.7270,  ...,  0.2349, -0.9167, -1.5125],
        [-1.5726, -0.9856,  1.6037,  ..., -1.7133,  0.8148,  0.8512],
        [ 0.8848,  0.4292,  1.2239,  ...,  0.4213,  0.6701, -1.3766],
        ...,
        [-1.3123,  0.3576,  1.6153,  ...,  0.2425, -0.4328, -0.0343],
        [-1.0931, -0.5602,  0.2153,  ..., -0.8334, -0.0813, -0.1970],
        [ 1.5390, -0.6581,  0.4920,  ..., -0.2716,  1.4207, -1.6593]])

## InstanceNorm1d

In [455]:
input = torch.rand(16, 100, 256)

In [456]:
m = nn.InstanceNorm1d(100, affine=False, track_running_stats=False)

In [457]:
out = m(input)

In [458]:
input = input.view(16 * 100, 256)
input = input.transpose(0, 1)
print(input.size())

torch.Size([256, 1600])


In [459]:
mu = torch.mean(input, dim=0)
var = torch.var(input, dim=0, unbiased=False)
sigma = torch.sqrt(var + 1e-5)

In [460]:
print(mu.size())

torch.Size([1600])


In [461]:
out2 = (input - mu) / sigma
print(out2.size())

torch.Size([256, 1600])


In [462]:
out2 = out2.transpose(0, 1)
out2 = out2.view(16, 100, 256)
print(out2.size())

torch.Size([16, 100, 256])


In [464]:
out2[0]

tensor([[-6.2974e-01, -6.5277e-01, -8.9465e-01,  ...,  1.4388e+00,
         -1.0573e+00,  4.0100e-01],
        [ 9.3242e-01, -1.4115e+00, -1.2465e+00,  ...,  1.6629e+00,
          5.3434e-01,  1.3269e+00],
        [-1.7811e+00,  1.3545e+00,  1.0929e+00,  ..., -1.0244e+00,
          1.6471e-01, -1.0048e+00],
        ...,
        [-1.4536e+00, -1.5437e+00,  3.6212e-01,  ...,  6.8613e-01,
          7.9932e-01,  7.9058e-01],
        [ 9.9791e-02,  9.2586e-02,  1.2675e+00,  ...,  1.2594e+00,
          5.0791e-02,  5.4029e-01],
        [-1.7442e-01, -1.5277e+00,  6.8518e-01,  ...,  1.3089e+00,
         -5.1909e-01, -2.7275e-01]])

In [465]:
out[0]

tensor([[-6.2974e-01, -6.5277e-01, -8.9465e-01,  ...,  1.4388e+00,
         -1.0573e+00,  4.0100e-01],
        [ 9.3242e-01, -1.4115e+00, -1.2465e+00,  ...,  1.6629e+00,
          5.3434e-01,  1.3269e+00],
        [-1.7811e+00,  1.3545e+00,  1.0929e+00,  ..., -1.0244e+00,
          1.6471e-01, -1.0048e+00],
        ...,
        [-1.4536e+00, -1.5437e+00,  3.6212e-01,  ...,  6.8613e-01,
          7.9932e-01,  7.9058e-01],
        [ 9.9791e-02,  9.2586e-02,  1.2675e+00,  ...,  1.2594e+00,
          5.0791e-02,  5.4029e-01],
        [-1.7442e-01, -1.5277e+00,  6.8518e-01,  ...,  1.3089e+00,
         -5.1909e-01, -2.7275e-01]])