Window-based attention benchmark
Include:
- Regular Swin Transformer Window Attention
- 2D Window Flex Attention
- Hilbert Window Attention

Sliding Window / Neighborhood attention benchmark
Include:
- Regular Slide Attention
- Regular NATTEN Attention
- 2D Flex Sliding Window Attention
- Hilbert Sliding Window Attention

# Import library and initial setup

In [1]:
import random
from functools import lru_cache, partial

import torch
import torch.nn.functional as F

from tabulate import tabulate
from torch.nn.attention.flex_attention import (
    _DEFAULT_SPARSE_BLOCK_SIZE,
    create_block_mask,
    create_mask,
    flex_attention,
)
from triton.testing import do_bench

torch.set_default_device("cuda")
torch.manual_seed(0)

torch._dynamo.config.cache_size_limit = 1000

# Compile the flex_attention function
flex_attention = torch.compile(flex_attention, dynamic=False)

data_type = torch.float16

# The kernels will utilize block sparisty to increase performance
print(f"Using the default sparsity block size: {_DEFAULT_SPARSE_BLOCK_SIZE}")

Using the default sparsity block size: 128


# WSA

In [None]:
import torch
import math

def swin_window_rearrange(x, H_img: int, W_img: int, WINDOW: int):
    B, H_img, W_img, C = x.shape
    assert H_img % WINDOW == 0 and W_img % WINDOW == 0

    num_win_h = H_img // WINDOW
    num_win_w = W_img // WINDOW
    Lw = WINDOW * WINDOW
    Nw = num_win_h * num_win_w

    # rearrange the 2D image to window sequence
    x_windows = x.view(B, num_win_h, WINDOW, num_win_w, WINDOW, C)
    x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous()
    x_windows = x_windows.view(B * Nw, Lw, C)

    shape_info = (B, C, H_img, W_img, num_win_h, num_win_w, Nw, Lw)
    return x_windows, shape_info

def swin_qkv_projection(x_windows, num_heads: int):

    B_Nw, Lw, C = x_windows.shape
    head_dim = C // num_heads

    # QKV projection - create fixed weight matrix
    if not hasattr(swin_qkv_projection, 'qkv_weight'):
        swin_qkv_projection.qkv_weight = torch.randn(
            C, 3*C, device=x_windows.device, dtype=x_windows.dtype
        )

    # QKV projection (self.qkv(x) in SWIN)
    qkv = x_windows @ swin_qkv_projection.qkv_weight
    qkv = qkv.reshape(B_Nw, Lw, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4).contiguous()
    q, k, v = qkv[0], qkv[1], qkv[2]  # shape: (B*Nw, num_heads, Lw, head_dim)

    return q, k, v

def restore_windows(windows, shape_info):
    B, C, H_img, W_img, num_win_h, num_win_w, Nw, Lw = shape_info
    WINDOW = int(math.sqrt(Lw))

    if len(windows.shape) == 4:  # (B*Nw, num_heads, Lw, head_dim)
        B_Nw, num_heads, Lw, head_dim = windows.shape
        windows = windows.transpose(1, 2).reshape(B_Nw, Lw, num_heads * head_dim)
    x = windows.view(B, num_win_h, num_win_w, WINDOW, WINDOW, C)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.view(B, H_img, W_img, C)
    x = x.view(B, H_img * W_img, C)
    return x

def swin_attention_only(q, k, v, rpb=None):
    """Pure Swin attention calculation, without QKV projection"""
    B_Nw, num_heads, Lw, head_dim = q.shape

    q = q * (head_dim ** -0.5)
    attn = q @ k.transpose(-2, -1)  # (B*Nw, num_heads, Lw, Lw)
    if rpb is not None:
        attn = attn + rpb.unsqueeze(0)
    attn = attn.softmax(dim=-1)

    out = attn @ v  # (B*Nw, num_heads, Lw, head_dim)
    out = out.transpose(1, 2).reshape(B_Nw, Lw, num_heads * head_dim)  # (B*Nw, Lw, C)
    # out = F.scaled_dot_product_attention(
    # q, k, v,
    # attn_mask=rpb,            # ← additive bias (RPB)
    # dropout_p=0.0,
    # is_causal=False
    # )
    return out

def create_simple_rpb(Lw, num_heads, device, dtype):
    """Create a simple 1D global RPB outside"""
    # directly create the bias matrix (num_heads, Lw, Lw)
    rpb = torch.randn(num_heads, Lw, Lw, device=device, dtype=dtype) * 0.02
    return rpb

