# Batch Normalization

syncbatchnorm on the other hand is just batchnorm code, but is made so that it works with distributed training, with the running varience and running mean being synchronised

In [3]:
import torch
import torch.nn as nn
import numpy as np

class CustomBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momenßtum=0.1):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))

    def forward(self, input):
        if self.training:
            mean = input.mean([0, 2, 3])
            var = input.var([0, 2, 3], unbiased=False)

            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
                self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var

        input_normalized = (input - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None] + self.eps)
        return self.weight[None, :, None, None] * input_normalized + self.bias[None, :, None, None]

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Create a random input tensor
batch_size, channels, height, width = 4, 3, 2, 2
x = torch.randn(batch_size, channels, height, width)

# Create custom BatchNorm2d layer
custom_bn = CustomBatchNorm2d(channels, eps=1e-5, momentum=0.1)

# Create PyTorch's BatchNorm2d layer
torch_bn = nn.BatchNorm2d(channels, eps=1e-5, momentum=0.1)

# Ensure both are in training mode
custom_bn.train()
torch_bn.train()

# Forward pass
custom_output = custom_bn(x)
torch_output = torch_bn(x)

print("Input shape:", x.shape)
print("\nInput:")
print(x)

print("\nCustom BatchNorm output:")
print(custom_output)

print("\nPyTorch BatchNorm output:")
print(torch_output)

print("\nDifference between custom and PyTorch implementations:")
print(torch.abs(custom_output - torch_output).max().item())

print("\nCustom BatchNorm running mean:")
print(custom_bn.running_mean)

print("\nPyTorch BatchNorm running mean:")
print(torch_bn.running_mean)

print("\nCustom BatchNorm running variance:")
print(custom_bn.running_var)

print("\nPyTorch BatchNorm running variance:")
print(torch_bn.running_var)

# Test in eval mode
custom_bn.eval()
torch_bn.eval()

custom_eval_output = custom_bn(x)
torch_eval_output = torch_bn(x)

print("\nDifference between custom and PyTorch implementations (eval mode):")
print(torch.abs(custom_eval_output - torch_eval_output).max().item())

Input shape: torch.Size([4, 3, 2, 2])

Input:
tensor([[[[ 1.9269,  1.4873],
          [ 0.9007, -2.1055]],

         [[ 0.6784, -1.2345],
          [-0.0431, -1.6047]],

         [[-0.7521,  1.6487],
          [-0.3925, -1.4036]]],


        [[[-0.7279, -0.5594],
          [-0.7688,  0.7624]],

         [[ 1.6423, -0.1596],
          [-0.4974,  0.4396]],

         [[-0.7581,  1.0783],
          [ 0.8008,  1.6806]]],


        [[[ 1.2791,  1.2964],
          [ 0.6105,  1.3347]],

         [[-0.2316,  0.0418],
          [-0.2516,  0.8599]],

         [[-1.3847, -0.8712],
          [-0.2234,  1.7174]]],


        [[[ 0.3189, -0.4245],
          [ 0.3057, -0.7746]],

         [[-1.5576,  0.9956],
          [-0.8798, -0.6011]],

         [[-1.2742,  2.1228],
          [-1.2347, -0.4879]]]])

