# Method2: Maclurain Random Feature implemented on one row of softmax

# Precompute all


In [None]:
import numpy as np
import math
import torch

# ——— Load and prepare q0, k0 as NumPy arrays ———
q_full = torch.load("subset_qk/block_1_q_proj_batch_6.pt", map_location="cpu")
k_full = torch.load("subset_qk/block_1_k_proj_batch_6.pt", map_location="cpu")

q = q_full[0]  # shape [L, d_model]
k = k_full[0]

L, d_model = q.shape
num_heads  = 32
d_head     = d_model // num_heads

# pick head 15 and first `sample` positions
sample = 4096
def sampling(q, k, sample):
    q0 = (
        q
        .view(L, num_heads, d_head)
        .permute(1, 0, 2)[15, :sample]
        .numpy()
    )   # shape [sample, d_head]
    k0 = (
        k
        .view(L, num_heads, d_head)
        .permute(1, 0, 2)[15, :sample]
        .numpy()
    )

    q0 = q0 / 128 ** 0.25
    k0 = k0 / 128 ** 0.25
    
    return q0, k0


def rfa_attention_fast(q0, k0, P=8, D=2000, shrink=10.0, power=100):
    """
    q0, k0:  (N, d) arrays of queries & keys
    returns: (N, N) approximate softmax(QK^T)
    """
    # 1) Pre-scale and cast to float32
    X = (q0 / shrink).astype(np.float64)   # (N, d)
    Y = (k0 / shrink).astype(np.float64)

    N, d = X.shape

    # 2) Sample ±1 weights in one flat block, cast to float32
    #    shape = (P*D, d)
    w_flat = np.sign(np.random.randn(P * D, d)).astype(np.float64)

    # 3) One mat-mul to get all projections, then reshape to (N,P,D)
    #    proj_x[n, p*D + j] = w_flat[p*D + j] ⋅ X[n]
    proj_x_flat = X.dot(w_flat.T)           # (N, P*D)
    proj_y_flat = Y.dot(w_flat.T)

    proj_x = proj_x_flat.reshape(N, P, D)   # (N, P, D)
    proj_y = proj_y_flat.reshape(N, P, D)

    # 4) Build per-degree normalizers √(D·p!) for p=1..P
    facts = np.array([math.sqrt(math.factorial(p+1)) for p in range(P)],
                     dtype=np.float64)     # (P,)
    normalizer = np.sqrt(D, dtype=np.float64) * facts
    normalizer = normalizer.reshape(1, P, 1)  # (1, P, 1)

    # 5) Cumulative product along the P-axis to get φ_p
    #    φ_p = ∏_{m=1..p} (proj[...,m-1] / normalizer[...,m-1])
    phi_x = np.cumprod(proj_x / normalizer, axis=1)  # (N, P, D)
    phi_y = np.cumprod(proj_y / normalizer, axis=1)

    # 6) Flatten φ back to (N, P*D)
    phi_x_flat = phi_x.reshape(N, P * D)
    phi_y_flat = phi_y.reshape(N, P * D)

    # 7) One big BLAS mat-mul to form the kernel matrix
    S = phi_x_flat.dot(phi_y_flat.T)  # (N, N)

    # 8) Sharpen & row-normalize
    M = (1.0 + S) ** power
    M /= M.max(axis=1, keepdims=True)
    M /= M.sum(axis=1, keepdims=True)

    return M


# usage:
# approx = rfa_attention_vectorized(q0, k0)

def true_softmax(q0, k0):
    dot = q0 @ k0.T
    true_val = np.exp(dot - dot.max(axis=1, keepdims=True))
    true_val /= true_val.sum(axis=1, keepdims=True)
    return true_val

def report_error(record_approx_values, true_val):
    return torch.norm(torch.tensor(record_approx_values - true_val)) / torch.norm(torch.tensor(true_val))



# Usage 

In [None]:
P, D, d = 4, 2000, 128
sample=4096
q0, k0 = sampling(q, k, sample)
v0 =(q0+ 4 * k0) /3

In [None]:
record_approx_values = rfa_attention_fast(q0, k0, P, D)