def test():
    B, H_img, W_img, C, WINDOW = 16, 128, 128, 128, 16
    S = H_img * W_img
    num_heads = 2
    device = "cuda"
    dtype = torch.float16

    # prepare the input and gradient
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    gradOut = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16)

    # test the window rearrange and QKV projection
    x_windows, shape_info = swin_window_rearrange(x, H_img, W_img, WINDOW)
    q, k, v = swin_qkv_projection(x_windows, num_heads)

    # prepare the gradient of the windowed
    gradOut_windows = gradOut.view(B, H_img, W_img, C)
    gradOut_windows = gradOut_windows.view(B, H_img//WINDOW, WINDOW, W_img//WINDOW, WINDOW, C)
    gradOut_windows = gradOut_windows.permute(0, 1, 3, 2, 4, 5).contiguous()
    gradOut_windows = gradOut_windows.view(B * (H_img//WINDOW) * (W_img//WINDOW), WINDOW*WINDOW, C)

    Lw = WINDOW * WINDOW
    rpb = create_simple_rpb(Lw, num_heads, device, dtype)

    # Forward - window rearrange time
    window_rearrange_call = lambda: swin_window_rearrange(x, H_img, W_img, WINDOW)
    window_rearrange_fw = do_bench(window_rearrange_call)

    # Forward - QKV projection time
    qkv_proj_call = lambda: swin_qkv_projection(x_windows, num_heads)
    qkv_proj_fw = do_bench(qkv_proj_call)

    # Forward - pure attention (without QKV projection and rearrange)
    pure_attention_call = lambda: swin_attention_only(q, k, v, rpb=rpb)
    pure_fw = do_bench(pure_attention_call)

    # Forward - end-to-end (with rearrange+QKV+attention+restore)
    def e2e_call():
        x_windows_tmp, shape_tmp = swin_window_rearrange(x, H_img, W_img, WINDOW)
        q_tmp, k_tmp, v_tmp = swin_qkv_projection(x_windows_tmp, num_heads)
        out_win = swin_attention_only(q_tmp, k_tmp, v_tmp, rpb=rpb)
        # out = restore_windows(out_win, shape_tmp)
        return out_win
    e2e_fw = do_bench(e2e_call)


    print(f"\nPerformance test results:")
    print(f"Window rearrange time: {window_rearrange_fw:.4f} ms")
    print(f"QKV projection time: {qkv_proj_fw:.4f} ms")
    print(f"Pure attention forward: {pure_fw:.4f} ms")
    print(f"End-to-end forward: {e2e_fw:.4f} ms")

test()


性能测试结果:
窗口重排时间: 0.1978 ms
QKV投影时间: 0.9361 ms
纯注意力前向: 0.6680 ms
端到端前向: 1.5784 ms


# HWA

In [None]:
def sgn(x):
    return -1 if x < 0 else (1 if x > 0 else 0)

def generate2d(x: int, y: int, ax: int, ay: int, bx: int, by: int, result):
    w = abs(ax + ay)
    h = abs(bx + by)
    dax, day = sgn(ax), sgn(ay)
    dbx, dby = sgn(bx), sgn(by)

    if h == 1 or w == 1:
        if h == 1:
            for _ in range(w):
                result.append((x, y))
                x, y = x + dax, y + day
        elif w == 1:
            for _ in range(h):
                result.append((x, y))
                x, y = x + dbx, y + dby
        return

    ax2, ay2 = ax // 2, ay // 2
    bx2, by2 = bx // 2, by // 2
    w2 = abs(ax2 + ay2)
    h2 = abs(bx2 + by2)

    if 2 * w > 3 * h:
        if w2 % 2 and w > 2:
            ax2, ay2 = ax2 + dax, ay2 + day
        generate2d(x, y, ax2, ay2, bx, by, result)
        generate2d(x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by, result)
    else:
        if h2 % 2 and h > 2:
            bx2, by2 = bx2 + dbx, by2 + dby
        generate2d(x, y, bx2, by2, ax2, ay2, result)
        generate2d(x + bx2, y + by2, ax, ay, bx - bx2, by - by2, result)
        generate2d(x + (ax - dax) + (bx2 - dbx),
                   y + (ay - day) + (by2 - dby),
                   -bx2, -by2, -(ax - ax2), -(ay - ay2), result)

def gilbert2d(width, height):
    result = []
    if width >= height:
        generate2d(0, 0, width, 0, 0, height, result)
    else:
        generate2d(0, 0, 0, height, width, 0, result)
    return result

class GilbertPathCache:
    def __init__(self):
        self.cache = {}
        self.device_index_cache = {}

    def get_or_create_path(self, H, W):
        key = (H, W)
        if key not in self.cache:
            path = gilbert2d(W, H)

            forward_map = torch.zeros((H, W), dtype=torch.long)
            reverse_map = torch.zeros((H * W, 2), dtype=torch.long)

            for idx, (x, y) in enumerate(path[:H*W]):
                if y < H and x < W:
                    forward_map[y, x] = idx
                    reverse_map[idx, 0] = y
                    reverse_map[idx, 1] = x

            self.cache[key] = {
                'path': path,
                'forward_map': forward_map,
                'reverse_map': reverse_map,
                'y_indices': reverse_map[:, 0].clone(),
                'x_indices': reverse_map[:, 1].clone(),
                'H': H,
                'W': W
            }

        return self.cache[key]

    def get_indices_on_device(self, H, W, device):
        device_key = (H, W, str(device))
        if device_key in self.device_index_cache:
            return self.device_index_cache[device_key]
        info = self.get_or_create_path(H, W)
        y_dev = info['y_indices'].to(device)
        x_dev = info['x_indices'].to(device)
        self.device_index_cache[device_key] = (y_dev, x_dev)
        return y_dev, x_dev

    def precompute_paths(self, resolutions):
        for H, W in resolutions:
            self.get_or_create_path(H, W)

    def clear_cache(self):
        self.cache.clear()

_global_gilbert_cache = GilbertPathCache()

def tensor_to_gilbert_path(x, cache=None):
    """
    Args:
        x: Input tensor, shape (B, H, W, C)
        cache: Optional GilbertPathCache instance, use global cache if None
    Returns:
        Reordered tensor, shape (B, H*W, C)
    """
    B, H, W, C = x.shape
    device = x.device
    if cache is None:
        cache = _global_gilbert_cache

    y_indices, x_indices = cache.get_indices_on_device(H, W, device)
    gilbert_tensor = x[:, y_indices, x_indices, :]  # (B, H*W, C)

    return gilbert_tensor

def gilbert_tensor_to_2d(x, H, W, cache=None):
    """
    Args:
        x: Gilbert sequence tensor, shape (B, H*W, C)
        H: Target height
        W: Target width
        cache: Optional GilbertPathCache instance, use global cache if None
    Returns:
        2D layout tensor, shape (B, H, W, C)
    """
    B, N, C = x.shape
    device = x.device

    if cache is None:
        cache = _global_gilbert_cache

    output_2d = torch.zeros((B, H, W, C), dtype=x.dtype, device=device)

    valid_n = min(N, H * W)
    if valid_n > 0:
        y_all, x_all = cache.get_indices_on_device(H, W, device)
        y_indices = y_all[:valid_n]
        x_indices = x_all[:valid_n]

        output_2d[:, y_indices, x_indices, :] = x[:, :valid_n, :]

    return output_2d

@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda",BLOCK_SIZE=128):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device,BLOCK_SIZE=BLOCK_SIZE, _compile=True)
    return block_mask

def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
    return multiplier * flops * (1e3 / time_ms) / 1e12

def hilbert_rearrange(x):
    x_seq = tensor_to_gilbert_path(x)  # (B, H_img*W_img, C)
    return x_seq

def hilbert_qkv_projection(x_seq, num_heads: int):
    B, S, C = x_seq.shape
    head_dim = C // num_heads

    if not hasattr(hilbert_qkv_projection, 'qkv_weight'):
        hilbert_qkv_projection.qkv_weight = torch.randn(
            C, 3*C, device=x_seq.device, dtype=x_seq.dtype
        )

    qkv = x_seq @ hilbert_qkv_projection.qkv_weight  # (B, H_img*W_img, 3*C)
    qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]

    return q, k, v

def prepare_hilbert_qkv(x, num_heads):
    # Hilbert reorder
    x_seq = hilbert_rearrange(x)

    # QKV projection and reshape
    q, k, v = hilbert_qkv_projection(x_seq, num_heads)

    return q, k, v

def hilbert_flex_attention_only(q, k, v, score_mod=None, block_mask=None):

    return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)


def hilbert_window_flex_attention(x, num_heads, score_mod=None, block_mask=None):

    x_seq = hilbert_rearrange(x)
    q, k, v = hilbert_qkv_projection(x_seq, num_heads)
    x_seq = hilbert_flex_attention_only(q, k, v, score_mod=score_mod, block_mask=block_mask)
    return x_seq

# Test performance of split functions
def test_split_performance():
    B, H_img, W_img, C, num_heads = 16, 64, 64, 128, 2
    BLOCK = 64  # Block size
    S = H_img * W_img

    # Prepare input
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    gradOut = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16)

    # Create block mask
    def block_window_1d(b, h, q_idx, kv_idx):
        return (q_idx // BLOCK) == (kv_idx // BLOCK)
    def score_mod_func(score, b, h, q_idx, kv_idx):
        rel_pos = (q_idx - kv_idx).to(score.dtype)
        return score + rel_pos

    block_mask = create_block_mask_cached(block_window_1d, 1, 1, S, S, device=x.device, BLOCK_SIZE=128)

    # Test Hilbert rearrangement only
    hilbert_rearrange_call = lambda: hilbert_rearrange(x)
    hilbert_rearrange_ms = do_bench(hilbert_rearrange_call)

    # Test QKV projection only
    x_seq = hilbert_rearrange_call()
    qkv_proj_call = lambda: hilbert_qkv_projection(x_seq, num_heads)
    qkv_proj_ms = do_bench(qkv_proj_call)

    # Test QKV preparation (combined)
    prepare_call = lambda: prepare_hilbert_qkv(x, num_heads)
    prepare_ms = do_bench(prepare_call)

    # Pre-compute QKV for pure attention test
    q, k, v = prepare_call()

    # Test pure flex attention only
    pure_attention_call = lambda: hilbert_flex_attention_only(q, k, v, score_mod=score_mod_func, block_mask=block_mask)
    pure_attention_ms = do_bench(pure_attention_call)

    # Test combined function
    combined_call = lambda: hilbert_window_flex_attention(x, num_heads, score_mod=score_mod_func, block_mask=block_mask)
    combined_ms = do_bench(combined_call)

    # Backward test
    combined_out = combined_call()
    pure_out = pure_attention_call()
    gradOut_seq = gradOut.view(B, H_img*W_img, num_heads, C // num_heads).permute(0, 2, 1, 3).contiguous()

    pure_bw_ms = do_bench(lambda: pure_out.backward(gradOut_seq, retain_graph=True))
    combined_bw_ms = do_bench(lambda: combined_out.backward(gradOut_seq, retain_graph=True))

    results = [
        ["Hilbert Rearrangement", f"{hilbert_rearrange_ms:.4f}", "-", "-", "-"],
        ["QKV Projection", f"{qkv_proj_ms:.4f}", "-", "-", "-"],
        ["QKV Preparation (Total)", f"{prepare_ms:.4f}", "-", "-", "-"],
        ["Pure Flex Attention", f"{pure_attention_ms:.4f}", "-", f"{pure_bw_ms:.4f}", "-"],
        ["Combined (Total)", f"{combined_ms:.4f}", "-", f"{combined_bw_ms:.4f}", "-"],
        ["Overhead", f"{combined_ms - pure_attention_ms:.4f}", "-", f"{combined_bw_ms - pure_bw_ms:.4f}", "-"],
    ]

    print(f"\nFunction performance test:")
    print(tabulate(results, headers=["Operation", "FW Time (ms)", "FW FLOPS (TF/s)", "BW Time (ms)", "BW FLOPS (TF/s)"], tablefmt="grid"))

    # Clean up
    del x, q, k, v, combined_out, pure_out
    torch.cuda.empty_cache()

test_split_performance()

# FlexWA

In [None]:
@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda",BLOCK_SIZE=128):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device,BLOCK_SIZE=BLOCK_SIZE, _compile=True)
    return block_mask

def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
    return multiplier * flops * (1e3 / time_ms) / 1e12

def prepare_hilbert_qkv(x, num_heads):

    B, H_img, W_img, C = x.shape
    S = H_img * W_img
    x_seq = x.view(B, S, C)  # Flatten the input

    # Add QKV projection
    if not hasattr(prepare_hilbert_qkv, 'qkv_weight'):
        prepare_hilbert_qkv.qkv_weight = torch.randn(
            C, 3*C, device=x_seq.device, dtype=x_seq.dtype
        )

    # QKV projection
    qkv = x_seq @ prepare_hilbert_qkv.qkv_weight  # (B, S, 3*C)
    qkv = qkv.view(B, S, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]  #shape: (B, num_heads, S, C//num_heads)

    return q, k, v

def flex_attention_only(q, k, v, score_mod=None, block_mask=None):

    return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)

def window_flex_attention(x, num_heads, score_mod=None, block_mask=None):

    q, k, v = prepare_hilbert_qkv(x, num_heads)
    x_seq = flex_attention_only(q, k, v, score_mod=score_mod, block_mask=block_mask)
    return x_seq

# Test performance of split functions
def test_split_performance():
    B, H_img, W_img, C, num_heads = 16, 128, 128, 128, 2
    S = H_img * W_img

    # Prepare input
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    gradOut = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16)

    # Create block mask
    WINDOW = 16  # Block size
    def get_x_y(idx):
        return idx // W_img, idx % W_img 

    def swin_window_mask(b, h, q_idx, kv_idx):
        q_x, q_y = get_x_y(q_idx)
        k_x, k_y = get_x_y(kv_idx)
        same_win_row = (q_x // WINDOW) == (k_x // WINDOW)
        same_win_col = (q_y // WINDOW) == (k_y // WINDOW)
        return same_win_row & same_win_col

    def score_mod_func(score, b, h, q_idx, kv_idx):
        rel_pos = (q_idx - kv_idx).to(score.dtype)
        return score + rel_pos
    block_mask = create_block_mask_cached(swin_window_mask, 1, 1, S, S, device=x.device, BLOCK_SIZE=128)
    # Test QKV preparation only
    prepare_call = lambda: prepare_hilbert_qkv(x, num_heads)
    prepare_ms = do_bench(prepare_call)

    # Pre-compute QKV for pure attention test
    q, k, v = prepare_call()

    # Test pure flex attention only
    pure_attention_call = lambda: flex_attention_only(q, k, v, score_mod=score_mod_func, block_mask=block_mask)
    pure_attention_ms = do_bench(pure_attention_call)

    # Test combined function
    combined_call = lambda: window_flex_attention(x, num_heads,score_mod=score_mod_func, block_mask=block_mask)
    combined_ms = do_bench(combined_call)

    # Backward test
    combined_out = combined_call()
    pure_out = pure_attention_call()
    gradOut_seq = gradOut.view(B, H_img*W_img, num_heads, C // num_heads).permute(0, 2, 1, 3).contiguous()

    pure_bw_ms = do_bench(lambda: pure_out.backward(gradOut_seq, retain_graph=True))
    combined_bw_ms = do_bench(lambda: combined_out.backward(gradOut_seq, retain_graph=True))

    results = [
        ["QKV Preparation", f"{prepare_ms:.4f}", "-", "-", "-"],
        ["Pure Flex Attention", f"{pure_attention_ms:.4f}", "-", f"{pure_bw_ms:.4f}", "-"],
        ["Combined (Total)", f"{combined_ms:.4f}", "-", f"{combined_bw_ms:.4f}", "-"],
        ["Overhead", f"{combined_ms - pure_attention_ms:.4f}", "-", f"{combined_bw_ms - pure_bw_ms:.4f}", "-"],
    ]

    print(f"\nFunction performance test:")
    print(tabulate(results, headers=["Operation", "FW Time (ms)", "FW FLOPS (TF/s)", "BW Time (ms)", "BW FLOPS (TF/s)"], tablefmt="grid"))

    # Clean up
    del x, q, k, v, combined_out, pure_out
    torch.cuda.empty_cache()

test_split_performance()


分离函数性能测试:
+---------------------+----------------+-------------------+----------------+-------------------+
| Operation           |   FW Time (ms) | FW FLOPS (TF/s)   | BW Time (ms)   | BW FLOPS (TF/s)   |
| QKV Preparation     |         0.3212 | -                 | -              | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Pure Flex Attention |         2.6972 | -                 | 13.1964        | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Combined (Total)    |         2.9507 | -                 | 13.2315        | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Overhead            |         0.2535 | -                 | 0.0351         | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+


Regular Slide Transformer Attention

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from triton.testing import do_bench

class SlideAttentionSplit(nn.Module):
    def __init__(
        self, input_resolution, dim, num_heads, ka, qkv_bias=True, qk_scale=None,
        attn_drop=0., proj_drop=0., padding_mode='zeros'):

        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim ** -0.5
        self.padding_mode = padding_mode
        self.ka = ka
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.dep_conv = nn.Conv2d(self.head_dim, self.ka*self.ka*self.head_dim, kernel_size=self.ka,
                                 bias=True, groups=self.head_dim, padding=self.ka//2, padding_mode=padding_mode)
        self.dep_conv1 = nn.Conv2d(self.head_dim, self.ka*self.ka*self.head_dim, kernel_size=self.ka,
                                  bias=True, groups=self.head_dim, padding=self.ka//2, padding_mode=padding_mode)

        self.reset_parameters()

        self.relative_position_bias_table = nn.Parameter(torch.zeros(1, self.num_heads, 1, self.ka*self.ka, 1, 1))
        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=3)

    def reset_parameters(self):
        # shift initialization for group convolution
        kernel = torch.zeros(self.ka*self.ka, self.ka, self.ka)
        for i in range(self.ka*self.ka):
            kernel[i, i//self.ka, i%self.ka] = 1.
        kernel = kernel.unsqueeze(1).repeat(self.head_dim, 1, 1, 1)
        self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=False)

    def prepare_qkv(self, x):
        B, L, C = x.shape
        H, W = self.input_resolution
        x = x.view(B, H, W, C)
        qkv = self.qkv(x)

        f_conv = qkv.permute(0, 3, 1, 2).reshape(B*self.num_heads, 3*self.head_dim, H, W)

        q = (f_conv[:, :self.head_dim, :, :] * self.scale).reshape(B, self.num_heads, self.head_dim, 1, H, W)
        k = f_conv[:, self.head_dim:2*self.head_dim, :, :]
        v = f_conv[:, 2*self.head_dim:, :, :]

        k_conv = (self.dep_conv(k) + self.dep_conv1(k))
        v_conv = (self.dep_conv(v) + self.dep_conv1(v))

        k = k_conv.view(B, self.num_heads, self.ka*self.ka, self.head_dim, H, W).permute(0, 1, 3, 2, 4, 5)
        v = v_conv.view(B, self.num_heads, self.ka*self.ka, self.head_dim, H, W).permute(0, 1, 3, 2, 4, 5)

        k = k + self.relative_position_bias_table

        return q, k, v, (B, L, C)

    def compute_attention(self, q, k, v, orig_shape):

        B, L, C = orig_shape
        H, W = self.input_resolution

        attn = (q * k).sum(2, keepdim=True) # B, self.num_heads, 1, ka*ka, H, W

        attn = self.softmax(attn)
        attn = self.attn_drop(attn)

        x = (attn * v).sum(3).reshape(B, C, H, W).permute(0, 2, 3, 1).view(B, L, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def forward(self, x, mask=None):

        q, k, v, orig_shape = self.prepare_qkv(x)
        return self.compute_attention(q, k, v, orig_shape)


def simple_bench(fn, warmup=25, rep=100):
    # warmup
    for _ in range(warmup):
        fn()

    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    for _ in range(rep):
        fn()
    end.record()
    torch.cuda.synchronize()

    return start.elapsed_time(end) / rep

def test_split_slide_attention():

    B, H_img, W_img, dim, num_heads, ka = 16, 56, 56, 128, 2, 7
    S = H_img * W_img

    slide_attn = SlideAttentionSplit(
        input_resolution=(H_img, W_img),
        dim=dim,
        num_heads=num_heads,
        ka=ka
    ).cuda().to(torch.float16)

    x = torch.randn(B, S, dim, device="cuda", dtype=torch.float16)

    # test full forward time
    fw_fn = lambda: slide_attn(x.clone())
    fw_time = simple_bench(fw_fn)

    # pre-compute q,k,v for separate attention calculation
    with torch.no_grad():
        q, k, v, orig_shape = slide_attn.prepare_qkv(x)

    # test pre-processing time
    prep_fn = lambda: slide_attn.prepare_qkv(x.clone())
    prep_time = simple_bench(prep_fn)

    attn_fn = lambda: slide_attn.compute_attention(q, k, v, orig_shape)
    attn_time = simple_bench(attn_fn)

    print(f"full forward time: {fw_time:.4f} ms")
    print(f"pre-processing time: {prep_time:.4f} ms")
    print(f"attention calculation time: {attn_time:.4f} ms")
    print(f"pre-processing + attention ≈ {prep_time + attn_time:.4f} ms")

    return fw_time, prep_time, attn_time

print("start testing...")
test_split_slide_attention()

SlideAttentionSplit 类已创建，forward方法已拆分为:
开始性能测试...
完整forward时间: 111.9284 ms
预处理时间: 106.5070 ms
注意力计算时间: 5.8193 ms
预处理+注意力 ≈ 112.3262 ms
注意力计算占比: 5.2%


(111.928408203125, 106.506953125, 5.819269409179688)

# NA2D (NAT)

In [None]:
import torch
import torch.nn as nn
from natten import na2d  #install it refer to the NATTEN github
from triton.testing import do_bench
from tabulate import tabulate
from functools import lru_cache

def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
    """计算TFLOPS"""
    return multiplier * flops * (1e3 / time_ms) / 1e12

def prepare_natten_qkv(x, num_heads):
    """Prepare QKV tensors from 2D input for NATTEN"""
    B, H_img, W_img, C = x.shape
    head_dim = C // num_heads

    if not hasattr(prepare_natten_qkv, 'qkv_weight'):
        prepare_natten_qkv.qkv_weight = torch.randn(
            C, 3*C, device=x.device, dtype=x.dtype
        )

    qkv = x @ prepare_natten_qkv.qkv_weight  # (B, H_img, W_img, 3*C)
    qkv = qkv.reshape(B, H_img, W_img, 3, num_heads, head_dim)

    qkv = qkv.permute(3, 0, 1, 2, 4, 5).contiguous()  # (3, B, H_img, W_img, num_heads, head_dim)
    q, k, v = qkv[0], qkv[1], qkv[2]  # [B, H_img, W_img, num_heads, head_dim]

    return q, k, v

def natten_attention_only(q, k, v, kernel_size=7):
    """Execute NATTEN 2D attention on pre-processed QKV tensors

    Args:
        q, k, v: Query, Key, Value tensors, each with shape (B, H_img, W_img, num_heads, head_dim)
        kernel_size: Window size for neighborhood attention

    Returns:
        Output tensor from NATTEN, shape (B, H_img, W_img, num_heads, head_dim)
    """
    return na2d(q, k, v, kernel_size=kernel_size)

def natten_2d_combined(x, num_heads, kernel_size=7):
    """Combined function for QKV preparation and NATTEN attention"""
    q, k, v = prepare_natten_qkv(x, num_heads)
    output = natten_attention_only(q, k, v, kernel_size=kernel_size)
    return output

def test_natten_split_performance():
    """Test performance of split NATTEN functions"""

    print(" NATTEN 2D test")
    print("="*60)

    # 配置参数
    B, H_img, W_img, C, num_heads = 16, 96, 96, 128, 2
    kernel_size = 17
    S = H_img * W_img
    head_dim = C // num_heads

    print(f"\n test configuration:")
    print(f"   batch size: {B}")
    print(f"   image size: {H_img}×{W_img}")
    print(f"   sequence length: {S}")
    print(f"   feature dimension: {C}")
    print(f"   number of attention heads: {num_heads}")
    print(f"   head dimension: {head_dim}")
    print(f"   window size: {kernel_size}×{kernel_size}")

    # 准备输入
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    grad_out = torch.randn(B, H_img, W_img, num_heads, head_dim, device="cuda", dtype=torch.float16)

    print(f"\n start testing...")

    prepare_call = lambda: prepare_natten_qkv(x, num_heads)
    prepare_ms = do_bench(prepare_call)

    prepare_flops = B * S * C * C * 3  # QKV projection
    prepare_tflops = calculate_tflops(prepare_flops, prepare_ms, 2)

    q, k, v = prepare_call()

    pure_attention_call = lambda: natten_attention_only(q, k, v, kernel_size=kernel_size)
    pure_attention_ms = do_bench(pure_attention_call)

    # kernel_area = kernel_size * kernel_size
    # density = min(kernel_area / S, 1.0)
    # attention_flops = density * B * S * num_heads * head_dim * S
    # pure_attention_tflops = calculate_tflops(attention_flops, pure_attention_ms, 4)

    # ========== test combined function ==========
    combined_call = lambda: natten_2d_combined(x, num_heads, kernel_size=kernel_size)
    combined_ms = do_bench(combined_call)

    # combined_flops = prepare_flops + attention_flops
    # combined_tflops = calculate_tflops(combined_flops, combined_ms, 3)

    # ========== Backward ==========
    # pure attention backward
    pure_out = pure_attention_call()
    def pure_backward():
        out = pure_out.clone().requires_grad_(True)
        out.backward(grad_out, retain_graph=True)
    pure_bw_ms = do_bench(pure_backward)
    # pure_bw_tflops = calculate_tflops(attention_flops, pure_bw_ms, 10)

    # combined function backward
    combined_out = combined_call()
    def combined_backward():
        out = combined_out.clone().requires_grad_(True)
        out.backward(grad_out, retain_graph=True)
    combined_bw_ms = do_bench(combined_backward)
    # combined_bw_tflops = calculate_tflops(combined_flops, combined_bw_ms, 6)

    overhead_fw_ms = combined_ms - pure_attention_ms - prepare_ms
    overhead_bw_ms = combined_bw_ms - pure_bw_ms

    results = [
        ["QKV preparation",
         f"{prepare_ms:.3f}",
         f"{prepare_tflops:.2f}",
         "-",
         "-"],
        ["pure NATTEN attention",
         f"{pure_attention_ms:.3f}",
         f"{pure_bw_ms:.3f}"],
        ["combined function (total)",
         f"{combined_ms:.3f}",
         f"{combined_bw_ms:.3f}"],
    ]

    print(f"\n split function performance test results:")
    print(tabulate(
        results,
        headers=["operation", "forward time (ms)", "forward TFLOPS", "backward time (ms)", "backward TFLOPS"],
        tablefmt="grid"
    ))


    print(f"\n key performance indicators:")
    print(f"  forward time: {combined_ms:.3f}ms")
    print(f"  backward time: {combined_bw_ms:.3f}ms")

    # 清理
    del x, q, k, v, combined_out, pure_out, grad_out
    torch.cuda.empty_cache()

# 运行测试
if __name__ == "__main__":
    test_natten_split_performance()

🚀 NATTEN 2D 分离函数性能测试

📊 测试配置:
  批量大小: 16
  图像尺寸: 96×96
  序列长度: 9216
  特征维度: 128
  注意力头数: 2
  每头维度: 64
  窗口大小: 17×17

⚙️ 开始性能测试...

📊 分离函数性能测试结果:
+----------------+----------------+--------------+----------------+--------------+
| 操作           |   前向时间(ms) | 前向TFLOPS   | 反向时间(ms)   | 反向TFLOPS   |
| QKV准备        |          0.65  | 22.30        | -              | -            |
+----------------+----------------+--------------+----------------+--------------+
| 纯NATTEN注意力 |          1.27  | 17.18        | 9.953          | 5.48         |
+----------------+----------------+--------------+----------------+--------------+
| 组合函数(总计) |          1.905 | 20.01        | 9.915          | 7.69         |
+----------------+----------------+--------------+----------------+--------------+
| 额外开销       |         -0.016 | -            | -0.038         | -            |
+----------------+----------------+--------------+----------------+--------------+

📈 时间占比分析:
  QKV准备占前向时间: 34.1%
  NATTEN注意力占前向时间: 66.7%


# HNA (NAT)

In [None]:
import torch
import torch.nn as nn
from natten import na2d, na1d
from triton.testing import do_bench
from tabulate import tabulate
from functools import lru_cache
def sgn(x):
    return -1 if x < 0 else (1 if x > 0 else 0)

def generate2d(x: int, y: int, ax: int, ay: int, bx: int, by: int, result):
    w = abs(ax + ay)
    h = abs(bx + by)
    dax, day = sgn(ax), sgn(ay)
    dbx, dby = sgn(bx), sgn(by)

    if h == 1 or w == 1:
        if h == 1:
            for _ in range(w):
                result.append((x, y))
                x, y = x + dax, y + day
        elif w == 1:
            for _ in range(h):
                result.append((x, y))
                x, y = x + dbx, y + dby
        return

    ax2, ay2 = ax // 2, ay // 2
    bx2, by2 = bx // 2, by // 2
    w2 = abs(ax2 + ay2)
    h2 = abs(bx2 + by2)

    if 2 * w > 3 * h:
        if w2 % 2 and w > 2:
            ax2, ay2 = ax2 + dax, ay2 + day
        generate2d(x, y, ax2, ay2, bx, by, result)
        generate2d(x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by, result)
    else:
        if h2 % 2 and h > 2:
            bx2, by2 = bx2 + dbx, by2 + dby
        generate2d(x, y, bx2, by2, ax2, ay2, result)
        generate2d(x + bx2, y + by2, ax, ay, bx - bx2, by - by2, result)
        generate2d(x + (ax - dax) + (bx2 - dbx),
                   y + (ay - day) + (by2 - dby),
                   -bx2, -by2, -(ax - ax2), -(ay - ay2), result)

def gilbert2d(width, height):
    result = []
    if width >= height:
        generate2d(0, 0, width, 0, 0, height, result)
    else:
        generate2d(0, 0, 0, height, width, 0, result)
    return result

class GilbertPathCache:
    def __init__(self):
        self.cache = {}
        self.device_index_cache = {}

    def get_or_create_path(self, H, W):
        key = (H, W)
        if key not in self.cache:
            path = gilbert2d(W, H)

            forward_map = torch.zeros((H, W), dtype=torch.long)
            reverse_map = torch.zeros((H * W, 2), dtype=torch.long)

            for idx, (x, y) in enumerate(path[:H*W]):
                if y < H and x < W:
                    forward_map[y, x] = idx
                    reverse_map[idx, 0] = y
                    reverse_map[idx, 1] = x

            self.cache[key] = {
                'path': path,
                'forward_map': forward_map,
                'reverse_map': reverse_map,
                'y_indices': reverse_map[:, 0].clone(),
                'x_indices': reverse_map[:, 1].clone(),
                'H': H,
                'W': W
            }

        return self.cache[key]

    def get_indices_on_device(self, H, W, device):
        device_key = (H, W, str(device))
        if device_key in self.device_index_cache:
            return self.device_index_cache[device_key]
        info = self.get_or_create_path(H, W)
        y_dev = info['y_indices'].to(device)
        x_dev = info['x_indices'].to(device)
        self.device_index_cache[device_key] = (y_dev, x_dev)
        return y_dev, x_dev

    def precompute_paths(self, resolutions):
        for H, W in resolutions:
            self.get_or_create_path(H, W)

    def clear_cache(self):
        self.cache.clear()

_global_gilbert_cache = GilbertPathCache()

def tensor_to_gilbert_path(x, cache=None):
    """
    Args:
        x: Input tensor, shape (B, H, W, C)
        cache: Optional GilbertPathCache instance, use global cache if None
    Returns:
        Reordered tensor, shape (B, H*W, C)
    """
    B, H, W, C = x.shape
    device = x.device
    if cache is None:
        cache = _global_gilbert_cache

    y_indices, x_indices = cache.get_indices_on_device(H, W, device)
    gilbert_tensor = x[:, y_indices, x_indices, :]  # (B, H*W, C)

    return gilbert_tensor

def gilbert_tensor_to_2d(x, H, W, cache=None):
    """
    Args:
        x: Gilbert sequence tensor, shape (B, H*W, C)
        H: Target height
        W: Target width
        cache: Optional GilbertPathCache instance, use global cache if None
    Returns:
        2D layout tensor, shape (B, H, W, C)
    """
    B, N, C = x.shape
    device = x.device

    if cache is None:
        cache = _global_gilbert_cache

    output_2d = torch.zeros((B, H, W, C), dtype=x.dtype, device=device)

    valid_n = min(N, H * W)
    if valid_n > 0:
        y_all, x_all = cache.get_indices_on_device(H, W, device)
        y_indices = y_all[:valid_n]
        x_indices = x_all[:valid_n]

        output_2d[:, y_indices, x_indices, :] = x[:, :valid_n, :]

    return output_2d
def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
    """计算TFLOPS"""
    return multiplier * flops * (1e3 / time_ms) / 1e12

def natten_hilbert_rearrange(x):

    x_seq = tensor_to_gilbert_path(x)  # (B, H_img*W_img, C)
    return x_seq

def natten_qkv_projection(x_seq, num_heads: int):

    B, S, C = x_seq.shape
    head_dim = C // num_heads

    if not hasattr(natten_qkv_projection, 'qkv_weight'):
        natten_qkv_projection.qkv_weight = torch.randn(
            C, 3*C, device=x_seq.device, dtype=x_seq.dtype
        )

    # QKV projection
    qkv = x_seq @ natten_qkv_projection.qkv_weight  # (B, S, 3*C)
    qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 1, 3, 4).contiguous()  # (3, B, S, H, D)
    q, k, v = qkv[0], qkv[1], qkv[2]  

    return q, k, v

def prepare_natten_qkv_with_hilbert(x, num_heads):

    x_seq = natten_hilbert_rearrange(x)

    q, k, v = natten_qkv_projection(x_seq, num_heads)

    return q, k, v

def natten_attention_only(q, k, v, kernel_size=7):

    return na1d(q, k, v, kernel_size=kernel_size)

def natten_1d_combined(x, num_heads, kernel_size=7):

    x_seq = natten_hilbert_rearrange(x)
    q, k, v = natten_qkv_projection(x_seq, num_heads)
    output = natten_attention_only(q, k, v, kernel_size=kernel_size)
    return output

def test_natten_split_performance():
    """Test performance of split NATTEN functions"""

    # 配置参数
    B, H_img, W_img, C, num_heads = 16, 96, 96, 128, 2
    kernel_size = 289
    S = H_img * W_img
    head_dim = C // num_heads

    print(f"\n Test configuration:")
    print(f"   batch size: {B}")
    print(f"   image size: {H_img}×{W_img}")
    print(f"   sequence length: {S}")
    print(f"   feature dimension: {C}")
    print(f"   number of attention heads: {num_heads}")
    print(f"   head dimension: {head_dim}")
    print(f"   window size: {kernel_size}×{kernel_size}")

    # 准备输入
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    grad_out = torch.randn(B, S, num_heads, head_dim, device="cuda", dtype=torch.float16)

    print(f"\n start testing...")

    hilbert_rearrange_call = lambda: natten_hilbert_rearrange(x)
    hilbert_rearrange_ms = do_bench(hilbert_rearrange_call)


    x_seq = hilbert_rearrange_call()
    qkv_proj_call = lambda: natten_qkv_projection(x_seq, num_heads)
    qkv_proj_ms = do_bench(qkv_proj_call)


    prepare_call = lambda: prepare_natten_qkv_with_hilbert(x, num_heads)
    prepare_ms = do_bench(prepare_call)

    print(f"\n Hilbert reorder:")
    print(f"   Time: {hilbert_rearrange_ms:.4f}ms")
    print(f"\n QKV projection:")
    print(f"   Time: {qkv_proj_ms:.4f}ms")
    print(f"\n Hilbert sequence + QKV preparation:")
    print(f"   Time: {prepare_ms:.4f}ms")

    q, k, v = prepare_call()

    pure_attention_call = lambda: natten_attention_only(q, k, v, kernel_size=kernel_size)
    pure_attention_ms = do_bench(pure_attention_call)


    print(f"\n Pure NATTEN 1D attention:")
    print(f"   Forward time: {pure_attention_ms:.4f}ms")

    combined_call = lambda: natten_1d_combined(x, num_heads, kernel_size=kernel_size)
    combined_ms = do_bench(combined_call)

    print(f"\n Combined function (Hilbert + QKV + NATTEN):")
    print(f"   Total time: {combined_ms:.4f}ms")

    pure_out = pure_attention_call()
    def pure_backward():
        out = pure_out.clone().requires_grad_(True)
        out.backward(grad_out, retain_graph=True)
    pure_bw_ms = do_bench(pure_backward)

    combined_out = combined_call()
    def combined_backward():
        out = combined_out.clone().requires_grad_(True)
        out.backward(grad_out, retain_graph=True)
    combined_bw_ms = do_bench(combined_backward)

    results = [
        ["Hilbert reorder",
         f"{hilbert_rearrange_ms:.4f}",
         "-",
         "-",
         "-"],
        ["QKV projection",
         f"{qkv_proj_ms:.4f}",
         "-",
         "-",
         "-"],
        ["Hilbert+QKV preparation",
         f"{prepare_ms:.4f}",
         "-",
         "-",
         "-"],
        ["Pure NATTEN 1D",
         f"{pure_attention_ms:.4f}",
         f"{pure_bw_ms:.4f}",],
        ["Combined function",
         f"{combined_ms:.4f}",
         f"{combined_bw_ms:.4f}",],
    ]

    print(f"\n Performance test results:")
    print(tabulate(
        results,
        headers=["Operation", "Forward (ms)", "Backward (ms)"],
        tablefmt="grid"
    ))


    # clear memory  
    del x, q, k, v, combined_out, pure_out, grad_out
    torch.cuda.empty_cache()

if __name__ == "__main__":
    test_natten_split_performance()


�� NATTEN 2D 分离函数性能测试

�� 测试配置:
  批量大小: 16
  图像尺寸: 96×96
  序列长度: 9216
  特征维度: 128
  注意力头数: 2
  每头维度: 64
  窗口大小: 289×289

⚙️ 开始性能测试...

1️⃣ Hilbert重排:
   时间: 0.2201ms

2️⃣ QKV投影:
   时间: 0.6507ms

3️⃣ Hilbert序列 + QKV准备:
   时间: 0.8679ms

4️⃣ 纯NATTEN 1D注意力:
   前向时间: 0.7749ms
   前向TFLOPS: 28.16
   稀疏密度: 3.14%

5️⃣ 组合函数 (Hilbert + QKV + NATTEN):
   总时间: 1.6144ms
   TFLOPS: 13.51

📊 性能测试结果:
+-----------------+------------+------------+------------+------------+
| 操作            |   前向(ms) | 前向TF/s   | 反向(ms)   | 反向TF/s   |
| Hilbert重排     |     0.2201 | -          | -          | -          |
+-----------------+------------+------------+------------+------------+
| QKV投影         |     0.6507 | -          | -          | -          |
+-----------------+------------+------------+------------+------------+
| Hilbert+QKV准备 |     0.8679 | -          | -          | -          |
+-----------------+------------+------------+------------+------------+
| 纯NATTEN 1D     |     0.7749 | 28.16      | 6.6319  

# HSA

In [None]:
def sgn(x):
    return -1 if x < 0 else (1 if x > 0 else 0)

def generate2d(x: int, y: int, ax: int, ay: int, bx: int, by: int, result):
    w = abs(ax + ay)
    h = abs(bx + by)
    dax, day = sgn(ax), sgn(ay)
    dbx, dby = sgn(bx), sgn(by)

    if h == 1 or w == 1:
        if h == 1:
            for _ in range(w):
                result.append((x, y))
                x, y = x + dax, y + day
        elif w == 1:
            for _ in range(h):
                result.append((x, y))
                x, y = x + dbx, y + dby
        return

    ax2, ay2 = ax // 2, ay // 2
    bx2, by2 = bx // 2, by // 2
    w2 = abs(ax2 + ay2)
    h2 = abs(bx2 + by2)

    if 2 * w > 3 * h:
        if w2 % 2 and w > 2:
            ax2, ay2 = ax2 + dax, ay2 + day
        generate2d(x, y, ax2, ay2, bx, by, result)
        generate2d(x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by, result)
    else:
        if h2 % 2 and h > 2:
            bx2, by2 = bx2 + dbx, by2 + dby
        generate2d(x, y, bx2, by2, ax2, ay2, result)
        generate2d(x + bx2, y + by2, ax, ay, bx - bx2, by - by2, result)
        generate2d(x + (ax - dax) + (bx2 - dbx),
                   y + (ay - day) + (by2 - dby),
                   -bx2, -by2, -(ax - ax2), -(ay - ay2), result)

def gilbert2d(width, height):
    result = []
    if width >= height:
        generate2d(0, 0, width, 0, 0, height, result)
    else:
        generate2d(0, 0, 0, height, width, 0, result)
    return result

class GilbertPathCache:
    def __init__(self):
        self.cache = {}
        self.device_index_cache = {}

    def get_or_create_path(self, H, W):
        key = (H, W)
        if key not in self.cache:
            path = gilbert2d(W, H)

            forward_map = torch.zeros((H, W), dtype=torch.long)
            reverse_map = torch.zeros((H * W, 2), dtype=torch.long)

            for idx, (x, y) in enumerate(path[:H*W]):
                if y < H and x < W:
                    forward_map[y, x] = idx
                    reverse_map[idx, 0] = y
                    reverse_map[idx, 1] = x

            self.cache[key] = {
                'path': path,
                'forward_map': forward_map,
                'reverse_map': reverse_map,
                'y_indices': reverse_map[:, 0].clone(),
                'x_indices': reverse_map[:, 1].clone(),
                'H': H,
                'W': W
            }

        return self.cache[key]

    def get_indices_on_device(self, H, W, device):
        device_key = (H, W, str(device))
        if device_key in self.device_index_cache:
            return self.device_index_cache[device_key]
        info = self.get_or_create_path(H, W)
        y_dev = info['y_indices'].to(device)
        x_dev = info['x_indices'].to(device)
        self.device_index_cache[device_key] = (y_dev, x_dev)
        return y_dev, x_dev

    def precompute_paths(self, resolutions):
        for H, W in resolutions:
            self.get_or_create_path(H, W)

    def clear_cache(self):
        self.cache.clear()

_global_gilbert_cache = GilbertPathCache()

def tensor_to_gilbert_path(x, cache=None):
    """
    Args:
        x: Input tensor, shape (B, H, W, C)
        cache: Optional GilbertPathCache instance, use global cache if None
    Returns:
        Reordered tensor, shape (B, H*W, C)
    """
    B, H, W, C = x.shape
    device = x.device
    if cache is None:
        cache = _global_gilbert_cache

    y_indices, x_indices = cache.get_indices_on_device(H, W, device)
    gilbert_tensor = x[:, y_indices, x_indices, :]  # (B, H*W, C)

    return gilbert_tensor

def gilbert_tensor_to_2d(x, H, W, cache=None):
    """
    Args:
        x: Gilbert sequence tensor, shape (B, H*W, C)
        H: Target height
        W: Target width
        cache: Optional GilbertPathCache instance, use global cache if None
    Returns:
        2D layout tensor, shape (B, H, W, C)
    """
    B, N, C = x.shape
    device = x.device

    if cache is None:
        cache = _global_gilbert_cache

    output_2d = torch.zeros((B, H, W, C), dtype=x.dtype, device=device)

    valid_n = min(N, H * W)
    if valid_n > 0:
        y_all, x_all = cache.get_indices_on_device(H, W, device)
        y_indices = y_all[:valid_n]
        x_indices = x_all[:valid_n]

        output_2d[:, y_indices, x_indices, :] = x[:, :valid_n, :]

    return output_2d

@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda",BLOCK_SIZE=128):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device,BLOCK_SIZE=BLOCK_SIZE, _compile=True)
    return block_mask

def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
    return multiplier * flops * (1e3 / time_ms) / 1e12

def hilbert_rearrange(x):

    x_seq = tensor_to_gilbert_path(x)  # (B, H_img*W_img, C)
    return x_seq

def hilbert_qkv_projection(x_seq, num_heads: int):
    B, S, C = x_seq.shape
    head_dim = C // num_heads

    if not hasattr(hilbert_qkv_projection, 'qkv_weight'):
        hilbert_qkv_projection.qkv_weight = torch.randn(
            C, 3*C, device=x_seq.device, dtype=x_seq.dtype
        )

    qkv = x_seq @ hilbert_qkv_projection.qkv_weight  # (B, H_img*W_img, 3*C)
    qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]  # 每个形状: (B, num_heads, H_img*W_img, C//num_heads)

    return q, k, v

def prepare_hilbert_qkv(x, num_heads):
    x_seq = hilbert_rearrange(x)

    q, k, v = hilbert_qkv_projection(x_seq, num_heads)

    return q, k, v

def hilbert_flex_attention_only(q, k, v, score_mod=None, block_mask=None):

    return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)

def hilbert_window_flex_attention(x, num_heads, score_mod=None, block_mask=None):

    x_seq = hilbert_rearrange(x)
    q, k, v = hilbert_qkv_projection(x_seq, num_heads)
    x_seq = hilbert_flex_attention_only(q, k, v, score_mod=score_mod, block_mask=block_mask)
    return x_seq

def test_split_performance():
    B, H_img, W_img, C, num_heads = 16, 96, 96, 128, 2
    WINDOW_SIZE = 121 
    S = H_img * W_img

    # Prepare input
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    gradOut = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16)

    # Create block mask
    def sliding_window_mask(b, h, q_idx, kv_idx):
        return (q_idx - kv_idx).abs() <= WINDOW_SIZE // 2
    def score_mod_func(score, b, h, q_idx, kv_idx):
        rel_pos = (q_idx - kv_idx).to(score.dtype)
        return score + rel_pos
    block_mask = create_block_mask_cached(sliding_window_mask, 1, 1, S, S, device=x.device, BLOCK_SIZE=128)

    # Test Hilbert rearrangement only
    hilbert_rearrange_call = lambda: hilbert_rearrange(x)
    hilbert_rearrange_ms = do_bench(hilbert_rearrange_call)

    # Test QKV projection only
    x_seq = hilbert_rearrange_call()
    qkv_proj_call = lambda: hilbert_qkv_projection(x_seq, num_heads)
    qkv_proj_ms = do_bench(qkv_proj_call)

    # Test QKV preparation (combined)
    prepare_call = lambda: prepare_hilbert_qkv(x, num_heads)
    prepare_ms = do_bench(prepare_call)

    # Pre-compute QKV for pure attention test
    q, k, v = prepare_call()

    # Test pure flex attention only
    pure_attention_call = lambda: hilbert_flex_attention_only(q, k, v, score_mod=score_mod_func, block_mask=block_mask)
    pure_attention_ms = do_bench(pure_attention_call)

    # Test combined function
    combined_call = lambda: hilbert_window_flex_attention(x, num_heads, score_mod=score_mod_func, block_mask=block_mask)
    combined_ms = do_bench(combined_call)

    # Backward test
    combined_out = combined_call()
    pure_out = pure_attention_call()
    gradOut_seq = gradOut.view(B, H_img*W_img, num_heads, C // num_heads).permute(0, 2, 1, 3).contiguous()

    pure_bw_ms = do_bench(lambda: pure_out.backward(gradOut_seq, retain_graph=True))
    combined_bw_ms = do_bench(lambda: combined_out.backward(gradOut_seq, retain_graph=True))

    results = [
        ["Hilbert Rearrangement", f"{hilbert_rearrange_ms:.4f}", "-", "-", "-"],
        ["QKV Projection", f"{qkv_proj_ms:.4f}", "-", "-", "-"],
        ["QKV Preparation (Total)", f"{prepare_ms:.4f}", "-", "-", "-"],
        ["Pure Flex Attention", f"{pure_attention_ms:.4f}", "-", f"{pure_bw_ms:.4f}", "-"],
        ["Combined (Total)", f"{combined_ms:.4f}", "-", f"{combined_bw_ms:.4f}", "-"],
    ]

    print(f"\nTest split performance:")
    print(tabulate(results, headers=["Operation", "FW Time (ms)", "BW Time (ms)"], tablefmt="grid"))

    # Clean up
    del x, q, k, v, combined_out, pure_out
    torch.cuda.empty_cache()

test_split_performance()


分离函数性能测试:
+-------------------------+----------------+-------------------+----------------+-------------------+
| Operation               |   FW Time (ms) | FW FLOPS (TF/s)   | BW Time (ms)   | BW FLOPS (TF/s)   |
| Hilbert Rearrangement   |         0.2395 | -                 | -              | -                 |
+-------------------------+----------------+-------------------+----------------+-------------------+
| QKV Projection          |         0.1892 | -                 | -              | -                 |
+-------------------------+----------------+-------------------+----------------+-------------------+
| QKV Preparation (Total) |         0.3507 | -                 | -              | -                 |
+-------------------------+----------------+-------------------+----------------+-------------------+
| Pure Flex Attention     |         0.3017 | -                 | 3.0607         | -                 |
+-------------------------+----------------+-------------------+-------

# FlexSA

In [None]:
@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda",BLOCK_SIZE=128):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device,BLOCK_SIZE=BLOCK_SIZE, _compile=True)
    return block_mask

def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
    return multiplier * flops * (1e3 / time_ms) / 1e12

def prepare_hilbert_qkv(x, num_heads):

    B, H_img, W_img, C = x.shape
    S = H_img * W_img
    x_seq = x.view(B, S, C) 

    if not hasattr(prepare_hilbert_qkv, 'qkv_weight'):
        prepare_hilbert_qkv.qkv_weight = torch.randn(
            C, 3*C, device=x_seq.device, dtype=x_seq.dtype
        )

    # QKV projection
    qkv = x_seq @ prepare_hilbert_qkv.qkv_weight  # (B, S, 3*C)
    qkv = qkv.view(B, S, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2] 

    return q, k, v

def flex_attention_only(q, k, v, score_mod=None, block_mask=None):

    return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)

def window_flex_attention(x, num_heads, score_mod=None, block_mask=None):
    """Original combined function for backward compatibility"""
    q, k, v = prepare_hilbert_qkv(x, num_heads)
    x_seq = flex_attention_only(q, k, v, score_mod=score_mod, block_mask=block_mask)
    return x_seq

# Test performance of split functions
def test_split_performance():
    B, H_img, W_img, C, num_heads = 16, 128, 128, 128, 2
    S = H_img * W_img
    WINDOW = 17

    # Prepare input
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    gradOut = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16)

    # Create block mask
    def sasa_mask(b, h, q_idx, kv_idx):

        def get_x_y(idx):
            return idx // W_img, idx % W_img

        q_x, q_y = get_x_y(q_idx)
        kv_x, kv_y = get_x_y(kv_idx)
        horizontal_mask = (q_x - kv_x).abs() <= WINDOW // 2
        vertical_mask = (q_y - kv_y).abs() <= WINDOW // 2
        return horizontal_mask & vertical_mask

    def score_mod_func(score, b, h, q_idx, kv_idx):
        rel_pos = (q_idx - kv_idx).to(score.dtype)
        return score + rel_pos

    block_mask = create_block_mask_cached(sasa_mask, 1, 1, S, S, device=x.device, BLOCK_SIZE=128)
    # Test QKV preparation only
    prepare_call = lambda: prepare_hilbert_qkv(x, num_heads)
    prepare_ms = do_bench(prepare_call)

    # Pre-compute QKV for pure attention test
    q, k, v = prepare_call()

    # Test pure flex attention only
    pure_attention_call = lambda: flex_attention_only(q, k, v, score_mod=score_mod_func, block_mask=block_mask)
    pure_attention_ms = do_bench(pure_attention_call)

    # Test combined function
    combined_call = lambda: window_flex_attention(x, num_heads,score_mod=score_mod_func, block_mask=block_mask)
    combined_ms = do_bench(combined_call)

    # Backward test
    combined_out = combined_call()
    pure_out = pure_attention_call()
    gradOut_seq = gradOut.view(B, H_img*W_img, num_heads, C // num_heads).permute(0, 2, 1, 3).contiguous()

    pure_bw_ms = do_bench(lambda: pure_out.backward(gradOut_seq, retain_graph=True))
    combined_bw_ms = do_bench(lambda: combined_out.backward(gradOut_seq, retain_graph=True))

    results = [
        ["QKV Preparation", f"{prepare_ms:.4f}", "-", "-", "-"],
        ["Pure Flex Attention", f"{pure_attention_ms:.4f}", "-", f"{pure_bw_ms:.4f}", "-"],
        ["Combined (Total)", f"{combined_ms:.4f}", "-", f"{combined_bw_ms:.4f}", "-"],
    ]

    print(f"\nTest split performance:")
    print(tabulate(results, headers=["Operation", "FW Time (ms)", "BW Time (ms)"], tablefmt="grid"))

    # Clean up
    del x, q, k, v, combined_out, pure_out
    torch.cuda.empty_cache()

test_split_performance()


分离函数性能测试:
+---------------------+----------------+-------------------+----------------+-------------------+
| Operation           |   FW Time (ms) | FW FLOPS (TF/s)   | BW Time (ms)   | BW FLOPS (TF/s)   |
| QKV Preparation     |         0.3211 | -                 | -              | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Pure Flex Attention |         4.123  | -                 | 13.8566        | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Combined (Total)    |         4.2111 | -                 | 13.6113        | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Overhead            |         0.0881 | -                 | -0.2453        | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+


# NA2D (Flex)

In [None]:
@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda",BLOCK_SIZE=128):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device,BLOCK_SIZE=BLOCK_SIZE, _compile=True)
    return block_mask

def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
    return multiplier * flops * (1e3 / time_ms) / 1e12

def prepare_hilbert_qkv(x, num_heads):

    B, H_img, W_img, C = x.shape
    S = H_img * W_img
    x_seq = x.view(B, S, C) 

    if not hasattr(prepare_hilbert_qkv, 'qkv_weight'):
        prepare_hilbert_qkv.qkv_weight = torch.randn(
            C, 3*C, device=x_seq.device, dtype=x_seq.dtype
        )

    # QKV projection
    qkv = x_seq @ prepare_hilbert_qkv.qkv_weight  # (B, S, 3*C)
    qkv = qkv.view(B, S, 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]

    return q, k, v

def flex_attention_only(q, k, v, score_mod=None, block_mask=None):

    return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)

