# Experiment 2: Generalizing the Principle Beyond LayerNorm

### Summary
This notebook reproduces the second experiment from the paper, **"Stability and Expression: The Dual-Mechanism of Normalization in Deep Learning."** Having established the foundational effect of LayerNorm in the first experiment, this notebook investigates its specificity. We compare LayerNorm against Batch Normalization (BatchNorm) and a baseline with no normalization.

### Hypothesis
This experiment tests the hypothesis that the implicit regularization effect is not an idiosyncratic property of LayerNorm, but stems from a more general principle shared by other normalization layers: **the active constraint of activation statistics on an instance-by-instance basis.** We predict that:
1.  **Equivalence in Properties:** LayerNorm and BatchNorm will exhibit nearly identical behavior. Both will converge to flat minima and maintain isotropic representations.
2.  **Contrast with Baseline:** Both normalization methods will stand in stark contrast to the unconstrained "No Norm" model, which is expected to converge to a sharp minimum with collapsed representations.

### Methodology
We train three minimalist Transformer models on the synthetic "Low-Rank Associative Recall" task. The models differ only in their normalization strategy:
1.  LayerNorm
2.  BatchNorm
3.  No Norm (Baseline)

We then perform the same analysis as in Experiment 1, plotting the training loss, the loss landscape flatness, and the activation spectrum for all three models to facilitate a direct comparison.

### 1. Setup and Imports

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import copy

### 2. Configuration
Parameters for the synthetic experiment are defined here.

In [None]:


class Config:

    N_ENTITIES = 20
    M_ATTRIBUTES = 20
    VOCAB_SIZE = N_ENTITIES * M_ATTRIBUTES + 2
    THEORETICAL_RANK = N_ENTITIES + M_ATTRIBUTES
    EMBED_DIM = 128
    NUM_HEADS = 1
    SEQ_LENGTH = THEORETICAL_RANK + 1
    LEARNING_RATE = 1e-4
    BATCH_SIZE = 128
    TRAINING_STEPS = 4000
    LANDSCAPE_RESOLUTION = 40
    LANDSCAPE_RANGE = 0.5


config = Config()

### 3. Synthetic Data Generation
This function creates batches for the "Low-Rank Associative Recall" task.

In [None]:


def generate_associative_recall_batch(config):
    """Generates a batch of data for the low-rank associative recall task."""
    keys = torch.arange(0, config.N_ENTITIES * config.M_ATTRIBUTES)
    X = torch.zeros((config.BATCH_SIZE, config.SEQ_LENGTH), dtype=torch.long)
    Y = torch.zeros((config.BATCH_SIZE), dtype=torch.long)

    for i in range(config.BATCH_SIZE):
        context_indices = np.random.choice(len(keys), config.THEORETICAL_RANK, replace=False)
        query_idx = context_indices[0]
        context_indices = context_indices[1:]
        sequence = torch.from_numpy(context_indices)
        query_key = torch.tensor([query_idx])
        X[i, :len(sequence)] = sequence
        X[i, -1] = query_key
        Y[i] = query_idx
    return X, Y

### 4. Flexible Transformer Model
The Transformer block is updated to support different normalization types:
'layernorm', 'batchnorm', or 'none'. This flexibility is key to the experiment.


In [None]:

class MinimalistTransformerBlock(nn.Module):
    """A Transformer block that can use LayerNorm, BatchNorm, or no normalization."""
    def __init__(self, embed_dim, num_heads, normalization_type='layernorm'):
        super().__init__()
        self.normalization_type = normalization_type
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.GELU(),
            nn.Linear(4 * embed_dim, embed_dim)
        )
        if self.normalization_type == 'layernorm':
            self.norm1 = nn.LayerNorm(embed_dim)
            self.norm2 = nn.LayerNorm(embed_dim)
        elif self.normalization_type == 'batchnorm':
            self.norm1 = nn.BatchNorm1d(embed_dim)
            self.norm2 = nn.BatchNorm1d(embed_dim)

    def forward(self, x):
        attn_output, _ = self.attention(x, x, x)
        x = x + attn_output
        if self.normalization_type == 'layernorm':
            x = self.norm1(x)
        elif self.normalization_type == 'batchnorm':
            x = self.norm1(x.permute(0, 2, 1)).permute(0, 2, 1) # Permute for BatchNorm

        ffn_output = self.ffn(x)
        x = x + ffn_output
        if self.normalization_type == 'layernorm':
            x = self.norm2(x)
        elif self.normalization_type == 'batchnorm':
            x = self.norm2(x.permute(0, 2, 1)).permute(0, 2, 1)
        return x

class TransformerModel(nn.Module):
    """A wrapper model with embeddings and a final prediction head."""
    def __init__(self, config, normalization_type='layernorm'):
        super().__init__()
        self.token_embeddings = nn.Embedding(config.VOCAB_SIZE, config.EMBED_DIM)
        self.transformer_block = MinimalistTransformerBlock(
            config.EMBED_DIM, config.NUM_HEADS, normalization_type
        )
        self.to_vocab = nn.Linear(config.EMBED_DIM, config.VOCAB_SIZE)

    def forward(self, x):
        x = self.token_embeddings(x)
        x = self.transformer_block(x)
        query_representation = x[:, -1, :]
        logits = self.to_vocab(query_representation)
        return logits

### 5. Training Loop

In [None]:


