# Experiment 3: The Decisive Test - Decomposing the Normalization Mechanism

### Summary
This notebook reproduces the third and final synthetic experiment from the paper, **"Stability and Expression: The Dual-Mechanism of Normalization in Deep Learning."** This is the decisive test designed to isolate and prove the function of the two core components of our proposed dual-mechanism: the geometric constraint and the learnable affine transformation.

### Hypothesis
This experiment tests the hypothesis that normalization's success is a partnership between two parts:
1.  **The Geometric Stabilizer:** The core act of normalizing activation vectors, which prevents representational collapse and guides the optimizer to a flat region of the loss landscape.
2.  **The Expressive Engine:** The learnable affine parameters (gain and bias), which provide the model the freedom to find a high-performance solution *within* that stable region.

To test this, we introduce a **"Manual L2 Norm"** model. We predict:
- It will exhibit the **same stability** as LayerNorm/BatchNorm (flat landscape, isotropic representations) because it shares the geometric constraint.
- It will **fail to learn the task** (high final loss) because it lacks the learnable affine parameters (the expressive engine).

### Methodology
We compare four minimalist Transformer models:
1.  LayerNorm (Complete Mechanism)
2.  BatchNorm (Complete Mechanism)
3.  **Manual L2 Norm (Stabilizer Only)**
4.  No Norm (Unconstrained Baseline)

A successful outcome, where Manual L2 Norm shows stability but poor performance, would provide strong causal evidence for our two-part hypothesis.

### 1. Setup and Imports

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import copy

### 2. Configuration
Parameters for the synthetic experiment are defined here.

In [None]:


class Config:

    N_ENTITIES = 20
    M_ATTRIBUTES = 20
    VOCAB_SIZE = N_ENTITIES * M_ATTRIBUTES + 2 # +2 for PAD and MASK tokens
    THEORETICAL_RANK = N_ENTITIES + M_ATTRIBUTES

    EMBED_DIM = 128
    NUM_HEADS = 1
    SEQ_LENGTH = THEORETICAL_RANK + 1


    LEARNING_RATE = 1e-4
    BATCH_SIZE = 128
    TRAINING_STEPS = 4000


    LANDSCAPE_RESOLUTION = 40
    LANDSCAPE_RANGE = 0.5


config = Config()

### 3. Synthetic Data Generation


In [None]:

def generate_associative_recall_batch(config):
    """Generates a batch of data for the low-rank associative recall task."""
    keys = torch.arange(0, config.N_ENTITIES * config.M_ATTRIBUTES)
    X = torch.zeros((config.BATCH_SIZE, config.SEQ_LENGTH), dtype=torch.long)
    Y = torch.zeros((config.BATCH_SIZE), dtype=torch.long)

    for i in range(config.BATCH_SIZE):
        context_indices = np.random.choice(len(keys), config.THEORETICAL_RANK, replace=False)
        query_idx = context_indices[0]
        context_indices = context_indices[1:]
        sequence = torch.from_numpy(context_indices)
        query_key = torch.tensor([query_idx])
        X[i, :len(sequence)] = sequence
        X[i, -1] = query_key
        Y[i] = query_idx
    return X, Y

### 4. Flexible Transformer Model
This version of the Transformer block includes the crucial 'manual_l2' option.
This is a non-parametric operation that only enforces the geometric constraint.

In [None]:


