In [2]:
import sys

sys.path.append("..")

from src.bert_layers.softpick import naive_softpick_attn, parallel_softpick_attn

  from .autonotebook import tqdm as notebook_tqdm
  @custom_fwd
  @custom_bwd


In [9]:
def pack_sequences(inputs, attention_mask):
    """
    inputs: Tensor of shape [batch_size, max_seq_len, ...]
    attention_mask: Tensor of shape [batch_size, max_seq_len] (1=valid, 0=pad)
    """
    # Get actual sequence lengths
    seq_lens = attention_mask.sum(dim=1).cpu().tolist()
    
    # Remove padding and concatenate sequences
    packed = torch.cat([inputs[i,:l] for i,l in enumerate(seq_lens)], dim=0)
    
    # Create cumulative sequence lengths
    cu_seqlens = torch.tensor(
        [0] + [sum(seq_lens[:i+1]) for i in range(len(seq_lens))],
        dtype=torch.long,
        device=inputs.device
    )
    
    return packed.unsqueeze(0), cu_seqlens  # Add batch dimension

# Original padded batch
batch = torch.tensor([
    [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6], [0.0, 0.0]],  # seq_len=3
    [[0.7, 0.8], [0.9, 1.0], [0.0, 0.0], [0.0, 0.0]]   # seq_len=2
])
mask = torch.tensor([
    [1, 1, 1, 0],
    [1, 1, 0, 0]
])

# Pack sequences
packed, cu_seqlens = pack_sequences(batch, mask)
print(packed.shape)  # torch.Size([1, 5, 2]) 
print(cu_seqlens)    # tensor([0, 3, 5])

torch.Size([1, 5, 2])
tensor([0, 3, 5])


In [20]:
from einops import rearrange

B, H, S, D = 4, 16, 128, 64

# Create random input and attention mask
inputs = torch.randn(B, S, H * D, device="cuda")
attention_mask = torch.randint(0, 2, (B, S), device="cuda").to(torch.bool)
# Create packed sequences
packed, cu_seqlens = pack_sequences(inputs, attention_mask)

query = torch.randn(B, S, H * D, device="cuda")
key = torch.randn(B, S, H * D, device="cuda")
value = torch.randn(B, S, H * D, device="cuda")

query = rearrange(query, "b s (h d) -> b h s d", h=H)
key = rearrange(key, "b s (h d) -> b h s d", h=H)
value = rearrange(value, "b s (h d) -> b h s d", h=H)

# Test naive softpick attention
print("Testing naive softpick attention...")
naive_attn, _ = naive_softpick_attn(query, key, value, mask=attention_mask, scale=D**0.5, head_first=True)

# Test parallel softpick attention
print("Testing parallel softpick attention...")
packed_query, cu_seqlens = pack_sequences(query, attention_mask)
packed_key, _ = pack_sequences(key, attention_mask)
packed_value, _ = pack_sequences(value, attention_mask)
parallel_attn = parallel_softpick_attn(packed_query, packed_key, packed_value, scale=D**0.5, cu_seqlens=cu_seqlens, head_first=True)

print(naive_attn)  # Expected shape: [B, H, S, D]
# print(attention_mask)


