<a href="https://colab.research.google.com/github/MLDreamer/AIMathematicallyexplained/blob/main/Value_Matrix_Interactive_Playground.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
Value Matrix as Learned Database: Interactive Playground
Run this in Google Colab to experiment with Value matrices and compression

Author: [Your Name]
Article: [Link to Medium article]
"""

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import seaborn as sns

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

print("=" * 60)
print("Value Matrix as Database: Interactive Experiments")
print("=" * 60)

# ============================================================================
# EXPERIMENT 1: Value Matrix as Lossy Compression
# ============================================================================

def experiment_1_compression():
    """
    Demonstrate how Value matrix compresses information
    """
    print("\n" + "=" * 60)
    print("EXPERIMENT 1: Value Matrix as Lossy Compression")
    print("=" * 60)

    # Original embedding dimension
    d_model = 512
    # Compressed dimension (per head)
    d_v = 64
    # Number of tokens
    n_tokens = 10

    # Create random embeddings (simulating input tokens)
    X = torch.randn(n_tokens, d_model)

    # Create Value projection matrix
    W_V = torch.randn(d_model, d_v) / np.sqrt(d_model)

    # Project to Values
    V = torch.mm(X, W_V)

    # Measure information loss via reconstruction
    # Try to reconstruct X from V
    W_V_pinv = torch.pinverse(W_V)
    X_reconstructed = torch.mm(V, W_V_pinv)

    # Calculate reconstruction error
    mse = torch.mean((X - X_reconstructed) ** 2).item()
    relative_error = mse / torch.mean(X ** 2).item()

    print(f"\nOriginal dimension: {d_model}")
    print(f"Compressed dimension: {d_v}")
    print(f"Compression ratio: {d_model / d_v:.1f}x")
    print(f"\nReconstruction MSE: {mse:.6f}")
    print(f"Relative error: {relative_error:.2%}")
    print(f"\nInformation loss: {relative_error:.2%} of original signal")

    # Analyze singular values
    U, S, Vh = torch.svd(W_V)

    print(f"\nSingular value spectrum (top 10):")
    print(S[:10].numpy())

    # Calculate effective rank
    total_variance = torch.sum(S ** 2)
    cumsum = torch.cumsum(S ** 2, dim=0) / total_variance
    effective_rank = torch.sum(cumsum < 0.95).item() + 1

    print(f"\nEffective rank (95% variance): {effective_rank}")
    print(f"Out of maximum rank: {d_v}")

    return X, V, W_V, S

# ============================================================================
# EXPERIMENT 2: Multi-Head vs Single-Head Value Storage
# ============================================================================

def experiment_2_multihead():
    """
    Compare single large V matrix vs multiple smaller V matrices
    """
    print("\n" + "=" * 60)
    print("EXPERIMENT 2: Multi-Head Value Storage")
    print("=" * 60)

    d_model = 512
    n_tokens = 20
    X = torch.randn(n_tokens, d_model)

    # Single head with large d_v
    d_v_single = 512
    W_V_single = torch.randn(d_model, d_v_single) / np.sqrt(d_model)
    V_single = torch.mm(X, W_V_single)

    # Multiple heads with smaller d_v each
    n_heads = 8
    d_v_per_head = 64  # 8 * 64 = 512 total, same as single head

    V_multi = []
    for i in range(n_heads):
        W_V_head = torch.randn(d_model, d_v_per_head) / np.sqrt(d_model)
        V_head = torch.mm(X, W_V_head)
        V_multi.append(V_head)

    V_multi_concat = torch.cat(V_multi, dim=1)

    print(f"\nSingle head:")
    print(f"  V shape: {V_single.shape}")
    print(f"  Total dimensions: {d_v_single}")

    print(f"\nMulti-head ({n_heads} heads):")
    print(f"  V shape per head: {d_v_per_head}")
    print(f"  V concatenated shape: {V_multi_concat.shape}")
    print(f"  Total dimensions: {n_heads * d_v_per_head}")

    # Calculate diversity between heads
    correlations = []
    for i in range(n_heads):
        for j in range(i + 1, n_heads):
            # Correlation between head i and j
            corr = torch.corrcoef(torch.stack([
                V_multi[i].flatten(),
                V_multi[j].flatten()
            ]))[0, 1].item()
            correlations.append(abs(corr))

    avg_correlation = np.mean(correlations)

    print(f"\nAverage absolute correlation between heads: {avg_correlation:.3f}")
    print(f"Lower correlation = more diverse information storage")

    if avg_correlation < 0.3:
        print("✓ Heads are storing diverse information!")
    else:
        print("⚠ Heads might be redundant")

    return V_single, V_multi

# ============================================================================
# EXPERIMENT 3: Retrieval with Different Attention Patterns
# ============================================================================

def experiment_3_retrieval():
    """
    Show how attention weights determine what information gets retrieved from V
    """
    print("\n" + "=" * 60)
    print("EXPERIMENT 3: Information Retrieval via Attention")
    print("=" * 60)

    n_tokens = 8
    d_v = 64

    # Create Value vectors with distinct patterns
    V = torch.zeros(n_tokens, d_v)

    # Token 0, 1: "person" information (first 32 dims)
    V[0, :32] = 1.0
    V[1, :32] = 0.8

    # Token 2, 3: "location" information (middle 32 dims)
    V[2, 16:48] = 1.0
    V[3, 16:48] = 0.7

    # Token 4, 5: "action" information (last 32 dims)
    V[4, 32:] = 1.0
    V[5, 32:] = 0.9

    # Rest are noise
    V[6:] += torch.randn(2, d_v) * 0.1

    print("\nValue matrix structure:")
    print(f"  Tokens 0-1: Person information (dimensions 0-31)")
    print(f"  Tokens 2-3: Location information (dimensions 16-47)")
    print(f"  Tokens 4-5: Action information (dimensions 32-63)")
    print(f"  Tokens 6-7: Noise")

    # Scenario 1: Uniform attention (no focus)
    alpha_uniform = torch.ones(n_tokens) / n_tokens
    output_uniform = torch.mv(V.t(), alpha_uniform)

    # Scenario 2: Focus on person tokens
    alpha_person = torch.zeros(n_tokens)
    alpha_person[0:2] = 0.5
    output_person = torch.mv(V.t(), alpha_person)

    # Scenario 3: Focus on location tokens
    alpha_location = torch.zeros(n_tokens)
    alpha_location[2:4] = 0.5
    output_location = torch.mv(V.t(), alpha_location)

    # Scenario 4: Mixed focus (person + action)
    alpha_mixed = torch.zeros(n_tokens)
    alpha_mixed[0:2] = 0.3
    alpha_mixed[4:6] = 0.2
    output_mixed = torch.mv(V.t(), alpha_mixed)

    print("\n" + "-" * 60)
    print("Retrieval Results:")
    print("-" * 60)

    print(f"\nUniform attention (no focus):")
    print(f"  Output norm: {torch.norm(output_uniform):.3f}")
    print(f"  Person dims [0-31]: {torch.norm(output_uniform[:32]):.3f}")
    print(f"  Location dims [16-47]: {torch.norm(output_uniform[16:48]):.3f}")
    print(f"  Action dims [32-63]: {torch.norm(output_uniform[32:]):.3f}")
    print("  → Blurry, unfocused retrieval")

    print(f"\nFocused on person tokens:")
    print(f"  Output norm: {torch.norm(output_person):.3f}")
    print(f"  Person dims [0-31]: {torch.norm(output_person[:32]):.3f}")
    print(f"  Location dims [16-47]: {torch.norm(output_person[16:48]):.3f}")
    print(f"  Action dims [32-63]: {torch.norm(output_person[32:]):.3f}")
    print("  → Strong person information retrieved!")

    print(f"\nFocused on location tokens:")
    print(f"  Output norm: {torch.norm(output_location):.3f}")
    print(f"  Person dims [0-31]: {torch.norm(output_location[:32]):.3f}")
    print(f"  Location dims [16-47]: {torch.norm(output_location[16:48]):.3f}")
    print(f"  Action dims [32-63]: {torch.norm(output_location[32:]):.3f}")
    print("  → Strong location information retrieved!")

    print(f"\nMixed focus (person + action):")
    print(f"  Output norm: {torch.norm(output_mixed):.3f}")
    print(f"  Person dims [0-31]: {torch.norm(output_mixed[:32]):.3f}")
    print(f"  Location dims [16-47]: {torch.norm(output_mixed[16:48]):.3f}")
    print(f"  Action dims [32-63]: {torch.norm(output_mixed[32:]):.3f}")
    print("  → Both person and action information retrieved!")

    return V, [alpha_uniform, alpha_person, alpha_location, alpha_mixed]

# ============================================================================
# EXPERIMENT 4: KV Cache Memory Calculation
# ============================================================================

def experiment_4_kv_cache():
    """
    Calculate actual KV cache memory requirements
    """
    print("\n" + "=" * 60)
    print("EXPERIMENT 4: KV Cache Memory Requirements")
    print("=" * 60)

    # Model configurations
    configs = {
        "GPT-2 Small": {"n_layers": 12, "n_heads": 12, "d_v": 64},
        "GPT-2 Large": {"n_layers": 36, "n_heads": 20, "d_v": 64},
        "LLaMA-7B": {"n_layers": 32, "n_heads": 32, "d_v": 128},
        "LLaMA-13B": {"n_layers": 40, "n_heads": 40, "d_v": 128},
    }

    sequence_lengths = [512, 1024, 2048, 4096]
    precision_bytes = 2  # fp16

    print("\nMemory per sequence (in MB):")
    print("-" * 60)

    for model_name, config in configs.items():
        print(f"\n{model_name}:")
        print(f"  Layers: {config['n_layers']}, Heads: {config['n_heads']}, d_v: {config['d_v']}")

        for seq_len in sequence_lengths:
            # K and V cache: 2 * seq_len * d_v * n_heads * n_layers * precision
            memory_bytes = 2 * seq_len * config['d_v'] * config['n_heads'] * config['n_layers'] * precision_bytes
            memory_mb = memory_bytes / (1024 ** 2)

            print(f"    Seq len {seq_len}: {memory_mb:.1f} MB")

    # Multi-Query Attention comparison
    print("\n" + "=" * 60)
    print("Multi-Query Attention (MQA) Savings:")
    print("=" * 60)

    model_name = "LLaMA-7B"
    config = configs[model_name]
    seq_len = 2048

    # Standard attention
    standard_memory = 2 * seq_len * config['d_v'] * config['n_heads'] * config['n_layers'] * precision_bytes

    # MQA: only 1 KV head shared across all query heads
    mqa_memory = 2 * seq_len * config['d_v'] * 1 * config['n_layers'] * precision_bytes

    # GQA: 4 KV heads (grouped)
    n_kv_heads_gqa = 4
    gqa_memory = 2 * seq_len * config['d_v'] * n_kv_heads_gqa * config['n_layers'] * precision_bytes

    print(f"\n{model_name} with seq_len={seq_len}:")
    print(f"  Standard attention: {standard_memory / (1024**2):.1f} MB")
    print(f"  MQA (1 KV head): {mqa_memory / (1024**2):.1f} MB")
    print(f"  GQA ({n_kv_heads_gqa} KV heads): {gqa_memory / (1024**2):.1f} MB")
    print(f"\n  MQA reduction: {config['n_heads']}x")
    print(f"  GQA reduction: {config['n_heads'] / n_kv_heads_gqa:.1f}x")

# ============================================================================
# EXPERIMENT 5: Learning Value Compression
# ============================================================================

def experiment_5_learning():
    """
    Simulate how Value matrix learns to compress information
    """
    print("\n" + "=" * 60)
    print("EXPERIMENT 5: Learning Value Compression")
    print("=" * 60)

    # Create a toy task: retrieve color information
    # Embeddings encode: [color_r, color_g, color_b, shape, size, ...]

    d_model = 16
    d_v = 4  # Compress to 4 dimensions
    n_samples = 100

    # Generate data
    # First 3 dims: RGB color (what we want to remember)
    # Rest: noise (what we want to forget)
    X = torch.randn(n_samples, d_model)
    X[:, :3] = torch.randn(n_samples, 3) * 2  # Stronger signal in color dims

    # Target: retrieve the color (first 3 dims)
    Y = X[:, :3]

    # Initialize Value projection
    W_V = nn.Parameter(torch.randn(d_model, d_v) / np.sqrt(d_model))
    W_out = nn.Parameter(torch.randn(d_v, 3) / np.sqrt(d_v))

    optimizer = torch.optim.Adam([W_V, W_out], lr=0.01)

    print("\nTraining Value matrix to compress color information...")
    print("Task: Remember RGB color (dims 0-2), forget noise (dims 3-15)")

    losses = []
    for epoch in range(500):
        # Forward pass
        V = torch.mm(X, W_V)  # Compress
        Y_pred = torch.mm(V, W_out)  # Reconstruct color

        # Loss: MSE between true color and reconstructed color
        loss = torch.mean((Y - Y_pred) ** 2)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())

        if epoch % 100 == 0:
            print(f"  Epoch {epoch}: Loss = {loss.item():.6f}")

    print(f"\nFinal loss: {losses[-1]:.6f}")

    # Analyze what W_V learned
    W_V_learned = W_V.detach()

    # Check if W_V focuses on color dimensions
    color_weights = torch.norm(W_V_learned[:3, :], dim=1)
    noise_weights = torch.norm(W_V_learned[3:, :], dim=1)

    print(f"\nLearned compression matrix W_V:")
    print(f"  Avg weight magnitude for color dims [0-2]: {torch.mean(color_weights):.3f}")
    print(f"  Avg weight magnitude for noise dims [3-15]: {torch.mean(noise_weights):.3f}")
    print(f"  Ratio: {torch.mean(color_weights) / torch.mean(noise_weights):.2f}x")

    if torch.mean(color_weights) > torch.mean(noise_weights) * 1.5:
        print("\n✓ W_V learned to focus on color information!")
        print("  It's compressing what matters and discarding noise.")
    else:
        print("\n⚠ W_V didn't learn clear focus")

    return W_V_learned, losses

# ============================================================================
# RUN ALL EXPERIMENTS
# ============================================================================

if __name__ == "__main__":
    # Experiment 1
    X, V, W_V, S = experiment_1_compression()

    # Experiment 2
    V_single, V_multi = experiment_2_multihead()

    # Experiment 3
    V_retrieval, attention_patterns = experiment_3_retrieval()

    # Experiment 4
    experiment_4_kv_cache()

    # Experiment 5
    W_V_learned, losses = experiment_5_learning()

    print("\n" + "=" * 60)
    print("All experiments complete!")
    print("=" * 60)
    print("\nKey Takeaways:")
    print("1. Value matrices compress information lossily")
    print("2. Multi-head attention creates diverse compressed views")
    print("3. Attention weights control what information gets retrieved")
    print("4. KV cache memory is dominated by V (not K)")
    print("5. Value matrices learn task-specific compression")
    print("\nThe Value matrix is the database. Attention is just the lookup.")

Value Matrix as Database: Interactive Experiments

EXPERIMENT 1: Value Matrix as Lossy Compression

Original dimension: 512
Compressed dimension: 64
Compression ratio: 8.0x

Reconstruction MSE: 0.865285
Relative error: 86.10%

Information loss: 86.10% of original signal

Singular value spectrum (top 10):
[1.3567705 1.3330604 1.3145422 1.2891326 1.2806014 1.2602961 1.2389673
 1.227566  1.2185081 1.2045442]

Effective rank (95% variance): 58
Out of maximum rank: 64

EXPERIMENT 2: Multi-Head Value Storage

Single head:
  V shape: torch.Size([20, 512])
  Total dimensions: 512

Multi-head (8 heads):
  V shape per head: 64
  V concatenated shape: torch.Size([20, 512])
  Total dimensions: 512

Average absolute correlation between heads: 0.026
Lower correlation = more diverse information storage
✓ Heads are storing diverse information!

EXPERIMENT 3: Information Retrieval via Attention

Value matrix structure:
  Tokens 0-1: Person information (dimensions 0-31)
  Tokens 2-3: Location informatio