class MinimalistTransformerBlock(nn.Module):
    """A block supporting LayerNorm, BatchNorm, Manual L2 Norm, or none."""
    def __init__(self, embed_dim, num_heads, normalization_type='layernorm'):
        super().__init__()
        self.normalization_type = normalization_type
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        if self.normalization_type == 'layernorm':
            self.norm1 = nn.LayerNorm(embed_dim)
            self.norm2 = nn.LayerNorm(embed_dim)
        elif self.normalization_type == 'batchnorm':
            self.norm1 = nn.BatchNorm1d(embed_dim)
            self.norm2 = nn.BatchNorm1d(embed_dim)

    def _apply_norm(self, x, norm_layer_idx):
        if self.normalization_type == 'layernorm':
            norm_layer = self.norm1 if norm_layer_idx == 1 else self.norm2
            return norm_layer(x)
        elif self.normalization_type == 'batchnorm':
            norm_layer = self.norm1 if norm_layer_idx == 1 else self.norm2
            return norm_layer(x.permute(0, 2, 1)).permute(0, 2, 1)
        elif self.normalization_type == 'manual_l2':

            return x / torch.norm(x, p=2, dim=-1, keepdim=True).clamp(min=1e-9)
        return x # 'none' case

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = x + attn_output
        x = self._apply_norm(x, 1)

        ffn_output = self.ffn(x)
        x = x + ffn_output
        x = self._apply_norm(x, 2)
        return x

class TransformerModel(nn.Module):
    """A wrapper model with embeddings and a final prediction head."""
    def __init__(self, config, normalization_type='layernorm'):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.VOCAB_SIZE, config.EMBED_DIM)
        self.transformer_block = MinimalistTransformerBlock(
            config.EMBED_DIM, config.NUM_HEADS, normalization_type
        )
        self.to_vocab = nn.Linear(config.EMBED_DIM, config.VOCAB_SIZE)

    def forward(self, x):
        x = self.token_embeddings(x)
        x = self.transformer_block(x)
        query_representation = x[:, -1, :]
        logits = self.to_vocab(query_representation)
        return logits

### 5. Training Loop

In [None]:


def train_model(config, normalization_type, device):
    """Trains a model and returns it along with its loss history."""
    print(f"\n--- Training model with: {normalization_type.upper()} ---")
    model = TransformerModel(config, normalization_type=normalization_type).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    loss_history = []
    for step in range(config.TRAINING_STEPS):
        model.train()
        X, Y = generate_associative_recall_batch(config)
        X, Y = X.to(device), Y.to(device)

        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, Y)
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())

        if step % 500 == 0:
            print(f"Step {step}/{config.TRAINING_STEPS}, Loss: {loss.item():.4f}")

    return model, loss_history

### 6. Analysis Functions (Corrected)
The get_loss_landscape function has been corrected to define 'logits' inside the loop.

In [None]:


# global variable to store hooked activations
activation_output = None

def get_loss_landscape(model, config, device):
    """Computes a 1D slice of the loss landscape to visualize flatness."""
    norm_type_name = model.transformer_block.normalization_type.upper()
    print(f"Analyzing loss landscape for {norm_type_name}...")
    final_state_dict = copy.deepcopy(model.state_dict())

    direction = [torch.randn_like(p) for p in model.parameters() if p.requires_grad]
    norm = np.sqrt(sum(torch.sum(d**2).item() for d in direction))
    direction = [d / norm for d in direction]

    losses = []
    alphas = np.linspace(-config.LANDSCAPE_RANGE, config.LANDSCAPE_RANGE, config.LANDSCAPE_RESOLUTION)
    criterion = nn.CrossEntropyLoss()

    model.eval()
    with torch.no_grad():
        X, Y = generate_associative_recall_batch(config) # Use a consistent test batch
        X, Y = X.to(device), Y.to(device)

        for alpha in alphas:
            temp_state_dict = copy.deepcopy(final_state_dict)
            i = 0
            for param in model.parameters():
                if param.requires_grad:
                    param.data += alpha * direction[i]
                    i += 1

            # *** BUG FIX: Added this line to compute logits with the perturbed model ***
            logits = model(X)
            losses.append(criterion(logits, Y).item())

    model.load_state_dict(final_state_dict) # Restore original weights
    return alphas, losses

