In [1]:
import torch
import torch.nn as nn

In [2]:
# ----------------------------
# 1) BatchNorm (N, C, *spatial)
# ----------------------------
class BatchNormManual(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats

        if affine:
            self.weight = nn.Parameter(torch.ones(num_features))  # gamma
            self.bias   = nn.Parameter(torch.zeros(num_features)) # beta
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        if track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var',  torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var',  None)
            self.register_parameter('num_batches_tracked', None)

    def forward(self, x):
        # x: (N, C, *S)
        original_dtype = x.dtype
        x_float = x.float()

        # axes to reduce: all except channel (dim=1)
        reduce_dims = [0] + list(range(2, x.dim()))
        if self.training:
            # batch statistics
            batch_mean = x_float.mean(dim=reduce_dims, keepdim=False)
            batch_var  = x_float.var(dim=reduce_dims, unbiased=False, keepdim=False)

            if self.track_running_stats:
                with torch.no_grad():
                    self.num_batches_tracked += 1
                    # EMA update
                    self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
                    self.running_var  = (1 - self.momentum) * self.running_var  + self.momentum * batch_var

            mean = batch_mean
            var  = batch_var
        else:
            # inference: use running stats if available, else fall back to batch stats
            if self.track_running_stats and self.running_mean is not None:
                mean = self.running_mean
                var  = self.running_var
            else:
                mean = x_float.mean(dim=reduce_dims, keepdim=False)
                var  = x_float.var(dim=reduce_dims, unbiased=False, keepdim=False)

        # reshape (C,) -> (1,C,1,1,...) for broadcasting
        shape = [1, -1] + [1] * (x.dim() - 2)
        mean = mean.view(*shape)
        var  = var.view(*shape)

        x_hat = (x_float - mean) / torch.sqrt(var + self.eps)
        if self.affine:
            weight = self.weight.view(*shape).to(x_hat.dtype)
            bias   = self.bias.view(*shape).to(x_hat.dtype)
            y = weight * x_hat + bias
        else:
            y = x_hat

        return y.to(original_dtype)


# -----------------------------------
# 2) LayerNorm over the last K dims
# -----------------------------------
class LayerNormManual(nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True):
        """
        normalized_shape: int or tuple of ints for the trailing dimensions to normalize.
        Example: hidden_size (int) or (H, W) for 2D per-sample spatial LN.
        """
        super().__init__()
        if isinstance(normalized_shape, int):
            normalized_shape = (normalized_shape,)
        self.normalized_shape = tuple(normalized_shape)
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(*self.normalized_shape))  # gamma
            self.bias   = nn.Parameter(torch.zeros(*self.normalized_shape)) # beta
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

    def forward(self, x):
        # reduce over the last len(normalized_shape) dims
        reduce_dims = tuple(range(x.dim() - len(self.normalized_shape), x.dim()))
        x_float = x.float()
        mean = x_float.mean(dim=reduce_dims, keepdim=True)
        var  = x_float.var(dim=reduce_dims, unbiased=False, keepdim=True)

        x_hat = (x_float - mean) / torch.sqrt(var + self.eps)
        if self.elementwise_affine:
            y = x_hat * self.weight + self.bias
        else:
            y = x_hat
        return y.to(x.dtype)


# ---------------------------
# 3) RMSNorm (no centering)
# ---------------------------
class RMSNormManual(nn.Module):
    def __init__(self, dim, eps=1e-8, elementwise_affine=True):
        """
        dim: size of the last (normalized) dimension.
        Only gamma is typically used; beta is optional (default off here).
        """
        super().__init__()
        self.dim = dim
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))  # gamma
        else:
            self.register_parameter('weight', None)

    def forward(self, x):
        # normalize over last dimension
        x_float = x.float()
        # rms = sqrt(mean(x^2) + eps)
        rms = torch.sqrt(x_float.pow(2).mean(dim=-1, keepdim=True) + self.eps)
        y = x_float / rms
        if self.elementwise_affine:
            y = y * self.weight.view(*([1] * (x.dim() - 1)), self.dim)
        return y.to(x.dtype)

In [3]:
# ----- LayerNorm check -----
x = torch.randn(4, 16, 64)
ln_ref = nn.LayerNorm(64)
ln_my  = LayerNormManual(64)
# copy params to compare apples-to-apples
ln_my.weight.data.copy_(ln_ref.weight.data)
ln_my.bias.data.copy_(ln_ref.bias.data)
print("LN max diff:", (ln_ref(x) - ln_my(x)).abs().max().item())

# ----- BatchNorm2d check -----
x = torch.randn(8, 32, 16, 16)
bn_ref = nn.BatchNorm2d(32, affine=True, momentum=0.1, eps=1e-5, track_running_stats=True)
bn_my  = BatchNormManual(32, affine=True, momentum=0.1, eps=1e-5, track_running_stats=True)
# train mode: both update running stats from same batch
bn_ref.train(); bn_my.train()
# align parameters
bn_my.weight.data.copy_(bn_ref.weight.data)
bn_my.bias.data.copy_(bn_ref.bias.data)
y_ref = bn_ref(x)
y_my  = bn_my(x)
print("BN train max diff:", (y_ref - y_my).abs().max().item())

# switch to eval: both should use their accumulated running stats
bn_ref.eval(); bn_my.eval()
y_ref2 = bn_ref(x)
y_my2  = bn_my(x)
print("BN eval max diff:", (y_ref2 - y_my2).abs().max().item())

# ----- RMSNorm check (no official ref, just shape/run) -----
x = torch.randn(2, 5, 768, dtype=torch.bfloat16)  # bf16 friendly
rms = RMSNormManual(768)
y = rms(x)
print("RMSNorm output shape:", y.shape)

LN max diff: 4.76837158203125e-07
BN train max diff: 4.76837158203125e-07
BN eval max diff: 0.0001239776611328125
RMSNorm output shape: torch.Size([2, 5, 768])