In [None]:
approx_v=record_approx_values @ v0

In [None]:
true_vals = true_softmax(q0, k0)

In [None]:
true_val = true_vals @ v0

In [None]:
import torch
torch.norm(torch.tensor(record_approx_values - true_vals)) / torch.norm(torch.tensor(true_vals))

In [None]:
import torch
torch.norm(torch.tensor(approx_v - true_val)) / torch.norm(torch.tensor(true_val))

In [None]:
import torch
import torch.nn as nn
import math

class OptimizedRFAMultiHeadAttention(nn.Module):
    def __init__(self, 
                 d_model: int,
                 num_heads: int,
                 P: int = 8,
                 D: int = 2000,
                 shrink: float = 10.0, 
                 power: float = 100.0,
                 dropout: float = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_head = d_model // num_heads
        self.P = P
        self.D = D
        self.shrink = shrink
        self.power = power
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        # 推迟随机特征的初始化，直到我们知道输入类型
        self.random_weights = None
        self.normalizer = None
        
        # 用于存储预计算的kv特征
        self.kv_cache = {}
        
    def _init_random_features(self, dtype, device):
        """初始化每个头的随机特征权重，根据输入的dtype和device"""
        self.random_weights = nn.Parameter(
            torch.randn(self.num_heads, self.P * self.D, self.d_head, 
                       dtype=dtype, device=device),
            requires_grad=False
        )
        # 将权重转换为±1
        self.random_weights.data.sign_()
        
        # 预计算归一化因子
        facts = torch.tensor([math.sqrt(math.factorial(p+1)) for p in range(self.P)],
                            dtype=dtype, device=device)
        self.normalizer = torch.sqrt(torch.tensor(self.D, dtype=dtype, device=device)) * facts.view(1, -1, 1)
    
    def _precompute_kv_features(self, k, v, head_idx):
        """预计算k和v的特征"""
        if self.random_weights is None:
            self._init_random_features(k.dtype, k.device)
        
        batch_size, seq_len, _ = k.shape
        
        # 确保所有计算都使用相同的dtype
        Y = k / self.shrink
        w_flat = self.random_weights[head_idx]
        
        # 执行计算
        proj_y_flat = Y @ w_flat.t()
        proj_y = proj_y_flat.view(batch_size, seq_len, self.P, self.D)
        
        # 计算phi(k)
        phi_y = torch.cumprod(proj_y / self.normalizer, dim=2)
        phi_y_flat = phi_y.reshape(batch_size, seq_len, self.P * self.D)
        
        # 计算phi(k)v
        phi_y_v = torch.bmm(phi_y_flat.transpose(1, 2), v)  # (batch_size, P*D, d_head)
        
        # 计算phi(k)的归一化因子
        phi_y_sum = phi_y_flat.sum(dim=1, keepdim=True)  # (batch_size, 1, P*D)
        
        return {
            'phi_y_flat': phi_y_flat,  # (batch_size, seq_len, P*D)
            'phi_y_v': phi_y_v,        # (batch_size, P*D, d_head)
            'phi_y_sum': phi_y_sum,    # (batch_size, 1, P*D)
            'v': v                      # (batch_size, seq_len, d_head)
        }
    
    def _compute_attention_with_precomputed(self, q, kv_features, head_idx):
        """使用预计算的kv特征计算注意力"""
        if self.random_weights is None:
            self._init_random_features(q.dtype, q.device)
        
        batch_size, seq_len, _ = q.shape
        
        # 确保所有计算都使用相同的dtype
        X = q / self.shrink
        w_flat = self.random_weights[head_idx]
        
        # 执行计算
        proj_x_flat = X @ w_flat.t()
        proj_x = proj_x_flat.view(batch_size, seq_len, self.P, self.D)
        
        # 计算phi(q)
        phi_x = torch.cumprod(proj_x / self.normalizer, dim=2)
        phi_x_flat = phi_x.reshape(batch_size, seq_len, self.P * self.D)
        
        # 计算phi(q)phi(k)v
        attn_output = torch.bmm(phi_x_flat, kv_features['phi_y_v'])  # (batch_size, seq_len, d_head)
        
        # 计算归一化因子
        norm_factor = torch.bmm(phi_x_flat, kv_features['phi_y_sum'].transpose(1, 2))  # (batch_size, seq_len, 1)
        
        # 应用power和归一化
        attn_output = attn_output / (norm_factor + 1e-6)
        attn_output = (1.0 + attn_output) ** self.power
        
        # 最终归一化
        attn_output = attn_output / attn_output.sum(dim=1, keepdim=True)
        
        return attn_output
        
    def _rfa_attention_head(self, q, k, v, head_idx):
        """计算单个头的RFA注意力"""
        if self.random_weights is None:
            self._init_random_features(q.dtype, q.device)
        
        batch_size, seq_len, _ = q.shape
        
        # 确保所有计算都使用相同的dtype
        X = q / self.shrink
        Y = k / self.shrink
        w_flat = self.random_weights[head_idx]
        
        # 执行计算
        proj_x_flat = X @ w_flat.t()
        proj_y_flat = Y @ w_flat.t()
        
        proj_x = proj_x_flat.view(batch_size, seq_len, self.P, self.D)
        proj_y = proj_y_flat.view(batch_size, seq_len, self.P, self.D)
        
        # 计算phi(k)和phi(q)
        phi_x = torch.cumprod(proj_x / self.normalizer, dim=2)
        phi_y = torch.cumprod(proj_y / self.normalizer, dim=2)
        
        # 重塑为(batch_size, seq_len, P*D)
        phi_x_flat = phi_x.reshape(batch_size, seq_len, self.P * self.D)
        phi_y_flat = phi_y.reshape(batch_size, seq_len, self.P * self.D)
        v = v ** 0.01
        # 计算phi(k)v
        phi_y_v = torch.bmm(phi_y_flat.transpose(1, 2), v)  # (batch_size, P*D, d_head)
        
        # 计算phi(q)phi(k)v
        attn_output = torch.bmm(phi_x_flat, phi_y_v)  # (batch_size, seq_len, d_head)
        
        # 计算归一化因子
        phi_y_sum = phi_y_flat.sum(dim=1, keepdim=True)  # (batch_size, 1, P*D)
        norm_factor = torch.bmm(phi_x_flat, phi_y_sum.transpose(1, 2))  # (batch_size, seq_len, 1)
        
        # 应用power和归一化
        attn_output = attn_output / (norm_factor + 1e-6)
        attn_output = (1.0 + attn_output) ** self.power
        
        # 最终归一化
        attn_output = attn_output / attn_output.sum(dim=1, keepdim=True)
        
        return attn_output
        
    def forward(self, query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None):
        tgt_len, bsz, embed_dim = query.shape
        src_len = key.size(0)
        
        # 投影
        q = self.q_proj(query)  # [tgt_len, bsz, embed_dim]
        k = self.k_proj(key)    # [src_len, bsz, embed_dim]
        v = self.v_proj(value)
        
        # 重塑为多头形式
        q = q.reshape(tgt_len, bsz * self.num_heads, self.d_head).transpose(0, 1)
        k = k.reshape(src_len, bsz * self.num_heads, self.d_head).transpose(0, 1)
        v = v.reshape(src_len, bsz * self.num_heads, self.d_head).transpose(0, 1)
        
        # 计算每个头的注意力
        attn_outputs = []
        for h in range(self.num_heads):
            # 获取当前头的q、k、v
            q_h = q[h*bsz:(h+1)*bsz]  # [bsz, tgt_len, d_head]
            k_h = k[h*bsz:(h+1)*bsz]  # [bsz, src_len, d_head]
            v_h = v[h*bsz:(h+1)*bsz]  # [bsz, src_len, d_head]
            
            # 生成缓存键
            cache_key = f"head_{h}_kv"
            
            # 检查是否需要预计算kv特征
            if cache_key not in self.kv_cache:
                self.kv_cache[cache_key] = self._precompute_kv_features(k_h, v_h, h)
            
            # 使用预计算的特征计算注意力
            attn_output = self._compute_attention_with_precomputed(
                q_h,
                self.kv_cache[cache_key],
                h
            )
            attn_outputs.append(attn_output)
        
        # 合并多头输出
        attn_output = torch.cat(attn_outputs, dim=0)
        attn_output = attn_output.transpose(0, 1).reshape(tgt_len, bsz, embed_dim)
        
        # 输出投影
        attn_output = self.out_proj(attn_output)
        
        return attn_output, None

In [25]:
import torch
import time
import gc
from torch.utils.benchmark import Timer

# 1. 加载数据
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
q_full = torch.load("subset_qk/block_1_q_proj_batch_6.pt", map_location=device)
k_full = torch.load("subset_qk/block_1_k_proj_batch_6.pt", map_location=device)

sample_size = 4096
d_model = 4096
num_heads = 32
d_head = d_model // num_heads
num_q_permutations = 50 # q的排列次数

# 计算计数器
computation_count = 0
group_count = 0

def prepare_test_data():
    q = q_full[0].float()
    k = k_full[0].float()
    
    # 取出所有head的q
    q_all = (q.view(-1, num_heads, d_head)
            .permute(1, 0, 2)[:, :sample_size]
            .div_(128 ** 0.25)
            .to(device))  # [num_heads, sample_size, d_head]
    
    # 取出所有head的k
    k_all = (k.view(-1, num_heads, d_head)
            .permute(1, 0, 2)[:, :sample_size]
            .div_(128 ** 0.25)
            .to(device))  # [num_heads, sample_size, d_head]
    
    # 生成v
    v = ((q_all[0] + k_all[0]) / 2).float()  # [sample_size, d_head]
    
    # 生成50个不同的q排列
    q_permutations = []
    for _ in range(num_q_permutations):
        perm = torch.randperm(sample_size, device=device)
        q_perm = q_all[0][perm]  # 使用第一个head的q进行排列
        q_permutations.append(q_perm)
    
    return q_permutations, k_all[0], v  # 返回q排列列表，固定的k和v

# 2. 初始化优化后的RFA和官方MHA
optimized_rfa = OptimizedRFAMultiHeadAttention(
    d_model=d_head,
    num_heads=1,
    P=4,
    D=100,
    shrink=10.0,
    power=100.0
).to(device)

torch_mha = torch.nn.MultiheadAttention(
    embed_dim=d_head,
    num_heads=1,
    batch_first=False,
    bias=False
).to(device)

# 让MHA的权重为单位阵，等价于直接qk^T
torch_mha.in_proj_weight.data.copy_(torch.cat([
    torch.eye(d_head), torch.eye(d_head), torch.eye(d_head)
], dim=0))

torch_mha.out_proj.weight.data.copy_(torch.eye(d_head))

def clear_memory():
    gc.collect()
    torch.cuda.empty_cache()

# 3. 定义基准函数
def benchmark_optimized_rfa():
    global computation_count, group_count
    q_permutations, k, v = prepare_test_data()
    outputs = []
    
    # 预计算k和v的phi（只计算一次）
    k_phi = optimized_rfa._precompute_kv_features(
        k.unsqueeze(0),  # [1, sample_size, d_head]
        v.unsqueeze(0),  # [1, sample_size, d_head]
        head_idx=0
    )
    
    # 对每个q排列计算注意力
    for q in q_permutations:
        # 计算注意力
        out = optimized_rfa._compute_attention_with_precomputed(
            q.unsqueeze(0),  # [1, sample_size, d_head]
            k_phi,
            head_idx=0
        ).squeeze(0)  # [sample_size, d_head]
        
        outputs.append(out)
        computation_count += 1
        group_count += 1
    
    return torch.stack(outputs)  # [num_q_permutations, sample_size, d_head]

def benchmark_torch_mha():
    global computation_count
    q_permutations, k, v = prepare_test_data()
    outputs = []
    
    # 对每个q排列计算注意力
    for q in q_permutations:
        # 计算注意力
        out, _ = torch_mha(
            q.unsqueeze(1),  # [sample_size, 1, d_head]
            k.unsqueeze(1),  # [sample_size, 1, d_head]
            v.unsqueeze(1),  # [sample_size, 1, d_head]
            need_weights=False
        )
        outputs.append(out.squeeze(1))  # [sample_size, d_head]
        computation_count += 1
    
    return torch.stack(outputs)  # [num_q_permutations, sample_size, d_head]

# 4. 预热
print("\n预热中...")
for _ in range(3):
    benchmark_optimized_rfa()
    benchmark_torch_mha()
    clear_memory()

# 5. 速度测试
print("\n开始速度测试...")
start_time = time.time()
optimized_out = benchmark_optimized_rfa()
optimized_time = time.time() - start_time
clear_memory()

start_time = time.time()
torch_out = benchmark_torch_mha()
torch_time = time.time() - start_time
clear_memory()

print(f"\Time test:")
print(f"RFA Attention: {optimized_time:.3f} s")
print(f"MultiheadAttention: {torch_time:.3f} s")
print(f"Ratio: {torch_time/optimized_time:.2f}x")
# 6. 精度测试
print("\n开始精度测试...")
with torch.no_grad():
    # 计算每个组合的误差
    errors = []
    max_abs_errors = []
    mean_abs_errors = []
    
    # 确保使用正确的循环范围
    for i in range(num_q_permutations):  # 使用q排列的数量
        # 计算相对L2误差
        error = torch.norm(optimized_out[i] - torch_out[i]) / torch.norm(torch_out[i])
        errors.append(error.item())
        
        # 计算最大绝对误差
        max_abs_error = torch.max(torch.abs(optimized_out[i] - torch_out[i])).item()
        max_abs_errors.append(max_abs_error)
        
        # 计算平均绝对误差
        mean_abs_error = torch.mean(torch.abs(optimized_out[i] - torch_out[i])).item()
        mean_abs_errors.append(mean_abs_error)
    
    print(f"\n精度测试结果:")
    print(f"所有组合的平均相对L2误差: {sum(errors)/len(errors):.3e}")
    print(f"所有组合的平均最大绝对误差: {sum(max_abs_errors)/len(max_abs_errors):.3e}")
    print(f"所有组合的平均平均绝对误差: {sum(mean_abs_errors)/len(mean_abs_errors):.3e}")
# 7. 统计信息
print(f"\n统计信息:")
print(f"总组数: {group_count:,}")
print(f"总计算次数: {computation_count:,}")
print(f"每组平均计算次数: {computation_count/group_count:.2f}")

# 最后清理内存
clear_memory()


预热中...

开始速度测试...
\Time test:
RFA Attention: 0.014 s
MultiheadAttention: 0.025 s
Ratio: 1.80x

开始精度测试...

精度测试结果:
所有组合的平均相对L2误差: nan
所有组合的平均最大绝对误差: nan
所有组合的平均平均绝对误差: nan

统计信息:
总组数: 200
总计算次数: 400
每组平均计算次数: 2.00


In [None]:
optimized_out

In [None]:
# 6. 精度测试
print("\n开始精度测试...")
with torch.no_grad():
    # 计算每个组合的误差
    errors = []
    max_abs_errors = []
    mean_abs_errors = []
    
    # 确保使用正确的循环范围
    for i in range(num_q_permutations):  # 使用q排列的数量
        # 计算相对L2误差
        error = torch.norm(optimized_out[i] - torch_out[i]) / torch.norm(torch_out[i])
        errors.append(error.item())
        
        # 计算最大绝对误差
        max_abs_error = torch.max(torch.abs(optimized_out[i] - torch_out[i])).item()
        max_abs_errors.append(max_abs_error)
        
        # 计算平均绝对误差
        mean_abs_error = torch.mean(torch.abs(optimized_out[i] - torch_out[i])).item()
        mean_abs_errors.append(mean_abs_error)
    
    print(f"\n精度测试结果:")
    print(f"所有组合的平均相对L2误差: {sum(errors)/len(errors):.3e}")
    print(f"所有组合的平均最大绝对误差: {sum(max_abs_errors)/len(max_abs_errors):.3e}")
    print(f"所有组合的平均平均绝对误差: {sum(mean_abs_errors)/len(mean_abs_errors):.3e}")

In [None]:
clear_memory()

In [None]:
clear_memory()