In [18]:
import torch
import numpy as np

# 1) Load your q0 and k0 from the provided files
q_full = torch.load("subset_qk/block_1_q_proj_batch_6.pt")
k_full = torch.load("subset_qk/block_1_k_proj_batch_6.pt")
q = q_full[0]
k = k_full[0]
L, d_model = q.shape
num_heads = 32
d_head = d_model // num_heads

# Extract head 15
q0 = q.view(L, num_heads, d_head).permute(1, 0, 2)[15]
k0 = k.view(L, num_heads, d_head).permute(1, 0, 2)[15]
# Normalize as per the paper
q0 = q0 / (128**0.25)
k0 = k0 / (128**0.25)

# Convert to NumPy
q0_np = q0.cpu().numpy()
k0_np = k0.cpu().numpy()
V_np = (3 * q0 + 4* k0) /10

# Method1: Unbiased Implementation

In [19]:
import numpy as np

def random_feature_attention(Q, K, V, D=20000, seed=42):
    """
    Approximate softmax-attention using D random‐feature samples.
    Q: (L, d), K: (L, d), V: (L, dv)
    Returns:
      alpha_hat: (L, L)    approximate attention weights
      out:       (L, dv)   approximate attended outputs
    """
    rng = np.random.RandomState(seed)
    L, d = Q.shape

    # 1) draw random weights W ~ N(0,1), shape (D, d)
    W = rng.randn(D, d)  # (D, d)

    # 2) compute features φ_Q and φ_K:
    #    φ(x) = exp(x·Wᵀ − ½‖x‖²) / √D
    Q2 = np.sum(Q*Q, axis=1, keepdims=True)   # (L,1)
    K2 = np.sum(K*K, axis=1, keepdims=True)   # (L,1)
    phi_Q = np.exp(Q.dot(W.T) - 0.5*Q2) / np.sqrt(D)  # (L, D)
    phi_K = np.exp(K.dot(W.T) - 0.5*K2) / np.sqrt(D)  # (L, D)

    # 3) approximate kernel: φ_Q φ_Kᵀ ≈ exp(Q Kᵀ)
    A_hat = phi_Q.dot(phi_K.T)                # (L, L)

    # 4) row‐normalize to get weights
    alpha_hat = A_hat / A_hat.sum(axis=1, keepdims=True)  # (L, L)

    # 5) attended outputs
    out = alpha_hat.dot(V)                    # (L, dv)
    return alpha_hat, out

# ────────────────────────────────────────────────────────
# Usage:
# (make sure q0_np, k0_np are already NumPy arrays in your session)

D_features = 20000
alpha_approx, out_approx = random_feature_attention(
    q0_np, k0_np, V_np, D=D_features, seed=0
)

print("alpha_approx.shape:", alpha_approx.shape)  # → (4096, 4096)
print("out_approx.shape:  ", out_approx.shape)    # → (4096, 128)


alpha_approx.shape: (4096, 4096)
out_approx.shape:   (4096, 128)


In [20]:
import numpy as np
import torch  # only for loading your .pt files



# 2) Exact softmax-attention (NumPy vectorized)
scores = q0_np.dot(k0_np.T)                          # (4096,4096)
exp_scores = np.exp(scores - scores.max(axis=1, keepdims=True))
alpha_full = exp_scores / exp_scores.sum(axis=1, keepdims=True)  # (4096,4096)
out_full = alpha_full.dot(V_np)                  # (4096,128)

print("alpha_full.shape:", alpha_full.shape)    # -> (4096,4096)
print("out_full.shape:  ", out_full.shape)      # -> (4096,128)


alpha_full.shape: (4096, 4096)
out_full.shape:   (4096, 128)


In [23]:
np.linalg.norm(alpha_full - alpha_approx) / np.linalg.norm(alpha_full)

2.7916740583049093

In [None]:
np.linalg.norm(out_full)

137.0

In [24]:
np.linalg.norm(out_approx - out_full) / np.linalg.norm(out_full)

0.2713716779893449

In [None]:
np.linalg.norm(out_full)

137.0