# Method2: Maclurain Random Feature implemented on one row of softmax

# Precompute all


In [7]:
import numpy as np
import math
import torch

# ——— Load and prepare q0, k0 as NumPy arrays ———
q_full = torch.load("subset_qk/block_1_q_proj_batch_6.pt", map_location="cpu")
k_full = torch.load("subset_qk/block_1_k_proj_batch_6.pt", map_location="cpu")

q = q_full[0]  # shape [L, d_model]
k = k_full[0]

L, d_model = q.shape
num_heads  = 32
d_head     = d_model // num_heads

# pick head 15 and first `sample` positions
sample = 4096
def sampling(q, k, sample):
    q0 = (
        q
        .view(L, num_heads, d_head)
        .permute(1, 0, 2)[15, :sample]
        .numpy()
    )   # shape [sample, d_head]
    k0 = (
        k
        .view(L, num_heads, d_head)
        .permute(1, 0, 2)[15, :sample]
        .numpy()
    )

    q0 = q0 / 128 ** 0.25
    k0 = k0 / 128 ** 0.25
    
    return q0, k0


def rfa_attention_fast(q0, k0, P=8, D=2000, shrink=10.0, power=100):
    """
    q0, k0:  (N, d) arrays of queries & keys
    returns: (N, N) approximate softmax(QK^T)
    """
    # 1) Pre-scale and cast to float32
    X = (q0 / shrink).astype(np.float64)   # (N, d)
    Y = (k0 / shrink).astype(np.float64)

    N, d = X.shape

    # 2) Sample ±1 weights in one flat block, cast to float32
    #    shape = (P*D, d)
    w_flat = np.sign(np.random.randn(P * D, d)).astype(np.float64)

    # 3) One mat-mul to get all projections, then reshape to (N,P,D)
    #    proj_x[n, p*D + j] = w_flat[p*D + j] ⋅ X[n]
    proj_x_flat = X.dot(w_flat.T)           # (N, P*D)
    proj_y_flat = Y.dot(w_flat.T)

    proj_x = proj_x_flat.reshape(N, P, D)   # (N, P, D)
    proj_y = proj_y_flat.reshape(N, P, D)

    # 4) Build per-degree normalizers √(D·p!) for p=1..P
    facts = np.array([math.sqrt(math.factorial(p+1)) for p in range(P)],
                     dtype=np.float64)     # (P,)
    normalizer = np.sqrt(D, dtype=np.float64) * facts
    normalizer = normalizer.reshape(1, P, 1)  # (1, P, 1)

    # 5) Cumulative product along the P-axis to get φ_p
    #    φ_p = ∏_{m=1..p} (proj[...,m-1] / normalizer[...,m-1])
    phi_x = np.cumprod(proj_x / normalizer, axis=1)  # (N, P, D)
    phi_y = np.cumprod(proj_y / normalizer, axis=1)

    # 6) Flatten φ back to (N, P*D)
    phi_x_flat = phi_x.reshape(N, P * D)
    phi_y_flat = phi_y.reshape(N, P * D)

    # 7) One big BLAS mat-mul to form the kernel matrix
    S = phi_x_flat.dot(phi_y_flat.T)  # (N, N)

    # 8) Sharpen & row-normalize
    M = (1.0 + S) ** power
    M /= M.max(axis=1, keepdims=True)
    M /= M.sum(axis=1, keepdims=True)

    return M


# usage:
# approx = rfa_attention_vectorized(q0, k0)

def true_softmax(q0, k0):
    dot = q0 @ k0.T
    true_val = np.exp(dot - dot.max(axis=1, keepdims=True))
    true_val /= true_val.sum(axis=1, keepdims=True)
    return true_val

def report_error(record_approx_values, true_val):
    return torch.norm(torch.tensor(record_approx_values - true_val)) / torch.norm(torch.tensor(true_val))



# Usage 

In [8]:
P, D, d = 4, 2000, 128
sample=4096
q0, k0 = sampling(q, k, sample)
v0 =(q0+ 4 * k0) /3

In [9]:
record_approx_values = rfa_attention_fast(q0, k0, P, D)

In [10]:
approx_v=record_approx_values @ v0

In [11]:
true_vals = true_softmax(q0, k0)

In [12]:
true_val = true_vals @ v0

In [13]:
import torch
torch.norm(torch.tensor(record_approx_values - true_vals)) / torch.norm(torch.tensor(true_vals))

tensor(0.2230, dtype=torch.float64)

In [14]:
import torch
torch.norm(torch.tensor(approx_v - true_val)) / torch.norm(torch.tensor(true_val))

tensor(0.0246, dtype=torch.float64)