# Experiment 1: The Foundational Effect of Layer Normalization

### Summary
This notebook reproduces the foundational experiment from the paper, **"Stability and Expression: The Dual-Mechanism of Normalization in Deep Learning."** We investigate the fundamental impact of Layer Normalization (LayerNorm) on the training dynamics and converged solution of a minimalist Transformer model.

### Hypothesis
The experiment tests the hypothesis that LayerNorm acts as a powerful implicit regularizer. Specifically, we expect to observe that:
1.  **Loss Landscape:** The model with LayerNorm will converge to a wide, **flat minimum**, which is characteristic of generalizable solutions. The model without normalization will converge to a sharp, narrow minimum, indicative of memorization.
2.  **Internal Representations:** The model with LayerNorm will maintain high-dimensional, **isotropic** internal representations. The model without normalization will suffer from **representational collapse**, where activations collapse into a low-dimensional subspace.

### Methodology
We compare two identical minimalist Transformer models trained on a synthetic "Low-Rank Associative Recall" task. The only difference between the models is the presence or absence of LayerNorm layers. We then analyze and plot three key aspects:
1.  The training loss curves.
2.  A 1D slice of the final loss landscape to measure flatness.
3.  The singular value spectrum of the internal activations to measure isotropy.

### Setup and Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import copy

# Configuration
###### All parameters for the synthetic experiment are defined here

In [None]:
class Config:
    N_ENTITIES = 20
    M_ATTRIBUTES = 20
    VOCAB_SIZE = N_ENTITIES * M_ATTRIBUTES + 2
    THEORETICAL_RANK = N_ENTITIES + M_ATTRIBUTES

    EMBED_DIM = 128
    NUM_HEADS = 1
    SEQ_LENGTH = THEORETICAL_RANK + 1
    NUM_BLOCKS = 1
    LEARNING_RATE = 1e-4
    BATCH_SIZE = 128
    TRAINING_STEPS = 4000
    LANDSCAPE_RESOLUTION = 50
    LANDSCAPE_RANGE = 0.5


config = Config()

### 3. Synthetic Data Generation
This function creates batches for the "Low-Rank Associative Recall" task.
The task is designed to have a known low-rank structure, making it easy to distinguish between models that learn the underlying structure versus those that simply memorize the data.

In [None]:


def generate_associative_recall_batch(config):
    """Generates a batch of data for the low-rank associative recall task."""
    keys = torch.arange(0, config.N_ENTITIES * config.M_ATTRIBUTES)

    X = torch.zeros((config.BATCH_SIZE, config.SEQ_LENGTH), dtype=torch.long)
    Y = torch.zeros((config.BATCH_SIZE), dtype=torch.long)

    for i in range(config.BATCH_SIZE):
        context_indices = np.random.choice(len(keys), config.THEORETICAL_RANK, replace=False)
        query_idx = context_indices[0]
        context_indices = context_indices[1:]

        sequence = torch.from_numpy(context_indices)
        query_key = torch.tensor([query_idx])

        X[i, :len(sequence)] = sequence
        X[i, -1] = query_key
        Y[i] = query_idx

    return X, Y

### 4. Minimalist Transformer Model
We define a simple Transformer with a single attention block.
The `use_layernorm` flag allows us to instantiate two versions of this
model: one with LayerNorm and one without.

In [None]:


class MinimalistTransformerBlock(nn.Module):
    """A single Transformer block that allows enabling/disabling LayerNorm."""
    def __init__(self, embed_dim, num_heads, use_layernorm=True):
        super().__init__()
        self.use_layernorm = use_layernorm
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        if self.use_layernorm:
            self.norm1 = nn.LayerNorm(embed_dim)
            self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = x + attn_output
        if self.use_layernorm:
            x = self.norm1(x)

        ffn_output = self.ffn(x)
        x = x + ffn_output
        if self.use_layernorm:
            x = self.norm2(x)
        return x

class TransformerModel(nn.Module):
    """A wrapper model with embeddings and a final prediction head."""
    def __init__(self, config, use_layernorm=True):
        super().__init__()
        self.config = config
        self.token_embeddings = nn.Embedding(config.VOCAB_SIZE, config.EMBED_DIM)
        self.transformer_block = MinimalistTransformerBlock(
            config.EMBED_DIM, config.NUM_HEADS, use_layernorm
        )
        self.to_vocab = nn.Linear(config.EMBED_DIM, config.VOCAB_SIZE)

    def forward(self, x):
        x = self.token_embeddings(x)
        x = self.transformer_block(x)
        query_representation = x[:, -1, :]
        logits = self.to_vocab(query_representation)
        return logits

### 5. Training Loop

In [None]:


def train_model(config, use_layernorm, device):
    """Trains a model and returns it along with its loss history."""
    print(f"\n--- Training model {'WITH' if use_layernorm else 'WITHOUT'} LayerNorm ---")
    model = TransformerModel(config, use_layernorm=use_layernorm).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    loss_history = []
    for step in range(config.TRAINING_STEPS):
        model.train()
        X, Y = generate_associative_recall_batch(config)
        X, Y = X.to(device), Y.to(device)

        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, Y)
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())

        if step % 500 == 0:
            print(f"Step {step}/{config.TRAINING_STEPS}, Loss: {loss.item():.4f}")

    return model, loss_history

