In [1]:
import sys
sys.path.append("../")
import torch

# LayerNorm
from src.components.layer_norm import LayerNorm

**1. Parameters**

In [3]:
batch_size = 2
seq_len = 10
d_model = 128

**2. Embedding creation**

*We create a random tensor to simulate the embeddings input to the attention mechanism.*

In [4]:
# Create input (from embeddings)
# 2 words / 8 tokens each / embedding size 128
x = torch.randn(batch_size, seq_len, d_model)
print(x.shape)
x[0]

torch.Size([2, 10, 128])


tensor([[ 0.3424,  0.2555,  1.0064,  ...,  0.7846,  0.0647,  0.2347],
        [ 0.5301, -0.3582, -0.8590,  ..., -0.0379, -0.6900, -0.3392],
        [ 0.4098,  1.6771,  1.1958,  ...,  2.1459,  0.3085, -0.7250],
        ...,
        [-0.9142, -0.6863,  0.7167,  ..., -1.5258, -1.2470, -0.4650],
        [ 1.4840, -0.5677,  0.3748,  ..., -0.5983,  0.7842, -1.2540],
        [ 0.7234,  0.4865, -0.9085,  ...,  0.3246, -1.1249,  0.3131]])

**3. Normalization Layer**

*In the transformer we will see the exact order of the layers.*

In [5]:
ln = LayerNorm(d_model)
out = ln(x)

[32m2026-02-03 11:43:27.346[0m | [34m[1mDEBUG   [0m | [36msrc.components.layer_norm[0m:[36m__init__[0m:[36m24[0m - [34m[1mLayerNorm: d_model=128, eps=1e-06[0m


In [7]:
out.shape

torch.Size([2, 10, 128])

Problem: During training, values in the network can become very large or very small (exploding/vanishing), making training unstable.

Solution: LayerNorm normalizes the values to have:

Mean = 0
Standard deviation = 1

*We will not achieve an exact mean of 0 and std of 1 because LayerNorm has learnable parameters that allow it to scale and shift the normalized values.*

In [23]:
first_token = out[0]
mean = first_token.mean(-1)
std = first_token.std(-1)

print(mean)
print(std)

tensor([-2.0489e-08,  0.0000e+00,  1.4901e-08,  1.3970e-08,  1.4901e-08,
         0.0000e+00,  1.8626e-08,  3.7253e-09,  8.3819e-09,  0.0000e+00],
       grad_fn=<MeanBackward1>)
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000], grad_fn=<StdBackward0>)


---