In [1]:
import torch
import torch.nn as nn
import numpy as np
import math
from timm.models.vision_transformer import PatchEmbed, Attention, Mlp

from torch.jit import Final

In [162]:
class LayerNormCustom(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=False, bias=True, device=None, dtype=None):
        super(LayerNormCustom, self).__init__()
        if isinstance(normalized_shape, (tuple, list)):
            self.normalized_shape = normalized_shape
        else:
            self.normalized_shape = (normalized_shape,)
        self.eps = eps
        self.elementwise_affine = elementwise_affine
    
    def forward(self, x):
        # mean = x.mean(dim=-1, keepdim=True)
        mean = torch.mean(x, dim=-1,keepdim=True)
        var = torch.var(x, dim=-1, keepdim=True, unbiased=False)
        # var = x.var(dim=-1, keepdim=True, unbiased=False)
        x_normalized = (x - mean) / torch.sqrt(var + self.eps)
        return x_normalized

In [187]:
print("Running Basic Equivalence Test...")
x = torch.randn(2, 3, 4, requires_grad=True)
norm1 = nn.LayerNorm(4)
norm2 = LayerNormCustom(4)
output1 = norm1(x)
output2 = norm2(x)
assert torch.allclose(output1, output2, atol=1e-5), f"Outputs are not close! {output1} {output2}"
print("Basic Equivalence Test passed!")

Running Basic Equivalence Test...
Basic Equivalence Test passed!


In [189]:
print("Running Shape Variability Test...")
shapes = [(4,), (2, 4), (2, 3, 4), (2, 3, 4, 5)]
for shape in shapes:
    x = torch.randn(*shape, requires_grad=True)
    norm1 = nn.LayerNorm(shape[-1])
    norm2 = LayerNormCustom(shape[-1])
    assert torch.allclose(norm1(x), norm2(x), atol=1e-4), f"Mismatch for shape {shape} nums {norm1(x)} {norm2(x)}"
print("Shape Variability Test passed!")

Running Shape Variability Test...
Shape Variability Test passed!


In [192]:
print("Running Gradient Consistency Test...")
x = torch.randn(2, 3, 4, requires_grad=True)
norm1 = nn.LayerNorm(4)
norm2 = LayerNormCustom(4)
output1 = norm1(x)
output2 = norm2(x)
loss1 = output1.sum()
loss2 = output2.sum()
grad1 = torch.autograd.grad(loss1, x, create_graph=True)[0]
grad2 = torch.autograd.grad(loss2, x, create_graph=True)[0]
assert torch.allclose(grad1, grad2, atol=1e-6), "Gradient mismatch!"
print("Gradient Consistency Test passed!")

Running Gradient Consistency Test...
Gradient Consistency Test passed!


In [194]:
print("Running Edge Cases Test...")
# All zeros
x = torch.zeros(2, 3, 4)
norm1 = nn.LayerNorm(4)
norm2 = LayerNormCustom(4)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6), "Mismatch for all zeros"
# Very large values
x = torch.full((2, 3, 4), 1e6)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6), "Mismatch for large values"
# Very small values
x = torch.full((2, 3, 4), 1e-6)
assert torch.allclose(norm1(x), norm2(x), atol=1e-6), "Mismatch for small values"
print("Edge Cases Test passed!")

Running Edge Cases Test...
Edge Cases Test passed!


In [198]:
print("Running Performance Test...")
x = torch.randn(512, 512, 512)
norm1 = nn.LayerNorm(512)
norm2 = LayerNormCustom(512)

import time
start = time.time()
norm1(x)
torch.cuda.synchronize()
print(f"nn.LayerNorm: {time.time() - start:.6f} seconds")

start = time.time()
norm2(x)
torch.cuda.synchronize()
print(f"CustomLayerNorm: {time.time() - start:.6f} seconds")


Running Performance Test...
nn.LayerNorm: 0.043641 seconds
CustomLayerNorm: 0.247132 seconds