def get_activation_spectrum(model, config, device):
    """Uses a forward hook to capture and analyze activations via SVD."""
    norm_type_name = model.transformer_block.normalization_type.upper()
    print(f"Analyzing activation spectrum for {norm_type_name}...")
    global activation_output

    def hook(model, input, output):
        global activation_output
        activation_output = output.detach()

    handle = model.transformer_block.register_forward_hook(hook)

    model.eval()
    with torch.no_grad():
        X, _ = generate_associative_recall_batch(config)
        _ = model(X.to(device))

    handle.remove()

    activations_matrix = activation_output.reshape(-1, config.EMBED_DIM)
    _, S, _ = torch.svd(activations_matrix)
    return S.cpu().numpy()

### 7. Main Execution and Data Collection
This block runs the full experiment for all four model variants.

In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

results = {}
norm_types = ['layernorm', 'batchnorm', 'manual_l2', 'none']

for norm_type in norm_types:
    config = Config()
    if norm_type == 'none':
        config.LEARNING_RATE = 5e-4

    model, loss_history = train_model(config, normalization_type=norm_type, device=device)
    alphas, landscape = get_loss_landscape(model, config, device)
    spectrum = get_activation_spectrum(model, config, device)
    results[norm_type] = {
        "loss": loss_history,
        "alphas": alphas,
        "landscape": landscape,
        "spectrum": spectrum
    }

### 8. Plotting and Visualization
This cell generates the 1x3 plot corresponding to Figure 3 in the paper.

In [None]:


plt.style.use('seaborn-v0_8-whitegrid')
fig, axes = plt.subplots(1, 3, figsize=(24, 7), sharey=False)
fig.suptitle('The Decisive Test: Decomposing the Normalization Mechanism', fontsize=20, weight='bold')

colors = {'layernorm': 'blue', 'batchnorm': 'green', 'manual_l2': 'purple', 'none': 'red'}
labels = {
    'layernorm':   'LayerNorm',
    'batchnorm':   'BatchNorm',
    'manual_l2':   'Manual L2 Norm (Pure Geometry)',
    'none':        'No Norm (Baseline)'
}
linestyles = {'layernorm': '-', 'batchnorm': '--', 'manual_l2': ':', 'none': '-'}

# Plot 1: Training Loss Curves
ax = axes[0]
for norm_type in norm_types:
    ax.plot(results[norm_type]['loss'], label=labels[norm_type], color=colors[norm_type],
            linestyle=linestyles[norm_type], linewidth=2.5)
ax.set_title('Training Loss Curves', fontsize=16)
ax.set_xlabel('Training Step', fontsize=12)
ax.set_ylabel('Cross-Entropy Loss (Log Scale)', fontsize=12)
ax.set_yscale('log')
ax.legend(fontsize=11)
ax.grid(True, which="both", ls="--")

# Plot 2: Loss Landscape Flatness
ax = axes[1]
for norm_type in norm_types:
    ax.plot(results[norm_type]['alphas'], results[norm_type]['landscape'],
            label=labels[norm_type], color=colors[norm_type], marker='o',
            markersize=5, alpha=0.9, linestyle=linestyles[norm_type])
ax.set_title('1D Loss Landscape Slice (Flatness)', fontsize=16)
ax.set_xlabel('Perturbation along Random Direction (Î±)', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.legend(fontsize=11)
ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

# Plot 3: Activation Singular Value Spectrum
ax = axes[2]
for norm_type in norm_types:
    ax.plot(results[norm_type]['spectrum'], label=labels[norm_type], color=colors[norm_type],
            linestyle=linestyles[norm_type], linewidth=2.5)
ax.set_title('Activation Singular Value Spectrum (Isotropy)', fontsize=16)
ax.set_xlabel('Singular Value Index', fontsize=12)
ax.set_ylabel('Singular Value Magnitude (Log Scale)', fontsize=12)
ax.set_yscale('log')
ax.legend(fontsize=11)
ax.grid(True, which="both", ls="--")

plt.tight_layout(rect=[0, 0.03, 1, 0.93])
plt.show()