Testing naive softpick attention...
Testing parallel softpick attention...
tensor([[[[-1.7523e+00,  4.1025e-01, -8.9073e-01,  ..., -1.3521e+00,
            3.6076e-01,  5.1788e-01],
          [ 2.8459e-01, -1.1112e+00, -9.3327e-01,  ..., -7.5988e-02,
           -7.6570e-01,  1.4538e+00],
          [ 6.3318e-01, -6.1629e-01, -2.4211e+00,  ...,  5.6732e-01,
           -2.0947e+00, -5.4335e-01],
          ...,
          [-7.7210e-01,  8.7004e-01, -4.4575e-01,  ..., -1.3027e+00,
            1.2381e+00, -7.3564e-01],
          [-1.4301e+00,  3.9622e-01, -2.1280e+00,  ...,  1.4532e-01,
            6.7443e-01,  1.5650e+00],
          [-7.6796e-01, -1.0822e+00, -1.8001e+00,  ..., -1.1054e+00,
           -7.2440e-01,  5.2213e-01]],

         [[ 4.5712e-01,  6.0572e-01, -1.3262e-01,  ..., -1.7929e+00,
            2.5354e-01, -6.5786e-01],
          [ 1.0021e+00,  3.2965e-01,  7.7627e-01,  ...,  1.7770e+00,
           -3.5784e-01,  1.5578e+00],
          [-1.3292e+00,  1.0739e+00, -1.2882e-01,  .

In [21]:
torch.testing.assert_close(
    naive_attn, parallel_attn, rtol=1e-5, atol=1e-5
)

AssertionError: The values for attribute 'shape' do not match: torch.Size([4, 16, 128, 64]) != torch.Size([1, 64, 128, 64]).

In [3]:
import torch

def run_comparison_test():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cpu":
        return False

    dtype = torch.float32
    torch.manual_seed(0)

    B = 1
    T_actual = 3
    H = 1
    K_dim = 2
    V_dim = 2
    scale = K_dim ** -0.5

    q_data = torch.randn(B, T_actual, H, K_dim, device=device, dtype=dtype)
    k_data = torch.randn(B, T_actual, H, K_dim, device=device, dtype=dtype)
    v_data = torch.randn(B, T_actual, H, V_dim, device=device, dtype=dtype)

    q_n_for_naive = q_data.permute(0, 2, 1, 3)
    k_n_for_naive = k_data.permute(0, 2, 1, 3)
    v_n_for_naive = v_data.permute(0, 2, 1, 3)
    naive_bidirectional_mask = torch.ones(T_actual, T_actual, device=device, dtype=torch.bool)

    o_naive_heads_first, _ = naive_softpick_attn(
        q_n_for_naive, k_n_for_naive, v_n_for_naive, scale,
        head_first=True,
        mask=naive_bidirectional_mask
    )
    o_naive = o_naive_heads_first.permute(0, 2, 1, 3)

    cu_seqlens_parallel = torch.tensor([0, T_actual], device=device, dtype=torch.long)

    o_parallel = parallel_softpick_attn(
        q_data, k_data, v_data, scale,
        cu_seqlens=cu_seqlens_parallel,
        head_first=False
    )

    return torch.allclose(o_naive, o_parallel, atol=1e-2, rtol=1e-2)

# You need to have your `naive_softpick_attn` and `parallel_softpick_attn`
# (with the bidirectional modification) defined or imported for this to run.
# Then, you can call:
result = run_comparison_test()
print(result)

True


In [7]:
def run_batched_comparison_test():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cpu":
        print("CUDA not available, skipping batched Triton test.")
        return False

    dtype = torch.float32
    torch.manual_seed(0)

    B = 2
    T_padded = 5
    H = 2
    K_dim = 4
    V_dim = 4
    scale = K_dim ** -0.5

    actual_lengths_list = [3, T_padded] # First sequence is shorter
    actual_lengths = torch.tensor(actual_lengths_list, device=device, dtype=torch.long)

    q_base = torch.randn(B, T_padded, H, K_dim, device=device, dtype=dtype)
    k_base = torch.randn(B, T_padded, H, K_dim, device=device, dtype=dtype)
    v_base = torch.randn(B, T_padded, H, V_dim, device=device, dtype=dtype)

    o_naive_list = []
    for i in range(B):
        L_i = actual_lengths[i].item()
        q_i = q_base[i:i+1, :L_i, ...].permute(0, 2, 1, 3) # [1, H, L_i, K_dim]
        k_i = k_base[i:i+1, :L_i, ...].permute(0, 2, 1, 3) # [1, H, L_i, K_dim]
        v_i = v_base[i:i+1, :L_i, ...].permute(0, 2, 1, 3) # [1, H, L_i, V_dim]

        mask_i = torch.ones(L_i, L_i, device=device, dtype=torch.bool)

        o_i_calc, _ = naive_softpick_attn(q_i, k_i, v_i, scale, head_first=True, mask=mask_i)

        padded_o_i = torch.zeros(1, H, T_padded, V_dim, device=device, dtype=dtype)
        padded_o_i[:, :, :L_i, :] = o_i_calc
        o_naive_list.append(padded_o_i)

    o_naive_batched_hf = torch.cat(o_naive_list, dim=0) # [B, H, T_padded, V_dim]
    o_naive = o_naive_batched_hf.permute(0, 2, 1, 3) # -> [B, T_padded, H, V_dim]

    cu_seqlens = torch.cat([torch.tensor([0], device=device, dtype=torch.long),
                            torch.cumsum(actual_lengths, dim=0)])

    q_packed_list = [q_base[i, :actual_lengths[i].item(), ...] for i in range(B)]
    k_packed_list = [k_base[i, :actual_lengths[i].item(), ...] for i in range(B)]
    v_packed_list = [v_base[i, :actual_lengths[i].item(), ...] for i in range(B)]

    q_par = torch.cat(q_packed_list, dim=0).unsqueeze(0) # [1, total_tokens, H, K_dim]
    k_par = torch.cat(k_packed_list, dim=0).unsqueeze(0) # [1, total_tokens, H, K_dim]
    v_par = torch.cat(v_packed_list, dim=0).unsqueeze(0) # [1, total_tokens, H, V_dim]

    # Call your modified parallel_softpick_attn
    # It expects q,k,v as [B,T,H,D] if head_first=False (B=1 for packed)
    o_parallel_packed = parallel_softpick_attn(
        q_par, k_par, v_par, scale,
        cu_seqlens=cu_seqlens,
        head_first=False # Inputs are [1, total_T, H, D]
    )
    # o_parallel_packed is [1, total_tokens, H, V_dim] (if head_first=False output from func)
    # If parallel_softpick_attn returns head_first=True output ([1,H,total_tokens,D]), permute it:
    # if o_parallel_packed.shape[1] == H: # Heuristic to check if it's head-first
    #    o_parallel_packed = o_parallel_packed.permute(0,2,1,3)


    # Unpack the output from parallel version
    o_parallel_unpacked = torch.zeros_like(o_naive) # Target shape [B, T_padded, H, V_dim]
    current_pos = 0
    for i in range(B):
        L_i = actual_lengths[i].item()
        if L_i > 0:
            o_parallel_unpacked[i, :L_i, ...] = o_parallel_packed[0, current_pos:current_pos + L_i, ...]
        current_pos += L_i

    # --- Comparison (only for valid token positions) ---
    # Create a mask for positions that are not padding
    valid_token_mask_2d = torch.arange(T_padded, device=device).unsqueeze(0) < actual_lengths.unsqueeze(1)
    # valid_token_mask_2d is [B, T_padded]

    # Expand mask to match output tensor dimensions [B, T_padded, H, V_dim]
    valid_token_mask_expanded = valid_token_mask_2d.unsqueeze(-1).unsqueeze(-1).expand_as(o_naive)

    # Compare only the valid parts
    are_close = torch.allclose(
        o_naive[valid_token_mask_expanded],
        o_parallel_unpacked[valid_token_mask_expanded],
        atol=1e-2, rtol=1e-2 # Tolerances might be needed
    )
    if not are_close:
        # For debugging, print differences only on the valid parts
        # Iterate batch items to print manageable chunks
        for i in range(B):
            L_i = actual_lengths[i].item()
            if L_i == 0: continue
            # print(f"Batch item {i}, Len {L_i}")
            # print("Naive valid:", o_naive[i, :L_i].flatten()[:10]) # Print first 10 of flattened valid part
            # print("Parallel valid:", o_parallel_unpacked[i, :L_i].flatten()[:10])
            # print("Diff valid:", (o_naive[i, :L_i] - o_parallel_unpacked[i, :L_i]).abs().max())
            pass # Keep it quiet as per user request
    return are_close

# To run:
# 1. Replace dummy functions with your actual, MODIFIED code.
# 2. Then call:
result = run_batched_comparison_test()
print(f"Batched outputs are close: {result}")

Batched outputs are close: True


In [None]:
import torch
from typing import Tuple, List, Dict, Any

def pack_qkv_for_parallel_attn(
    q_batched: torch.Tensor,
    k_batched: torch.Tensor,
    v_batched: torch.Tensor,
    actual_lengths: torch.LongTensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.LongTensor, Dict[str, Any]]:
    """
    Packs batched Q, K, V tensors for attention mechanisms that use cu_seqlens.

    Args:
        q_batched (torch.Tensor): Batched queries, shape [B, T_padded, H, K_dim].
        k_batched (torch.Tensor): Batched keys, shape [B, T_padded, H, K_dim].
        v_batched (torch.Tensor): Batched values, shape [B, T_padded, H, V_dim].
        actual_lengths (torch.LongTensor): Tensor of actual sequence lengths, shape [B].

    Returns:
        Tuple containing:
            - q_packed (torch.Tensor): Packed queries, shape [1, total_tokens, H, K_dim].
            - k_packed (torch.Tensor): Packed keys, shape [1, total_tokens, H, K_dim].
            - v_packed (torch.Tensor): Packed values, shape [1, total_tokens, H, V_dim].
            - cu_seqlens (torch.LongTensor): Cumulative sequence lengths, shape [B + 1].
            - unpack_info (Dict[str, Any]): Information needed for unpacking.
    """
    original_batch_size = q_batched.shape[0]
    original_padded_length = q_batched.shape[1]
    num_heads_q = q_batched.shape[2]
    head_dim_k = q_batched.shape[3]
    num_heads_v = v_batched.shape[2] # Can be different for GQA/MQA for K,V
    head_dim_v = v_batched.shape[3]


    q_packed_list: List[torch.Tensor] = []
    k_packed_list: List[torch.Tensor] = []
    v_packed_list: List[torch.Tensor] = []

    for i in range(original_batch_size):
        len_i = actual_lengths[i].item()
        if len_i > 0:
            q_packed_list.append(q_batched[i, :len_i, ...])
            k_packed_list.append(k_batched[i, :len_i, ...])
            v_packed_list.append(v_batched[i, :len_i, ...])

    if not q_packed_list: # All sequences had length 0
        q_packed = torch.empty(1, 0, num_heads_q, head_dim_k, dtype=q_batched.dtype, device=q_batched.device)
        k_packed = torch.empty(1, 0, k_batched.shape[2], k_batched.shape[3], dtype=k_batched.dtype, device=k_batched.device)
        v_packed = torch.empty(1, 0, num_heads_v, head_dim_v, dtype=v_batched.dtype, device=v_batched.device)
    else:
        q_packed = torch.cat(q_packed_list, dim=0).unsqueeze(0)
        k_packed = torch.cat(k_packed_list, dim=0).unsqueeze(0)
        v_packed = torch.cat(v_packed_list, dim=0).unsqueeze(0)

    cu_seqlens = torch.cat([
        torch.tensor([0], device=actual_lengths.device, dtype=torch.long),
        torch.cumsum(actual_lengths, dim=0)
    ])

    unpack_info = {
        "original_batch_size": original_batch_size,
        "original_padded_length": original_padded_length,
        "actual_lengths": actual_lengths,
        "q_dtype": q_batched.dtype,
        "q_device": q_batched.device
    }

    return q_packed, k_packed, v_packed, cu_seqlens, unpack_info

def unpack_output_from_parallel_attn(
    output_packed: torch.Tensor,
    unpack_info: Dict[str, Any]
) -> torch.Tensor:
    """
    Unpacks the output from an attention mechanism that used packed inputs.

    Args:
        output_packed (torch.Tensor): Packed output, shape [1, total_tokens, H_out, D_out].
        unpack_info (Dict[str, Any]): Information from the packing function.

    Returns:
        torch.Tensor: Unpacked output, shape [B, T_padded, H_out, D_out].
    """
    original_batch_size = unpack_info["original_batch_size"]
    original_padded_length = unpack_info["original_padded_length"]
    actual_lengths = unpack_info["actual_lengths"]
    output_dtype = output_packed.dtype # Use dtype from packed output
    output_device = output_packed.device

    num_heads_out = output_packed.shape[2]
    head_dim_out = output_packed.shape[3]

    output_unpacked = torch.zeros(
        original_batch_size,
        original_padded_length,
        num_heads_out,
        head_dim_out,
        dtype=output_dtype,
        device=output_device
    )

    current_pos = 0
    for i in range(original_batch_size):
        len_i = actual_lengths[i].item()
        if len_i > 0:
            output_unpacked[i, :len_i, ...] = output_packed[0, current_pos : current_pos + len_i, ...]
        current_pos += len_i

    return output_unpacked


In [2]:
import torch

def run_comparison_test():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cpu":
        return False

    dtype = torch.float32
    torch.manual_seed(0)

    B = 1
    T_actual = 3
    H = 1
    K_dim = 2
    V_dim = 2
    scale = K_dim ** -0.5

    q_data = torch.randn(B, T_actual, H, K_dim, device=device, dtype=dtype)
    k_data = torch.randn(B, T_actual, H, K_dim, device=device, dtype=dtype)
    v_data = torch.randn(B, T_actual, H, V_dim, device=device, dtype=dtype)

    q_n_for_naive = q_data.permute(0, 2, 1, 3)
    k_n_for_naive = k_data.permute(0, 2, 1, 3)
    v_n_for_naive = v_data.permute(0, 2, 1, 3)
    naive_bidirectional_mask = torch.ones(T_actual, T_actual, device=device, dtype=torch.bool)

    o_naive_heads_first, _ = naive_softpick_attn(
        q_n_for_naive, k_n_for_naive, v_n_for_naive, scale,
        head_first=True,
        mask=naive_bidirectional_mask
    )
    o_naive = o_naive_heads_first.permute(0, 2, 1, 3)

    cu_seqlens_parallel = torch.tensor([0, T_actual], device=device, dtype=torch.long)

    o_parallel = parallel_softpick_attn(
        q_data, k_data, v_data, scale,
        cu_seqlens=cu_seqlens_parallel,
        head_first=False
    )

    print("Naive output shape:", o_naive.shape)
    print("Parallel output shape:", o_parallel.shape)
    print(f"{'Naive output:'}\n{o_naive}")
    print(f"{'Parallel output:'}\n{o_parallel}")

    return torch.allclose(o_naive, o_parallel, atol=1e-5, rtol=1e-3)

# You need to have your `naive_softpick_attn` and `parallel_softpick_attn`
# (with the bidirectional modification) defined or imported for this to run.
# Then, you can call:
result = run_comparison_test()
print(result)

Naive output shape: torch.Size([1, 3, 1, 2])
Parallel output shape: torch.Size([1, 3, 1, 2])
Naive output:
tensor([[[[ 0.1183, -0.0333]],

         [[ 0.0000,  0.0000]],

         [[ 0.5512, -0.1587]]]], device='cuda:0')
Parallel output:
tensor([[[[ 2.5420, -0.7160]],

         [[ 0.0000,  0.0000]],

         [[ 0.5509, -0.1587]]]], device='cuda:0')
False


In [8]:
import torch
from torch.testing import assert_close
from einops import rearrange

def test_softpick_attention_equivalence():
    # Test both packed sequences (cu_seqlens) and batched inputs
    for use_cu_seqlens in [True, False]:
        print(f"\nTesting {'PACKED SEQUENCES (cu_seqlens)' if use_cu_seqlens else 'BATCHED INPUT'}...")
        
        # Configuration
        num_sequences = 2  # Actual batch size for non-packed case
        max_seq_len = 4
        num_heads = 2
        head_dim = 8
        hidden_size = num_heads * head_dim
        scale = 1.0 / (head_dim ** 0.5)
        dtype = torch.float32
        torch.manual_seed(42)

        # Create different sequence lengths
        seq_lens = [3, 4] if use_cu_seqlens else [max_seq_len]*num_sequences
        
        # Generate test data --------------------------------------------------
        if use_cu_seqlens:
            # Packed format: single batch with concatenated sequences
            q_seqs = [torch.randn(1, l, hidden_size, device="cuda") for l in seq_lens]
            k_seqs = [torch.randn(1, l, hidden_size, device="cuda") for l in seq_lens]
            v_seqs = [torch.randn(1, l, hidden_size, device="cuda") for l in seq_lens]
            
            q = torch.cat(q_seqs, dim=1)
            k = torch.cat(k_seqs, dim=1)
            v = torch.cat(v_seqs, dim=1)
            
            # Create cu_seqlens [0, 3, 7] for seq_lens [3,4]
            cu_seqlens = torch.tensor(
                [0] + [sum(seq_lens[:i+1]) for i in range(len(seq_lens))],
                dtype=torch.long, device="cuda"
            )
            attention_mask = None  # Masking handled via cu_seqlens
        else:
            # Batched format: separate entries in batch dimension
            q = torch.randn(num_sequences, max_seq_len, hidden_size, device="cuda")
            k = torch.randn(num_sequences, max_seq_len, hidden_size, device="cuda")
            v = torch.randn(num_sequences, max_seq_len, hidden_size, device="cuda")
            
            # Create attention mask (1=valid, 0=pad)
            attention_mask = torch.zeros(num_sequences, max_seq_len, device="cuda")
            for i,l in enumerate(seq_lens):
                attention_mask[i, :l] = 1
            cu_seqlens = None

        # Convert to head-first format -----------------------------------------
        q_head_first = rearrange(q, "b s (h d) -> b h s d", h=num_heads)
        k_head_first = rearrange(k, "b s (h d) -> b h s d", h=num_heads)
        v_head_first = rearrange(v, "b s (h d) -> b h s d", h=num_heads)

        # Run implementations -------------------------------------------------
        # Triton implementation
        triton_output = parallel_softpick_attn(
            q_head_first,
            k_head_first,
            v_head_first,
            scale=scale,
            cu_seqlens=cu_seqlens if use_cu_seqlens else None,
            head_first=True
        )
        
        # Naive implementation
        naive_output, _ = naive_softpick_attn(
            q_head_first,
            k_head_first,
            v_head_first,
            scale=scale,
            cu_seqlens=cu_seqlens if use_cu_seqlens else None,
            head_first=True,
            mask=attention_mask
        )

        # Convert outputs ------------------------------------------------------
        triton_output = rearrange(triton_output, "b h s d -> b s (h d)")
        
        # For packed sequences: unpack results
        if use_cu_seqlens:
            # Split concatenated output back into original sequences
            triton_outputs = [
                triton_output[:, cu_seqlens[i]:cu_seqlens[i+1]] 
                for i in range(len(seq_lens))
            ]
            naive_outputs = naive_output.split(seq_lens, dim=1)
        else:
            triton_outputs = triton_output
            naive_outputs = naive_output

        # Validate results ----------------------------------------------------
        for i, (triton_out, naive_out) in enumerate(zip(triton_outputs, naive_outputs)):
            print(f"\nSequence {i+1} comparison:")
            print("Triton:", triton_out[0,:seq_lens[i]].flatten()[:5])
            print("Naive: ", naive_out[0,:seq_lens[i]].flatten()[:5])
            
            assert_close(
                triton_out[:, :seq_lens[i]],  # Handle padding in batched case
                naive_out[:, :seq_lens[i]],
                rtol=1e-3,
                atol=1e-5,
                check_dtype=False,
                msg=f"Mismatch in {'packed' if use_cu_seqlens else 'batched'} seq {i+1}"
            )

        print(f"Passed {'packed' if use_cu_seqlens else 'batched'} test!")

if __name__ == "__main__":
    test_softpick_attention_equivalence()


Testing PACKED SEQUENCES (cu_seqlens)...


RuntimeError: split_with_sizes expects split_sizes to sum exactly to 2 (input tensor's size at dimension 1), but got split_sizes=[3, 4]