In [4]:
import torch
import torch.nn as nn

class BatchNorm1d(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = torch.tensor(0.0001)

    def forward(self, x, lamb=0.1, beta=0.1):
        mu = x.mean(dim=0)
        var = x.var(dim=0)
        x = (x - mu) / torch.sqrt(var + self.epsilon)
        x = lamb * x + beta

        return x

In [5]:
bn = BatchNorm1d()
x = torch.randn(16, 10)
out = bn(x)
print(out)

tensor([[ 7.2879e-02,  2.8467e-01,  2.1767e-02,  8.0477e-02,  8.5065e-02,
          2.1376e-02,  2.7958e-01,  1.8744e-01,  1.9016e-01,  2.0988e-01],
        [ 8.7985e-02,  1.4279e-01,  1.9293e-01,  2.8914e-02,  1.1653e-01,
          1.7054e-01,  1.8893e-01,  3.6919e-02, -7.0095e-02,  1.2405e-01],
        [ 1.4876e-01,  1.9134e-01,  1.1250e-01,  9.5024e-02,  4.8234e-02,
          3.3823e-02,  2.0502e-02, -1.5885e-02,  1.7252e-01,  4.8918e-02],
        [ 1.3028e-01,  2.2257e-02,  1.0642e-01,  2.3844e-01,  1.3062e-01,
          2.5044e-01,  2.6934e-01,  1.3307e-01, -2.0873e-02, -3.2597e-02],
        [ 9.8789e-02,  1.9511e-01,  1.9666e-02,  1.6306e-01,  9.3267e-02,
          1.0013e-01,  1.2962e-01,  1.2781e-01,  9.5577e-02,  1.8712e-01],
        [ 6.2113e-03,  8.7198e-02,  3.4345e-01, -2.1714e-02,  1.0512e-02,
          6.2747e-02,  5.0830e-03, -1.2047e-01,  1.5706e-01,  1.7864e-02],
        [ 1.9369e-01,  6.9565e-02, -6.4500e-02,  2.3176e-01,  1.7450e-01,
         -6.3438e-02,  8.1105e-0

In [5]:
import torch
import torch.nn as nn

class BatchNorm2d(nn.Module):
    def __init__(self, lambd = 0.99, gamma=0.1, beta=0.1):
        super().__init__()
        self.gamma = gamma
        self.beta = beta
        self.lambd = lambd
        self.running_mean = 0
        self.running_var = 0

    def forward(self, x):
        #b, c, h, w
        mu = x.mean(dim=[0,2,3])
        var = x.var(dim=[0,2,3])
        x = (x - mu.unsqueeze(1).unsqueeze(2)) / torch.sqrt(var.unsqueeze(1).unsqueeze(2) + 1e-5)
        x = self.gamma * x + self.beta

        #EMA方法计算mean与var
        self.running_mean = self.lambd * self.running_mean + (1 - self.lambd) * self.running_mean
        self.running_var = self.lambd * self.running_var + (1 - self.lambd) * self.running_var

        return x


In [4]:
bn = BatchNorm2d()
x = torch.rand(16, 100, 80, 60)
out = bn(x)
print(out.shape)

torch.Size([16, 100, 80, 60])
