In [27]:
import torch
import torch.nn as nn

In [26]:
a = torch.Tensor([[[2, 3, 4], [1, 1, 1], [0, -4, 18], [5, 6, 7]],
                  [[1, 2, 55], [5, 34, 13], [0, 0, 0], [-10, -6, 7]]])
a.shape # .shape -> [batch_size, seq_len, emb_size]

torch.Size([2, 4, 3])

## BatchNorm

In [34]:
# BatchNorm -> часто применяется для CNN и в целом в CV
class BatchNorm:
  def __init__(self, emb_size, eps=1e-5, momentum=0.1):
    self.emb_size = emb_size
    self.eps = eps
    self.momentum = momentum
    self.gamma = torch.ones(1, 1, emb_size, requires_grad=True)
    self.beta = torch.zeros(1, 1, emb_size, requires_grad=True)
    # для инференса скользящие статистики (среднее и дисперсия)
    self.running_mean = torch.zeros(1, 1, emb_size)
    self.running_var = torch.ones(1, 1, emb_size)
    self.training = True

  def __call__(self, x):
    if self.training:
      # вычисляем mean и var по осям: batch_size и seq_len (для каждого признака)
      batch_mean = x.mean(dim=(0, 1), keepdim=True) # [1, 1, emb_size]
      batch_var = x.var(dim=(0, 1), keepdim=True) # [1, 1, emb_size]

      self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
      self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
    else:
      batch_mean = self.running_mean
      batch_var = self.running_var

    x_norm = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
    out = self.gamma * x_norm + self.beta
    return out

In [39]:
bn = BatchNorm(emb_size=a.size(2))
out = bn(a)
out, out.shape, out.mean(dim=(0, 1)), out.var(dim=(0, 1))

(tensor([[[ 0.3198, -0.1199, -0.5084],
          [ 0.1066, -0.2797, -0.6756],
          [-0.1066, -0.6793,  0.2716],
          [ 0.9594,  0.1199, -0.3413]],
 
         [[ 0.1066, -0.1998,  2.3331],
          [ 0.9594,  2.3576, -0.0070],
          [-0.1066, -0.3596, -0.7313],
          [-2.2386, -0.8391, -0.3413]]], grad_fn=<AddBackward0>),
 torch.Size([2, 4, 3]),
 tensor([1.4901e-08, 1.4901e-08, 2.2352e-08], grad_fn=<MeanBackward1>),
 tensor([1.0000, 1.0000, 1.0000], grad_fn=<VarBackward0>))

## LayerNorm

In [42]:
# LayerNorm - часто применяется в трансформерах и RNN
class LayerNorm:
  def __init__(self, emb_size, eps=1e-5):
    self.emb_size = emb_size
    self.eps = eps
    self.gamma = torch.ones(emb_size, requires_grad=True)
    self.beta = torch.zeros(emb_size, requires_grad=True)

  def __call__(self, x):
    emb_mean = x.mean(dim=2, keepdim=True) # [batch_size, seq_len, 1]
    emb_var = x.var(dim=2, unbiased=False, keepdim=True) # [batch_size, seq_len, 1]
    x_norm = (x - emb_mean) / torch.sqrt(emb_var + self.eps)
    out = x_norm * self.gamma + self.beta
    return out

In [43]:
ln = LayerNorm(emb_size=a.size(2))
out = ln(a)
out, out.shape, out.mean(dim=2), out.var(dim=2)

(tensor([[[-1.2247,  0.0000,  1.2247],
          [ 0.0000,  0.0000,  0.0000],
          [-0.4877, -0.9058,  1.3935],
          [-1.2247,  0.0000,  1.2247]],
 
         [[-0.7268, -0.6872,  1.4140],
          [-1.0085,  1.3628, -0.3543],
          [ 0.0000,  0.0000,  0.0000],
          [-0.9646, -0.4134,  1.3779]]], grad_fn=<AddBackward0>),
 torch.Size([2, 4, 3]),
 tensor([[ 0.0000e+00,  0.0000e+00,  3.9736e-08,  0.0000e+00],
         [-7.9473e-08, -3.9736e-08,  0.0000e+00,  0.0000e+00]],
        grad_fn=<MeanBackward1>),
 tensor([[1.5000, 0.0000, 1.5000, 1.5000],
         [1.5000, 1.5000, 0.0000, 1.5000]], grad_fn=<VarBackward0>))

## RMSNorm

In [46]:
# RMSNorm - часто применяют в современных LLM
class RMSNorm:
  def __init__(self, emb_size, eps=1e-5):
    self.emb_size = emb_size
    self.eps = eps
    self.gamma = torch.ones(emb_size, requires_grad=True)

  def __call__(self, x):
    rms = torch.sqrt(torch.var(x**2, dim=-1, keepdim=True) + self.eps) # [batch_size, seq_len, 1]
    x_norm = x / rms
    out = self.gamma * x_norm
    return out

In [47]:
rms_norm = RMSNorm(emb_size=a.size(2))
out = rms_norm(a)
out, out.shape, out.mean(dim=2), out.var(dim=2)

(tensor([[[ 3.3180e-01,  4.9770e-01,  6.6360e-01],
          [ 3.1623e+02,  3.1623e+02,  3.1623e+02],
          [ 0.0000e+00, -2.1904e-02,  9.8566e-02],
          [ 4.1619e-01,  4.9942e-01,  5.8266e-01]],
 
         [[ 5.7305e-04,  1.1461e-03,  3.1518e-02],
          [ 8.1216e-03,  5.5227e-02,  2.1116e-02],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [-2.9561e-01, -1.7737e-01,  2.0693e-01]]], grad_fn=<MulBackward0>),
 torch.Size([2, 4, 3]),
 tensor([[ 4.9770e-01,  3.1623e+02,  2.5554e-02,  4.9942e-01],
         [ 1.1079e-02,  2.8155e-02,  0.0000e+00, -8.8684e-02]],
        grad_fn=<MeanBackward1>),
 tensor([[0.0275, 0.0000, 0.0041, 0.0069],
         [0.0003, 0.0006, 0.0000, 0.0690]], grad_fn=<VarBackward0>))