### 6. Analysis Functions
These functions are used to probe the properties of the trained models.


In [None]:

# global variable to store hooked activations
activation_output = None

def get_loss_landscape(model, config, device):
    """Computes a 1D slice of the loss landscape to visualize flatness."""
    print("Analyzing loss landscape...")
    final_state_dict = copy.deepcopy(model.state_dict())

    direction = [torch.randn_like(p) for p in model.parameters() if p.requires_grad]
    norm = np.sqrt(sum(torch.sum(d**2).item() for d in direction))
    direction = [d / norm for d in direction]

    losses = []
    alphas = np.linspace(-config.LANDSCAPE_RANGE, config.LANDSCAPE_RANGE, config.LANDSCAPE_RESOLUTION)
    criterion = nn.CrossEntropyLoss()

    model.eval()
    with torch.no_grad():
        X, Y = generate_associative_recall_batch(config) # here  use a consistent test batch for landscape
        X, Y = X.to(device), Y.to(device)

        for alpha in alphas:
            temp_state_dict = copy.deepcopy(final_state_dict)
            i = 0
            for name, param in temp_state_dict.items():
                if model.get_parameter(name).requires_grad:
                    param.data += alpha * direction[i]
                    i += 1
            model.load_state_dict(temp_state_dict)

            logits = model(X)
            loss = criterion(logits, Y).item()
            losses.append(loss)

    model.load_state_dict(final_state_dict)
    return alphas, losses

def get_activation_spectrum(model, config, device):
    """Uses a forward hook to capture and analyze activations via SVD."""
    print("Analyzing activation spectrum...")
    global activation_output

    def hook(model, input, output):
        global activation_output
        activation_output = output.detach()

    handle = model.transformer_block.register_forward_hook(hook)

    model.eval()
    with torch.no_grad():
        X, _ = generate_associative_recall_batch(config)
        _ = model(X.to(device))

    handle.remove()

    activations_matrix = activation_output.reshape(-1, config.EMBED_DIM)
    _, S, _ = torch.svd(activations_matrix)
    return S.cpu().numpy()

### 7. Main Execution and Data Collection
This block runs the full experiment for both model variants and stores the results.

In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# first exp
config = Config()
model_ln, loss_ln = train_model(config, use_layernorm=True, device=device)
alphas_ln, landscape_ln = get_loss_landscape(model_ln, config, device)
spectrum_ln = get_activation_spectrum(model_ln, config, device)

# second exp
config = Config()

config.LEARNING_RATE = 5e-4
model_no_ln, loss_no_ln = train_model(config, use_layernorm=False, device=device)
alphas_no_ln, landscape_no_ln = get_loss_landscape(model_no_ln, config, device)
spectrum_no_ln = get_activation_spectrum(model_no_ln, config, device)

### 8. Plotting and Visualization
This cell generates the 2x2 plot corresponding to Figure 1 in the paper.

In [None]:
plt.style.use('seaborn-v0_8-whitegrid')
fig, axes = plt.subplots(1, 3, figsize=(24, 7), sharey=False)  # we have changed it here for clarity
fig.suptitle("Analysis of Layer Normalization's Implicit Regularization", fontsize=20, weight='bold')


ax = axes[0]
ax.plot(loss_ln, label='With LayerNorm', color='blue', linewidth=2)
ax.plot(loss_no_ln, label='Without LayerNorm', color='red', linewidth=2, alpha=0.9)
ax.set_title('Training Loss Curves', fontsize=16)
ax.set_xlabel('Training Step', fontsize=12)
ax.set_ylabel('Cross-Entropy Loss (Log Scale)', fontsize=12)
ax.set_yscale('log')
ax.legend(fontsize=11)
ax.grid(True, which="both", ls="--")

ax = axes[1]
ax.plot(alphas_ln, landscape_ln, label='With LayerNorm', color='blue', marker='o', markersize=5)
ax.plot(alphas_no_ln, landscape_no_ln, label='Without LayerNorm', color='red', marker='x', markersize=5)
ax.set_title('1D Loss Landscape Slice', fontsize=16)
ax.set_xlabel('Perturbation along Random Direction (α)', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.legend(fontsize=11)
ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

ax = axes[2]
ax.plot(spectrum_ln, label='With LayerNorm (Isotropic)', color='blue', linewidth=2)
ax.plot(spectrum_no_ln, label='Without LayerNorm (Anisotropic)', color='red', linewidth=2)
ax.set_title('Activation Singular Value Spectrum', fontsize=16)
ax.set_xlabel('Singular Value Index', fontsize=12)
ax.set_ylabel('Singular Value Magnitude (Log Scale)', fontsize=12)
ax.set_yscale('log')
ax.legend(fontsize=11)
ax.grid(True, which="both", ls="--")


plt.tight_layout(rect=[0, 0.03, 1, 0.93])
plt.show()