Custom BatchNorm output:
tensor([[[[ 1.5237e+00,  1.1110e+00],
          [ 5.6032e-01, -2.2619e+00]],

         [[ 9.2737e-01, -1.2135e+00],
          [ 1.1991e-01, -1.6278e+00]],

         [[-6.2543e-0

# Group Normalization

- seems to be the same as BN, except we are splitting them into subgroups, and normalising from there. read: https://claude.ai/chat/445c1357-378a-4261-8d49-ae077ef0e190

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Create a random tensor
batch_size = 4
channels = 8
height = 4
width = 4
x = torch.randn(batch_size, channels, height, width)

print("Input tensor shape:", x.shape)

# Define normalization layers
num_groups = 4  # for Group Normalization
batch_norm = nn.BatchNorm2d(channels)
group_norm = nn.GroupNorm(num_groups, channels)

# Apply normalizations
batch_norm_output = batch_norm(x)
group_norm_output = group_norm(x)

print("\nBatch Normalization output shape:", batch_norm_output.shape)
print("Group Normalization output shape:", group_norm_output.shape)

# Manual calculation for Group Normalization
def manual_group_norm(x, num_groups, eps=1e-5):
    batch_size, channels, height, width = x.shape
    x = x.view(batch_size, num_groups, -1)

    mean = x.mean(dim=2, keepdim=True)
    var = x.var(dim=2, keepdim=True)
    x = (x - mean) / torch.sqrt(var + eps)

    return x.view(batch_size, channels, height, width)

manual_group_norm_output = manual_group_norm(x, num_groups)

print("\nManual Group Normalization output shape:", manual_group_norm_output.shape)

# Compare results
print("\nMax difference between PyTorch and manual Group Norm:")
print(torch.max(torch.abs(group_norm_output - manual_group_norm_output)))

# Visualize a slice of the data
slice_idx = 0
print(f"\nOriginal data (first channel, first sample):\n{x[slice_idx, 0]}")
print(f"\nBatch Normalized data (first channel, first sample):\n{batch_norm_output[slice_idx, 0]}")
print(f"\nGroup Normalized data (first channel, first sample):\n{group_norm_output[slice_idx, 0]}")

Input tensor shape: torch.Size([4, 8, 4, 4])

Batch Normalization output shape: torch.Size([4, 8, 4, 4])
Group Normalization output shape: torch.Size([4, 8, 4, 4])

Manual Group Normalization output shape: torch.Size([4, 8, 4, 4])

Max difference between PyTorch and manual Group Norm:
tensor(0.0471, grad_fn=<MaxBackward1>)

Original data (first channel, first sample):
tensor([[-0.9138, -0.6581,  0.0780,  0.5258],
        [-0.4880,  1.1914, -0.8140, -0.7360],
        [-1.4032,  0.0360, -0.0635,  0.6756],
        [-0.0978,  1.8446, -1.1845,  1.3835]])

Batch Normalized data (first channel, first sample):
tensor([[-0.7922, -0.5238,  0.2491,  0.7192],
        [-0.3452,  1.4179, -0.6874, -0.6055],
        [-1.3061,  0.2050,  0.1005,  0.8765],
        [ 0.0645,  2.1037, -1.0764,  1.6197]], grad_fn=<SelectBackward0>)

Group Normalized data (first channel, first sample):
tensor([[-1.0120, -0.7520, -0.0033,  0.4520],
        [-0.5789,  1.1288, -0.9105, -0.8311],
        [-1.5097, -0.0461, -0.14

# Power Norm

- difference in mathematics vs batch norm

overall, very simillar. both have learnable parameters like weights and biases. But slight differences in their varience and usage of means and moments etc





You're right to ask for a clear comparison of the mathematics. Let's break down the key equations for BatchNorm, LayerNorm, and PowerNorm (as implemented in the MaskPowerNorm class you've shared):

1. Batch Normalization:

y = γ * ((x - μ) / √(σ² + ε)) + β

Where:
- μ is the mean of x over the batch dimension
- σ² is the variance of x over the batch dimension
- γ and β are learnable parameters
- ε is a small constant for numerical stability

2. Layer Normalization:

y = γ * ((x - μ) / √(σ² + ε)) + β

Where:
- μ and σ² are calculated across the feature dimension for each sample
- γ and β are learnable parameters
- ε is a small constant for numerical stability

3. Power Normalization (from the provided MaskPowerNorm implementation):

y = γ * (x / √(E[x²] + ε)) + β

Where:
- E[x²] is the second moment (mean of squared values) of x
- γ (weight) and β (bias) are learnable parameters
- ε is a small constant for numerical stability

Key differences:

1. Moment calculation:
   - BatchNorm: Uses first (mean) and second (variance) moments across the batch dimension
   - LayerNorm: Uses first (mean) and second (variance) moments across the feature dimension
   - PowerNorm: Uses only the second moment (E[x²]) across a specified dimension (often feature dimension)

2. Centering:
   - BatchNorm and LayerNorm: Center the data by subtracting the mean (x - μ)
   - PowerNorm: Does not center the data

3. Scaling factor:
   - BatchNorm and LayerNorm: Use standard deviation √(σ² + ε)
   - PowerNorm: Uses root mean square √(E[x²] + ε)

4. Running statistics:
   - BatchNorm: Typically maintains running mean and variance
   - LayerNorm: Does not maintain running statistics
   - PowerNorm: Maintains a running second moment (running_phi in the code)

5. Additional features in PowerNorm:
   - Uses an exponential moving average of gradients (ema_gz)
   - Has separate forward and backward momentum parameters (alpha_fwd and alpha_bkw)
   - Implements a warm-up period

The key innovation in PowerNorm is its use of only the second moment for normalization, which can make it more robust to varying input distributions, especially in NLP tasks. By not centering the data, it preserves some information about the absolute magnitude of the inputs, which can be beneficial in certain scenarios.

The additional features like the gradient EMA and separate forward/backward momenta are designed to stabilize training and potentially improve performance, especially in scenarios with highly variable input distributions.

# Checking Masking

- https://claude.ai/chat/2c32cab8-0c1d-449e-be72-da7a79deaa2a

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimplifiedMaskPowerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        self.running_phi = torch.ones(1, num_features, 1, 1)

    def forward(self, input, pad_mask=None):
        shaped_input = (len(input.shape) == 2)
        if shaped_input:
            input = input.unsqueeze(0)
        T, B, C = input.shape

        if pad_mask is None:
            mask_input = input.clone()
        else:
            bn_mask = ~pad_mask
            bn_mask = bn_mask.transpose(0, 1)

        if pad_mask is not None:
            pad_size = (~bn_mask).sum()
            mask_input = input[bn_mask, :]
        else:
            mask_input = input.clone()

        mask_input = mask_input.reshape(-1, self.num_features)

        input = input.permute(1, 2, 0).contiguous()
        input_shape = input.size()
        input = input.reshape(input.size(0), self.num_features, -1)
        input = input.unsqueeze(-1)

        # Simplified PowerFunction logic
        N, C, H, W = input.size()
        x2 = (mask_input * mask_input).mean(dim=0)
        var = x2.reshape(1, C, 1, 1)

        output = input / (var + self.eps).sqrt()
        output = self.weight.reshape(1,C,1,1) * output + self.bias.reshape(1,C,1,1)

        output = output.reshape(input_shape)
        output = output.permute(2, 0, 1).contiguous()

        if shaped_input:
            output = output.squeeze(0)

        return output

# Create sample input and mask
T, B, C = 5, 2, 8
input_tensor = torch.randn(T, B, C)
pad_mask = torch.tensor([[False, False, False, True, True],
                         [False, False, True, True, True]])

# Initialize and apply SimplifiedMaskPowerNorm
power_norm = SimplifiedMaskPowerNorm(num_features=C)
output = power_norm(input_tensor, pad_mask)

# Print shapes and results
print("Input shape:", input_tensor.shape)
print("Pad mask shape:", pad_mask.shape)
print("Output shape:", output.shape)

# Print a slice of input and output for comparison
print("\nInput slice (first sequence, first batch):")
print(input_tensor[:, 0, :3])  # First 3 features
print("\nOutput slice (first sequence, first batch):")
print(output[:, 0, :3])  # First 3 features

# Verify that padding is handled correctly
print("\nInput at padded position (3, 0, 0):", input_tensor[3, 0, 0].item())
print("Output at padded position (3, 0, 0):", output[3, 0, 0].item())
print("Input at padded position (4, 1, 0):", input_tensor[4, 1, 0].item())
print("Output at padded position (4, 1, 0):", output[4, 1, 0].item())

Input shape: torch.Size([5, 2, 8])
Pad mask shape: torch.Size([2, 5])
Output shape: torch.Size([5, 2, 8])

Input slice (first sequence, first batch):
tensor([[ 1.8595,  2.6221,  0.3691],
        [-0.8032, -1.1209,  0.1956],
        [ 0.0613,  0.0853,  0.7481],
        [-1.8737,  2.3259, -0.9204],
        [-0.4138,  0.5184, -0.7015]])

Output slice (first sequence, first batch):
tensor([[ 1.9468,  1.8341,  0.5079],
        [-0.8409, -0.7841,  0.2692],
        [ 0.0642,  0.0596,  1.0295],
        [-1.9616,  1.6269, -1.2666],
        [-0.4332,  0.3626, -0.9654]], grad_fn=<SliceBackward0>)

Input at padded position (3, 0, 0): -1.8736529350280762
Output at padded position (3, 0, 0): -1.9615864753723145
Input at padded position (4, 1, 0): -1.083811640739441
Output at padded position (4, 1, 0): -1.1346766948699951


# Conclusion

- unsure whats masking doing, the pad_mask, GPT have kinda given me a wrong answer instead, as it doesn't understand how does masking work too.

- idk, it looks like that, pad_mask isn't used for LN, GN, but is used by BN, PN.

- https://claude.ai/chat/2c32cab8-0c1d-449e-be72-da7a79deaa2a

You're correct in your observation. Let's break this down to understand why:

1. Layer Normalization (LN):
   - LN normalizes across the feature dimension for each sample independently.
   - It doesn't need a pad_mask because it operates on each sample separately, so padding in one sample doesn't affect the normalization of another.

2. Group Normalization (GN):
   - Similar to LN, GN operates independently on each sample, just dividing features into groups.
   - Again, no pad_mask is needed as it doesn't compute statistics across the batch dimension.

3. Batch Normalization (BN) and Power Normalization (PN):
   - These compute statistics across the batch dimension.
   - In sequences of varying lengths (common in NLP tasks), padding is often used to make all sequences in a batch the same length.
   - Without a pad_mask, BN and PN would include padding values in their statistics computation, potentially skewing the results.
   - The pad_mask allows these methods to compute accurate statistics by excluding padding from the calculations.

The key difference is that BN and PN compute statistics across samples in a batch, while LN and GN compute statistics within each sample independently. This is why BN and PN need to account for padding in variable-length sequences, but LN and GN do not.

In the context of transformers or other models dealing with variable-length sequences:
- LN and GN can be applied as-is without worrying about padding.
- BN and PN need the additional pad_mask to ensure they're not affected by padding values when computing batch statistics.

This is one reason why LN is often preferred in NLP tasks with transformers - it naturally handles variable-length sequences without the need for explicit padding masks in the normalization layer.

# pad-mask testing

In [6]:
import torch

# Let's say we have a batch size of 3 and sequence length of 5
B, T = 3, 5
pad_token = 0

# Example of x after padding (0 is our pad token)
x = torch.tensor([
    [1, 2, 3, 4, 5],    # Full sequence
    [6, 7, 8, 0, 0],    # Padded sequence
    [9, 0, 0, 0, 0]     # Heavily padded sequence
])

print("x:")
print(x)

# Create pad_mask
pad_mask = (x != pad_token).float()

print("\npad_mask:")
print(pad_mask)

x:
tensor([[1, 2, 3, 4, 5],
        [6, 7, 8, 0, 0],
        [9, 0, 0, 0, 0]])

pad_mask:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 0.],
        [1., 0., 0., 0., 0.]])
