# Method2: Maclurain Random Feature implemented on one row of softmax

# Precompute all


In [24]:
import numpy as np
import math
import torch

# ——— Load and prepare q0, k0 as NumPy arrays ———
q_full = torch.load("subset_qk/block_1_q_proj_batch_6.pt", map_location="cpu")
k_full = torch.load("subset_qk/block_1_k_proj_batch_6.pt", map_location="cpu")

q = q_full[0]  # shape [L, d_model]
k = k_full[0]

L, d_model = q.shape
num_heads  = 32
d_head     = d_model // num_heads

# pick head 15 and first `sample` positions
sample = 4096
def sampling(q, k, sample):
    q0 = (
        q
        .view(L, num_heads, d_head)
        .permute(1, 0, 2)[15, :sample]
        .numpy()
    )   # shape [sample, d_head]
    k0 = (
        k
        .view(L, num_heads, d_head)
        .permute(1, 0, 2)[15, :sample]
        .numpy()
    )

    q0 = q0 / 128 ** 0.25
    k0 = k0 / 128 ** 0.25
    
    return q0, k0


def rfa_attention_fast(q0, k0, P=8, D=2000, shrink=10.0, power=100):
    """
    q0, k0:  (N, d) arrays of queries & keys
    returns: (N, N) approximate softmax(QK^T)
    """
    # 1) Pre-scale and cast to float32
    X = (q0 / shrink).astype(np.float64)   # (N, d)
    Y = (k0 / shrink).astype(np.float64)

    N, d = X.shape

    # 2) Sample ±1 weights in one flat block, cast to float32
    #    shape = (P*D, d)
    w_flat = np.sign(np.random.randn(P * D, d)).astype(np.float64)

    # 3) One mat-mul to get all projections, then reshape to (N,P,D)
    #    proj_x[n, p*D + j] = w_flat[p*D + j] ⋅ X[n]
    proj_x_flat = X.dot(w_flat.T)           # (N, P*D)
    proj_y_flat = Y.dot(w_flat.T)

    proj_x = proj_x_flat.reshape(N, P, D)   # (N, P, D)
    proj_y = proj_y_flat.reshape(N, P, D)

    # 4) Build per-degree normalizers √(D·p!) for p=1..P
    facts = np.array([math.sqrt(math.factorial(p+1)) for p in range(P)],
                     dtype=np.float64)     # (P,)
    normalizer = np.sqrt(D, dtype=np.float64) * facts
    normalizer = normalizer.reshape(1, P, 1)  # (1, P, 1)

    # 5) Cumulative product along the P-axis to get φ_p
    #    φ_p = ∏_{m=1..p} (proj[...,m-1] / normalizer[...,m-1])
    phi_x = np.cumprod(proj_x / normalizer, axis=1)  # (N, P, D)
    phi_y = np.cumprod(proj_y / normalizer, axis=1)

    # 6) Flatten φ back to (N, P*D)
    phi_x_flat = phi_x.reshape(N, P * D)
    phi_y_flat = phi_y.reshape(N, P * D)

    # 7) One big BLAS mat-mul to form the kernel matrix
    S = phi_x_flat.dot(phi_y_flat.T)  # (N, N)

    # 8) Sharpen & row-normalize
    M = (1.0 + S) ** power
    M /= M.max(axis=1, keepdims=True)
    M /= M.sum(axis=1, keepdims=True)

    return M


# usage:
# approx = rfa_attention_vectorized(q0, k0)

def true_softmax(q0, k0):
    dot = q0 @ k0.T
    true_val = np.exp(dot - dot.max(axis=1, keepdims=True))
    true_val /= true_val.sum(axis=1, keepdims=True)
    return true_val

def report_error(record_approx_values, true_val):
    return torch.norm(torch.tensor(record_approx_values - true_val)) / torch.norm(torch.tensor(true_val))



# Usage 

In [25]:
P, D, d = 4, 2000, 128
sample=4096
q0, k0 = sampling(q, k, sample)
v0 =(q0+ 4 * k0) /3

In [26]:
record_approx_values = rfa_attention_fast(q0, k0, P, D)

