In [3]:
import torch
from torch import nn

In [4]:
class MyLinear(nn.Module):
    def __init__(self, in_features, out_features) -> None:
        super().__init__()
        self.weights = nn.Parameter(torch.rand(in_features, out_features))
        print(self.weights.requires_grad)
        self.bias = nn.Parameter(torch.rand(out_features))
        print(self.bias.shape)

    def forward(self, inputs):
        return torch.matmul(inputs, self.weights) + self.bias

In [7]:
inputs = torch.rand(4, 5)
fc = MyLinear(5, 3)
outputs = fc(inputs)

True
torch.Size([3])


In [None]:
class BatchNorm2d(nn.Module):
    def __init__(self, channel, eps, affine=True, momentum=0.9) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.one(1, channel, 1, 1))
        self.beta = nn.Parameter(torch.zeros(1, channel, 1, 1))
        self.register_buffer('running_mean', torch.Tensor((1,)))
        self.register_buffer('running_variance', torch.tenosr((1,)))
        self.register_buffer('num_batches_tracked')
        self.register_parameter()
        self.eps = eps
        self.affine = affine
        self.momentum = momentum
    
    def forward(self, input):# BCHW
        if self.trainning:
            means = input.mean((0, 2, 3), keepdim=True)
            vars = torch.sum((input - means)**2, (0, 2, 3), keepdim=True)
            
            self.running_mean = self.momentum * self.running_mean + means * (1 - self.momentum)
            self.running_variance = self.momentum * self.running_variance + vars * (1 - self.momentum)
            self.num_batches_tracked += 1
        else:
            means = self.running_mean
            vars = self.running_variance    
        std = (vars + self.eps).sqrt() 
        input = (input - means) / std
        if self.affine:
            input = input * self.alpha + self.beta
        return input


In [13]:
data = [5, 10, 15]
s = data[0]
momentum = 0.9
for j in range(30):
    for i in data:
        s = s * momentum + i * (1 - momentum)
        print(s)
    print("------------")

5.0
5.5
6.449999999999999
------------
6.305
6.6745
7.50705
------------
7.256345
7.5307105
8.277639449999999
------------
7.949875504999999
8.154887954499998
8.839399159049998
------------
8.455459243144999
8.609913318830499
9.24892198694745
------------
8.824029788252705
8.941626809427435
9.547464128484691
------------
9.092717715636223
9.1834459440726
9.76510134966534
------------
9.288591214698807
9.359732093228926
9.923758883906034
------------
9.431382995515431
9.488244695963889
10.0394202263675
------------
9.53547820373075
9.581930383357676
10.123737345021908
------------
9.611363610519717
9.650227249467745
10.185204524520971
------------
9.666684072068874
9.700015664861986
10.230014098375788
------------
9.70701268853821
9.73631141968439
10.262680277715951
------------
9.736412249944356
9.76277102494992
10.286493922454929
------------
9.757844530209436
9.782060077188493
10.303854069469644
------------
9.77346866252268
9.796121796270413
10.316509616643371
------------
9.7848586