def train_model(config, normalization_type, device):
    """Trains a model and returns it along with its loss history."""
    print(f"\n--- Training model with: {normalization_type.upper()} ---")
    model = TransformerModel(config, normalization_type=normalization_type).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    loss_history = []
    for step in range(config.TRAINING_STEPS):
        model.train()
        X, Y = generate_associative_recall_batch(config)
        X, Y = X.to(device), Y.to(device)

        optimizer.zero_grad()
        logits = model(X)
        loss = criterion(logits, Y)
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())

        if step % 500 == 0:
            print(f"Step {step}/{config.TRAINING_STEPS}, Loss: {loss.item():.4f}")

    return model, loss_history

### 6. Analysis Functions

In [None]:


# global variable to store hooked activations
activation_output = None

def get_loss_landscape(model, config, device):
    """Computes a 1D slice of the loss landscape to visualize flatness."""
    norm_type_name = model.transformer_block.normalization_type.upper()
    print(f"Analyzing loss landscape for {norm_type_name}...")
    final_state_dict = copy.deepcopy(model.state_dict())


    direction = [torch.randn_like(p) for p in model.parameters() if p.requires_grad]
    norm = np.sqrt(sum(torch.sum(d**2).item() for d in direction))
    direction = [d / norm for d in direction]

    losses = []
    alphas = np.linspace(-config.LANDSCAPE_RANGE, config.LANDSCAPE_RANGE, config.LANDSCAPE_RESOLUTION)
    criterion = nn.CrossEntropyLoss()

    model.eval()
    with torch.no_grad():
        X, Y = generate_associative_recall_batch(config)
        X, Y = X.to(device), Y.to(device)

        for alpha in alphas:
            temp_state_dict = copy.deepcopy(final_state_dict)
            i = 0

            for param in model.parameters():
                if param.requires_grad:
                    param.data += alpha * direction[i]
                    i += 1

            logits = model(X)
            losses.append(criterion(logits, Y).item())

    model.load_state_dict(final_state_dict)
    return alphas, losses

def get_activation_spectrum(model, config, device):
    """Uses a forward hook to capture and analyze activations via SVD."""
    norm_type_name = model.transformer_block.normalization_type.upper()
    print(f"Analyzing activation spectrum for {norm_type_name}...")
    global activation_output

    def hook(model, input, output):
        global activation_output
        activation_output = output.detach()

    handle = model.transformer_block.register_forward_hook(hook)

    model.eval()
    with torch.no_grad():
        X, _ = generate_associative_recall_batch(config)
        _ = model(X.to(device))

    handle.remove()

    activations_matrix = activation_output.reshape(-1, config.EMBED_DIM)
    _, S, _ = torch.svd(activations_matrix)
    return S.cpu().numpy()

### 7. Main Execution and Data Collection
This block runs the full experiment for all three model variants and stores the results.

In [None]:


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

results = {}
norm_types = ['layernorm', 'batchnorm', 'none']

for norm_type in norm_types:

    config = Config()
    if norm_type == 'none':
        config.LEARNING_RATE = 5e-4

    model, loss_history = train_model(config, normalization_type=norm_type, device=device)
    alphas, landscape = get_loss_landscape(model, config, device)
    spectrum = get_activation_spectrum(model, config, device)
    results[norm_type] = {
        "loss": loss_history,
        "alphas": alphas,
        "landscape": landscape,
        "spectrum": spectrum
    }

### 8. Plotting and Visualization
This cell generates the 1x3 plot corresponding to Figure 2 in the paper.

In [None]:


plt.style.use('seaborn-v0_8-whitegrid')
fig, axes = plt.subplots(1, 3, figsize=(24, 7), sharey=False)
fig.suptitle('Comparative Analysis of Normalization Techniques', fontsize=20, weight='bold')

colors = {'layernorm': 'blue', 'batchnorm': 'green', 'none': 'red'}
labels = {'layernorm': 'LayerNorm', 'batchnorm': 'BatchNorm', 'none': 'No Norm'}
linestyles = {'layernorm': '-', 'batchnorm': (0, (3, 3)), 'none': '-'} # Solid, Dotted, Solid

# Plot 1: Training Loss Curves
ax = axes[0]
for norm_type in norm_types:
    ax.plot(results[norm_type]['loss'], label=labels[norm_type], color=colors[norm_type],
            linestyle=linestyles[norm_type], linewidth=2.5)
ax.set_title('Training Loss Curves', fontsize=16)
ax.set_xlabel('Training Step', fontsize=12)
ax.set_ylabel('Cross-Entropy Loss (Log Scale)', fontsize=12)
ax.set_yscale('log')
ax.legend(fontsize=11)
ax.grid(True, which="both", ls="--")

# Plot 2: Loss Landscape Flatness
ax = axes[1]
for norm_type in norm_types:
    ax.plot(results[norm_type]['alphas'], results[norm_type]['landscape'],
            label=labels[norm_type], color=colors[norm_type], marker='o',
            markersize=5, alpha=0.9, linestyle=linestyles[norm_type])
ax.set_title('1D Loss Landscape Slice', fontsize=16)
ax.set_xlabel('Perturbation along Random Direction (α)', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.legend(fontsize=11)
ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))


# Plot 3: Activation Singular Value Spectrum
ax = axes[2]
for norm_type in norm_types:
    ax.plot(results[norm_type]['spectrum'], label=labels[norm_type], color=colors[norm_type],
            linestyle=linestyles[norm_type], linewidth=2.5)
ax.set_title('Activation Singular Value Spectrum', fontsize=16)
ax.set_xlabel('Singular Value Index', fontsize=12)
ax.set_ylabel('Singular Value Magnitude (Log Scale)', fontsize=12)
ax.set_yscale('log')
ax.legend(fontsize=11)
ax.grid(True, which="both", ls="--")

plt.tight_layout(rect=[0, 0.03, 1, 0.93])
plt.show()