In [27]:
approx_v=record_approx_values @ v0

In [28]:
true_vals = true_softmax(q0, k0)

In [29]:
true_val = true_vals @ v0

In [30]:
import torch
torch.norm(torch.tensor(record_approx_values - true_vals)) / torch.norm(torch.tensor(true_vals))

tensor(0.2413, dtype=torch.float64)

In [31]:
import torch
torch.norm(torch.tensor(approx_v - true_val)) / torch.norm(torch.tensor(true_val))

tensor(0.0297, dtype=torch.float64)

In [32]:
class RFAMultiHeadAttention(nn.Module):
    def __init__(self, 
                 d_model: int,
                 num_heads: int,
                 P: int = 8,
                 D: int = 2000,
                 shrink: float = 10.0, 
                 power: float = 100.0,
                 dropout: float = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.P = P
        self.D = D
        self.shrink = shrink
        self.power = power
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # 推迟随机特征的初始化，直到我们知道输入类型
        self.random_weights = None
        self.normalizer = None
        
    def _init_random_features(self, dtype, device):
        """初始化每个头的随机特征权重，根据输入的dtype和device"""
        self.random_weights = nn.Parameter(
            torch.randn(self.num_heads, self.P * self.D, self.d_head, 
                       dtype=dtype, device=device),
            requires_grad=False
        )
        # 将权重转换为±1
        self.random_weights.data.sign_()
        
        # 预计算归一化因子
        facts = torch.tensor([math.sqrt(math.factorial(p+1)) for p in range(self.P)],
                            dtype=dtype, device=device)
        self.normalizer = torch.sqrt(torch.tensor(self.D, dtype=dtype, device=device)) * facts.view(1, -1, 1)
        
    def _rfa_attention_head(self, q, k, v, head_idx):
        # 确保初始化了随机权重
        if self.random_weights is None:
            self._init_random_features(q.dtype, q.device)
        
        batch_size, seq_len, _ = q.shape
        
        # 确保所有计算都使用相同的dtype
        X = q / self.shrink
        Y = k / self.shrink
        
        # 获取并确保相同dtype的随机权重
        w_flat = self.random_weights[head_idx]
        
        # 确认类型匹配
        assert X.dtype == w_flat.dtype, f"类型不匹配: X.dtype={X.dtype}, w_flat.dtype={w_flat.dtype}"
        
        # 执行计算
        proj_x_flat = X @ w_flat.t()
        proj_y_flat = Y @ w_flat.t()
        
        proj_x = proj_x_flat.view(batch_size, seq_len, self.P, self.D)
        proj_y = proj_y_flat.view(batch_size, seq_len, self.P, self.D)
        
        phi_x = torch.cumprod(proj_x / self.normalizer, dim=2)
        phi_y = torch.cumprod(proj_y / self.normalizer, dim=2)
        
        phi_x_flat = phi_x.view(batch_size, seq_len, -1)
        phi_y_flat = phi_y.view(batch_size, seq_len, -1)
        
        S = phi_x_flat @ phi_y_flat.transpose(-2, -1)
        
        M = (1.0 + S).pow(self.power)
        M = M / M.max(dim=-1, keepdim=True)[0]
        attn = self.dropout(M / M.sum(dim=-1, keepdim=True))
        
        return attn @ v

In [33]:
import torch
import time
from torch.utils.benchmark import Timer

# 假设RFAMultiHeadAttention已在同目录下定义好

def test_rfa_vs_torch_mha():
    # 1. 加载数据
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    q_full = torch.load("subset_qk/block_1_q_proj_batch_6.pt", map_location=device)
    k_full = torch.load("subset_qk/block_1_k_proj_batch_6.pt", map_location=device)
    
    sample_size = 4096
    d_model = 4096
    num_heads = 32
    d_head = d_model // num_heads
    head_idx = 15

    def prepare_test_data():
        q = q_full[0].float()
        k = k_full[0].float()
        # 取出指定head
        q0 = (q.view(-1, num_heads, d_head)
              .permute(1, 0, 2)[head_idx, :sample_size]
              .div_(128 ** 0.25)
              .to(device))
        k0 = (k.view(-1, num_heads, d_head)
              .permute(1, 0, 2)[head_idx, :sample_size]
              .div_(128 ** 0.25)
              .to(device))
        v0 = ((q0 + 4 * k0) / 3).float()
        # 变成(seq_len, batch, embed_dim)格式，batch=1
        return q0.unsqueeze(1), k0.unsqueeze(1), v0.unsqueeze(1)

    # 2. 初始化RFA和官方MHA
    rfa_model = RFAMultiHeadAttention(
        d_model=d_head,
        num_heads=1,
        P=8,
        D=2000,
        shrink=10.0,
        power=100.0
    ).to(device)
    torch_mha = torch.nn.MultiheadAttention(
        embed_dim=d_head,
        num_heads=1,
        batch_first=False,
        bias=False
    ).to(device)
    # 让MHA的权重为单位阵，等价于直接qk^T
    torch_mha.in_proj_weight.data.copy_(torch.cat([
        torch.eye(d_head), torch.eye(d_head), torch.eye(d_head)
    ], dim=0))

    torch_mha.out_proj.weight.data.copy_(torch.eye(d_head))


    # 3. 定义基准函数，每次都重新加载qkv
    def benchmark_rfa():
        q0, k0, v0 = prepare_test_data()
        # (1, seq, d_head) -> (batch, seq, d_head)
        return rfa_model._rfa_attention_head(
            q0.permute(1,0,2), k0.permute(1,0,2), v0.permute(1,0,2), head_idx=0
        ).squeeze(0)  # (seq, d_head)

    def benchmark_torch_mha():
        q0, k0, v0 = prepare_test_data()
        # (seq, batch, d_head)
        out, _ = torch_mha(q0, k0, v0, need_weights=False)
        return out.squeeze(1)  # (seq, d_head)

    # 4. 预热
    for _ in range(3):
        benchmark_rfa()
        benchmark_torch_mha()

    # 5. 速度测试
    timer_rfa = Timer(
        stmt="benchmark_rfa()",
        globals={'benchmark_rfa': benchmark_rfa}
    )
    timer_torch = Timer(
        stmt="benchmark_torch_mha()",
        globals={'benchmark_torch_mha': benchmark_torch_mha}
    )
    print("\n速度测试结果:")
    print(f"RFA Attention: {timer_rfa.timeit(10)}")
    print(f"官方MultiheadAttention: {timer_torch.timeit(10)}")

    # 6. 精度测试
    with torch.no_grad():
        rfa_out = benchmark_rfa()
        torch_out = benchmark_torch_mha()
        # 计算相对L2误差
        error = torch.norm(rfa_out - torch_out) / torch.norm(torch_out)
        print(f"\n相对L2误差: {error.item():.3e}")
        print(f"最大绝对误差: {torch.max(torch.abs(rfa_out - torch_out)).item():.3e}")
        print(f"平均绝对误差: {torch.mean(torch.abs(rfa_out - torch_out)).item():.3e}")

if __name__ == "__main__":
    test_rfa_vs_torch_mha()


速度测试结果:
RFA Attention: <torch.utils.benchmark.utils.common.Measurement object at 0x7f143d329360>
benchmark_rfa()
  43.17 ms
  1 measurement, 10 runs , 1 thread
官方MultiheadAttention: <torch.utils.benchmark.utils.common.Measurement object at 0x7f143d328af0>
benchmark_torch_mha()
  1.53 ms
  1 measurement, 10 runs , 1 thread

相对L2误差: 1.217e-01
最大绝对误差: 8.446e-01
平均绝对误差: 4.437e-02



速度测试结果:
RFA Attention: <torch.utils.benchmark.utils.common.Measurement object at 0x7f143d329360>
benchmark_rfa()
  43.17 ms
  1 measurement, 10 runs , 1 thread
官方MultiheadAttention: <torch.utils.benchmark.utils.common.Measurement object at 0x7f143d328af0>
benchmark_torch_mha()
  1.53 ms
  1 measurement, 10 runs , 1 thread

相对L2误差: 1.217e-01
最大绝对误差: 8.446e-01
平均绝对误差: 4.437e-02