def window_flex_attention(x, num_heads, score_mod=None, block_mask=None):

    q, k, v = prepare_hilbert_qkv(x, num_heads)
    x_seq = flex_attention_only(q, k, v, score_mod=score_mod, block_mask=block_mask)
    return x_seq

# Test performance of split functions
def test_split_performance():
    B, H_img, W_img, C, num_heads = 16, 56, 56, 128, 2
    S = H_img * W_img

    # Prepare input
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    gradOut = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16)

    # Create block mask
    K_W = 7
    K_H = 7
    def get_x_y(idx):
        return idx // W_img, idx % W_img


    def natten_mask(
        b,
        h,
        q_idx,
        kv_idx,
    ):
        q_x, q_y = get_x_y(q_idx)
        kv_x, kv_y = get_x_y(kv_idx)

        kernel_x = q_x.clamp(K_W // 2, (W_img - 1) - K_W // 2)
        kernel_y = q_y.clamp(K_H // 2, (H_img - 1) - K_H // 2)
        hori_mask = (kernel_x - kv_x).abs() <= K_W // 2
        vert_mask = (kernel_y - kv_y).abs() <= K_H // 2
        return hori_mask & vert_mask

    def score_mod_func(score, b, h, q_idx, kv_idx):
        rel_pos = (q_idx - kv_idx).to(score.dtype)
        return score + rel_pos

    block_mask = create_block_mask_cached(natten_mask, 1, 1, S, S, device=x.device, BLOCK_SIZE=128)
    # Test QKV preparation only
    prepare_call = lambda: prepare_hilbert_qkv(x, num_heads)
    prepare_ms = do_bench(prepare_call)

    # Pre-compute QKV for pure attention test
    q, k, v = prepare_call()

    # Test pure flex attention only
    pure_attention_call = lambda: flex_attention_only(q, k, v, score_mod=score_mod_func, block_mask=block_mask)
    pure_attention_ms = do_bench(pure_attention_call)

    # Test combined function
    combined_call = lambda: window_flex_attention(x, num_heads,score_mod=score_mod_func, block_mask=block_mask)
    combined_ms = do_bench(combined_call)

    # Backward test
    combined_out = combined_call()
    pure_out = pure_attention_call()
    gradOut_seq = gradOut.view(B, H_img*W_img, num_heads, C // num_heads).permute(0, 2, 1, 3).contiguous()

    pure_bw_ms = do_bench(lambda: pure_out.backward(gradOut_seq, retain_graph=True))
    combined_bw_ms = do_bench(lambda: combined_out.backward(gradOut_seq, retain_graph=True))

    results = [
        ["QKV Preparation", f"{prepare_ms:.4f}", "-", "-", "-"],
        ["Pure Flex Attention", f"{pure_attention_ms:.4f}", "-", f"{pure_bw_ms:.4f}", "-"],
        ["Combined (Total)", f"{combined_ms:.4f}", "-", f"{combined_bw_ms:.4f}", "-"],
    ]

    print(f"\nTest split performance:")
    print(tabulate(results, headers=["Operation", "FW Time (ms)", "BW Time (ms)"], tablefmt="grid"))

    # Clean up
    del x, q, k, v, combined_out, pure_out
    torch.cuda.empty_cache()

test_split_performance()


分离函数性能测试:
+---------------------+----------------+-------------------+----------------+-------------------+
| Operation           |   FW Time (ms) | FW FLOPS (TF/s)   | BW Time (ms)   | BW FLOPS (TF/s)   |
| QKV Preparation     |         0.0769 | -                 | -              | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Pure Flex Attention |         0.3365 | -                 | 1.5804         | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Combined (Total)    |         0.4002 | -                 | 1.5789         | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+
| Overhead            |         0.0637 | -                 | -0.0015        | -                 |
+---------------------+----------------+-------------------+----------------+-------------------+


# HNA (Flex)

In [None]:
def sgn(x):
    return -1 if x < 0 else (1 if x > 0 else 0)

def generate2d(x: int, y: int, ax: int, ay: int, bx: int, by: int, result):
    w = abs(ax + ay)
    h = abs(bx + by)
    dax, day = sgn(ax), sgn(ay)
    dbx, dby = sgn(bx), sgn(by)

    if h == 1 or w == 1:
        if h == 1:
            for _ in range(w):
                result.append((x, y))
                x, y = x + dax, y + day
        elif w == 1:
            for _ in range(h):
                result.append((x, y))
                x, y = x + dbx, y + dby
        return

    ax2, ay2 = ax // 2, ay // 2
    bx2, by2 = bx // 2, by // 2
    w2 = abs(ax2 + ay2)
    h2 = abs(bx2 + by2)

    if 2 * w > 3 * h:
        if w2 % 2 and w > 2:
            ax2, ay2 = ax2 + dax, ay2 + day
        generate2d(x, y, ax2, ay2, bx, by, result)
        generate2d(x + ax2, y + ay2, ax - ax2, ay - ay2, bx, by, result)
    else:
        if h2 % 2 and h > 2:
            bx2, by2 = bx2 + dbx, by2 + dby
        generate2d(x, y, bx2, by2, ax2, ay2, result)
        generate2d(x + bx2, y + by2, ax, ay, bx - bx2, by - by2, result)
        generate2d(x + (ax - dax) + (bx2 - dbx),
                   y + (ay - day) + (by2 - dby),
                   -bx2, -by2, -(ax - ax2), -(ay - ay2), result)

def gilbert2d(width, height):
    result = []
    if width >= height:
        generate2d(0, 0, width, 0, 0, height, result)
    else:
        generate2d(0, 0, 0, height, width, 0, result)
    return result

class GilbertPathCache:
    def __init__(self):
        self.cache = {}
        self.device_index_cache = {}

    def get_or_create_path(self, H, W):
        key = (H, W)
        if key not in self.cache:
            path = gilbert2d(W, H)

            forward_map = torch.zeros((H, W), dtype=torch.long)
            reverse_map = torch.zeros((H * W, 2), dtype=torch.long)

            for idx, (x, y) in enumerate(path[:H*W]):
                if y < H and x < W:
                    forward_map[y, x] = idx
                    reverse_map[idx, 0] = y
                    reverse_map[idx, 1] = x

            self.cache[key] = {
                'path': path,
                'forward_map': forward_map,
                'reverse_map': reverse_map,
                'y_indices': reverse_map[:, 0].clone(),
                'x_indices': reverse_map[:, 1].clone(),
                'H': H,
                'W': W
            }

        return self.cache[key]

    def get_indices_on_device(self, H, W, device):
        device_key = (H, W, str(device))
        if device_key in self.device_index_cache:
            return self.device_index_cache[device_key]
        info = self.get_or_create_path(H, W)
        y_dev = info['y_indices'].to(device)
        x_dev = info['x_indices'].to(device)
        self.device_index_cache[device_key] = (y_dev, x_dev)
        return y_dev, x_dev

    def precompute_paths(self, resolutions):
        for H, W in resolutions:
            self.get_or_create_path(H, W)

    def clear_cache(self):
        self.cache.clear()

_global_gilbert_cache = GilbertPathCache()

def tensor_to_gilbert_path(x, cache=None):
    """
    Args:
        x: Input tensor, shape (B, H, W, C)
        cache: Optional GilbertPathCache instance, use global cache if None
    Returns:
        Reordered tensor, shape (B, H*W, C)
    """
    B, H, W, C = x.shape
    device = x.device
    if cache is None:
        cache = _global_gilbert_cache

    y_indices, x_indices = cache.get_indices_on_device(H, W, device)
    gilbert_tensor = x[:, y_indices, x_indices, :]  # (B, H*W, C)

    return gilbert_tensor

def gilbert_tensor_to_2d(x, H, W, cache=None):
    """
    Args:
        x: Gilbert sequence tensor, shape (B, H*W, C)
        H: Target height
        W: Target width
        cache: Optional GilbertPathCache instance, use global cache if None
    Returns:
        2D layout tensor, shape (B, H, W, C)
    """
    B, N, C = x.shape
    device = x.device

    if cache is None:
        cache = _global_gilbert_cache

    output_2d = torch.zeros((B, H, W, C), dtype=x.dtype, device=device)

    valid_n = min(N, H * W)
    if valid_n > 0:
        y_all, x_all = cache.get_indices_on_device(H, W, device)
        y_indices = y_all[:valid_n]
        x_indices = x_all[:valid_n]

        output_2d[:, y_indices, x_indices, :] = x[:, :valid_n, :]

    return output_2d

@lru_cache
def create_block_mask_cached(score_mod, B, H, M, N, device="cuda",BLOCK_SIZE=128):
    block_mask = create_block_mask(score_mod, B, H, M, N, device=device,BLOCK_SIZE=BLOCK_SIZE, _compile=True)
    return block_mask

def calculate_tflops(flops: float, time_ms: float, multiplier: int) -> float:
    return multiplier * flops * (1e3 / time_ms) / 1e12

def hilbert_rearrange(x):

    x_seq = tensor_to_gilbert_path(x)  # (B, H_img*W_img, C)
    return x_seq

def hilbert_qkv_projection(x_seq, num_heads: int):

    B, S, C = x_seq.shape
    head_dim = C // num_heads

    if not hasattr(hilbert_qkv_projection, 'qkv_weight'):
        hilbert_qkv_projection.qkv_weight = torch.randn(
            C, 3*C, device=x_seq.device, dtype=x_seq.dtype
        )

    # QKV projection
    qkv = x_seq @ hilbert_qkv_projection.qkv_weight  # (B, H_img*W_img, 3*C)
    qkv = qkv.view(B, S, 3, num_heads, head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]  # 每个形状: (B, num_heads, H_img*W_img, C//num_heads)

    return q, k, v

def prepare_hilbert_qkv(x, num_heads):
    """Prepare QKV tensors from 2D input using Hilbert curve ordering"""
    x_seq = hilbert_rearrange(x)

    q, k, v = hilbert_qkv_projection(x_seq, num_heads)

    return q, k, v

def hilbert_flex_attention_only(q, k, v, score_mod=None, block_mask=None):

    return flex_attention(q, k, v, score_mod=score_mod, block_mask=block_mask)

def hilbert_window_flex_attention(x, num_heads, score_mod=None, block_mask=None):
    x_seq = hilbert_rearrange(x)
    q, k, v = hilbert_qkv_projection(x_seq, num_heads)
    x_seq = hilbert_flex_attention_only(q, k, v, score_mod=score_mod, block_mask=block_mask)
    return x_seq

# Test performance of split functions
def test_split_performance():
    B, H_img, W_img, C, num_heads = 16, 56, 56, 128, 2
    BLOCK = 49 
    S = H_img * W_img

    # Prepare input
    x = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16, requires_grad=True)
    gradOut = torch.randn(B, H_img, W_img, C, device="cuda", dtype=torch.float16)

    # Create block mask
    def natten_1d_mask(b, h, q_idx, kv_idx, KERNEL_SIZE=289, SEQ_LEN=16384):
        kernel_center = q_idx.clamp(KERNEL_SIZE // 2, SEQ_LEN - 1 - KERNEL_SIZE // 2)
        return (kernel_center - kv_idx).abs() <= KERNEL_SIZE // 2 
    def score_mod_func(score, b, h, q_idx, kv_idx):
        rel_pos = (q_idx - kv_idx).to(score.dtype)
        return score + rel_pos
    block_mask = create_block_mask_cached(natten_1d_mask, 1, 1, S, S, device=x.device, BLOCK_SIZE=128)

    # Test Hilbert rearrangement only
    hilbert_rearrange_call = lambda: hilbert_rearrange(x)
    hilbert_rearrange_ms = do_bench(hilbert_rearrange_call)

    # Test QKV projection only
    x_seq = hilbert_rearrange_call()
    qkv_proj_call = lambda: hilbert_qkv_projection(x_seq, num_heads)
    qkv_proj_ms = do_bench(qkv_proj_call)

    # Test QKV preparation (combined)
    prepare_call = lambda: prepare_hilbert_qkv(x, num_heads)
    prepare_ms = do_bench(prepare_call)

    # Pre-compute QKV for pure attention test
    q, k, v = prepare_call()

    # Test pure flex attention only
    pure_attention_call = lambda: hilbert_flex_attention_only(q, k, v, score_mod=score_mod_func, block_mask=block_mask)
    pure_attention_ms = do_bench(pure_attention_call)

    # Test combined function
    combined_call = lambda: hilbert_window_flex_attention(x, num_heads, score_mod=score_mod_func, block_mask=block_mask)
    combined_ms = do_bench(combined_call)

    # Backward test
    combined_out = combined_call()
    pure_out = pure_attention_call()
    gradOut_seq = gradOut.view(B, H_img*W_img, num_heads, C // num_heads).permute(0, 2, 1, 3).contiguous()

    pure_bw_ms = do_bench(lambda: pure_out.backward(gradOut_seq, retain_graph=True))
    combined_bw_ms = do_bench(lambda: combined_out.backward(gradOut_seq, retain_graph=True))

    results = [
        ["Hilbert Rearrangement", f"{hilbert_rearrange_ms:.4f}", "-", "-", "-"],
        ["QKV Projection", f"{qkv_proj_ms:.4f}", "-", "-", "-"],
        ["QKV Preparation (Total)", f"{prepare_ms:.4f}", "-", "-", "-"],
        ["Pure Flex Attention", f"{pure_attention_ms:.4f}", "-", f"{pure_bw_ms:.4f}", "-"],
        ["Combined (Total)", f"{combined_ms:.4f}", "-", f"{combined_bw_ms:.4f}", "-"],
]

    print(f"\nTest split performance:")
    print(tabulate(results, headers=["Operation", "FW Time (ms)", "BW Time (ms)"], tablefmt="grid"))

    # Clean up
    del x, q, k, v, combined_out, pure_out
    torch.cuda.empty_cache()

test_split_performance()


分离函数性能测试:
+-------------------------+----------------+-------------------+----------------+-------------------+
| Operation               |   FW Time (ms) | FW FLOPS (TF/s)   | BW Time (ms)   | BW FLOPS (TF/s)   |
| Hilbert Rearrangement   |         0.0896 | -                 | -              | -                 |
+-------------------------+----------------+-------------------+----------------+-------------------+
| QKV Projection          |         0.0776 | -                 | -              | -                 |
+-------------------------+----------------+-------------------+----------------+-------------------+
| QKV Preparation (Total) |         0.1298 | -                 | -              | -                 |
+-------------------------+----------------+-------------------+----------------+-------------------+
| Pure Flex Attention     |         0.2625 | -                 | 1.5588         | -                 |
+-------------------------+----------------+-------------------+-------