In [2]:
from modeling_modifier import nnscaler_flash_attention_forward

In [3]:
import torch
import math
from torch.autograd import gradcheck

# Function and custom attention implementation assumed to be defined or imported

# Test parameters
batch_size = 2
seq_len = 128
num_heads = 8
head_dim = 64


# Random input tensors
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(1024)
dtype = torch.float16
query_states = torch.randn(
    batch_size, num_heads, seq_len, head_dim, 
    requires_grad=True, dtype=dtype, device=device)
key_states = torch.randn(
    batch_size, num_heads, seq_len, head_dim, 
    requires_grad=True, dtype=dtype, device=device)
value_states = torch.randn(
    batch_size, num_heads, seq_len, head_dim, 
    requires_grad=True, dtype=dtype, device=device)

# Parameters for the attention function
dropout = 0.
softmax_scale = 1 / math.sqrt(head_dim)
causal = True

In [4]:
torch.is_autocast_enabled()

False

In [5]:
from torch import nn

intermediate_results_dict = {
    'custom': {},
    'standard_attn': {}
}

class CustomAttenFunc(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx: torch.autograd.function.FunctionCtx, 
        query_states, key_states, value_states, 
        attn_mask, attn_dropout, training
    ):
        # query_states: [batch_size, num_heads, query_len, head_dim]
        # attn_mask: [batch_size, num_heads, query_len, key_len]
        head_dim = query_states.size(-1)

        # Compute scaled dot-product attention scores
        # orig_attn_weights - [batch_size, num_heads, query_len, query_len]
        orig_attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(head_dim)
        
        # Apply mask if provided
        if attn_mask is not None:
            causal_mask = attn_mask[:, :, :, :key_states.shape[-2]]
            orig_attn_weights = orig_attn_weights + causal_mask

        # Softmax across the last dimension and apply dropout if in training mode
        orig_attn_weights = nn.functional.softmax(orig_attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(orig_attn_weights, p=attn_dropout, training=training)

        # Multiply attention weights by value states to get the output
        # attn_output - [batch_size, num_heads, query_len, head_dim]
        attn_output = torch.matmul(attn_weights, value_states)

        # Save for backward pass
        ctx.save_for_backward(query_states, key_states, value_states, attn_weights, orig_attn_weights, attn_mask)
        ctx.attn_dropout = attn_dropout
        ctx.training = training

        return attn_output

    @staticmethod
    def backward(ctx: torch.autograd.function.FunctionCtx, grad_output):
        query_states, key_states, value_states, attn_weights, orig_attn_weights, attn_mask = ctx.saved_tensors
        attn_dropout = ctx.attn_dropout
        training = ctx.training

        # Compute gradients w.r.t. value states
        # attn_weights - [batch_size, num_heads, query_len, query_len]
        # grad_output - [batch_size, num_heads, query_len, head_dim]
        # grad_value_states - [batch_size, num_heads, query_len, head_dim]
        grad_value_states = torch.matmul(attn_weights.transpose(-2, -1), grad_output)

        # Compute gradients w.r.t. attention weights
        # grad_output - [batch_size, num_heads, query_len, head_dim]
        # value_states - [batch_size, num_heads, query_len, head_dim]
        # grad_attn_weights - [batch_size, num_heads, query_len, query_len]
        grad_attn_weights = torch.matmul(grad_output, value_states.transpose(-2, -1))

        # Directly assume grad w.r.t attn_weights is equivalent to grad w.r.t orig_attn_weights because 
        grad_orig_attn_weights = grad_attn_weights

        # Compute gradients w.r.t. matrix multiplication before softmax
        # Because the direct operand of Softmax is the matrix multiplication result before dropout, so we use orig_attn_weights here
        # grad_orig_attn_mul = grad_orig_attn_weights * orig_attn_weights * (1 - orig_attn_weights)
        grad_orig_attn_mul = orig_attn_weights * (
            grad_orig_attn_weights - torch.sum(orig_attn_weights * grad_orig_attn_weights, dim=-1, keepdim=True)
        )

        # ------------------------------------------------
        # Save the gradients w.r.t (orig)_attn_weights to files


        # -------------------------------------------------


        # Compute gradients w.r.t. query and key states
        # grad_orig_attn_mul - [batch_size, num_heads, query_len, query_len]
        # key_states - [batch_size, num_heads, query_len, head_dim]
        # query_states - [batch_size, num_heads, query_len, head_dim]
        grad_query_states = torch.matmul(grad_orig_attn_mul, key_states) / math.sqrt(query_states.size(-1))
        grad_key_states = torch.matmul(query_states.transpose(-2, -1), grad_orig_attn_mul) / math.sqrt(query_states.size(-1))
        grad_key_states = grad_key_states.transpose(-2, -1)

        return grad_query_states, grad_key_states, grad_value_states, None, None, None

class CaptureAttention(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx: torch.autograd.function.FunctionCtx, 
        attn_weights: torch.Tensor,
    ):
        print('Attention weights:', attn_weights.size())
        return attn_weights
        
    @staticmethod
    def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
        print('Gradient w.r.t attention weights:', grad_output.size())
        with open("./attention_gradients.pt", "wb") as f:
            torch.save(grad_output, f)
        return grad_output


def custom_attn(
        query_states, key_states, value_states, 
        attn_mask, attn_dropout, training):
    head_dim = query_states.size(-1)
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

    if attn_mask is not None:  # no matter the length, we just slice it
        causal_mask = attn_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights = CaptureAttention.apply(attn_weights)
    attn_weights = nn.functional.dropout(attn_weights, p=attn_dropout, training=training)
    attn_output = torch.matmul(attn_weights, value_states)

    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output

def standard_attn(
        query_states, key_states, value_states, 
        attn_mask, attn_dropout, training):
    head_dim = query_states.size(-1)
    attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)

    if attn_mask is not None:  # no matter the length, we just slice it
        causal_mask = attn_mask[:, :, :, : key_states.shape[-2]]
        attn_weights = attn_weights + causal_mask

    # upcast attention to fp32
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
    attn_weights = nn.functional.dropout(attn_weights, p=attn_dropout, training=training)
    attn_output = torch.matmul(attn_weights, value_states)

    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output

def flash_attn(
    query_states, key_states, value_states, 
    attn_mask, attn_dropout, training
):
    q_len = query_states.size(2)
    query_states = query_states.transpose(1, 2) # query_states - [batch_size, query_len, num_heads, head_dim]
    key_states = key_states.transpose(1, 2)
    value_states = value_states.transpose(1, 2)

    dropout_rate = attn_dropout if training else 0.0
    input_dtype = query_states.dtype
    query_states = query_states.to(input_dtype)
    key_states = key_states.to(input_dtype)
    value_states = value_states.to(input_dtype)
    
    causal = True
    attn_output = nnscaler_flash_attention_forward(
        query_states, key_states, value_states, attn_mask, q_len, dropout=dropout_rate, causal=causal
    )
    return attn_output

In [6]:
# Clone inputs to keep gradients independent for each function
q_flash = query_states.clone().detach().requires_grad_(True)
k_flash = key_states.clone().detach().requires_grad_(True)
v_flash = value_states.clone().detach().requires_grad_(True)

q_custom = query_states.clone().detach().requires_grad_(True)
k_custom = key_states.clone().detach().requires_grad_(True)
v_custom = value_states.clone().detach().requires_grad_(True)

q_standard = query_states.clone().detach().requires_grad_(True)
k_standard = key_states.clone().detach().requires_grad_(True)
v_standard = value_states.clone().detach().requires_grad_(True)

attention_mask = None
flash_attn_mask = None

output_flash = flash_attn(q_flash, k_flash, v_flash, attention_mask, dropout, True)

# output_custom = CustomAttenFunc.apply(q_custom, k_custom, v_custom, custom_mask, dropout, True)
# output_custom = output_custom.transpose(1, 2).contiguous()
output_custom = custom_attn(q_custom, k_custom, v_custom, attention_mask, dropout, True)
output_standard = standard_attn(q_standard, k_standard, v_standard, attention_mask, dropout, True)

# Define a loss function and compute backward
loss_custom = output_custom.sum()
loss_standard = output_standard.sum()
loss_flash = output_flash.sum()

loss_custom.backward()
loss_standard.backward(w)
loss_flash.backward()

Attention weights: torch.Size([2, 8, 128, 128])
Gradient w.r.t attention weights: torch.Size([2, 8, 128, 128])


In [11]:
print('-' * 50)
print("Sizes of outputs:")
print(f"Output from flash-attention: {q_flash.grad.size(), k_flash.grad.size(), v_flash.grad.size()}")
print(f'Output from custom function: {q_custom.grad.size(), k_custom.grad.size(), v_custom.grad.size()}')
print(f'Output from standard function: {q_standard.grad.size(), k_standard.grad.size(), v_standard.grad.size()}')

print('-' * 50)
print("Output Equivalence check:")
print(f"Flash vs Custom: {torch.allclose(output_flash, output_custom, atol=1e-4)}")
print(f"Flash vs Standard: {torch.allclose(output_flash, output_standard, atol=1e-4)}")
print(f"Standard vs Custom: {torch.allclose(output_standard, output_custom, atol=1e-4)}")

print('-' * 50)
print("Loss equivalence check:")
print(f"Flash vs Custom: {torch.allclose(loss_flash, loss_custom, atol=1e-4)}")
print(f"Flash vs Standard: {torch.allclose(loss_flash, loss_standard, atol=1e-4)}")
print(f"Standard vs Custom: {torch.allclose(loss_standard, loss_custom, atol=1e-4)}")

--------------------------------------------------
Sizes of outputs:
Output from flash-attention: (torch.Size([2, 8, 128, 64]), torch.Size([2, 8, 128, 64]), torch.Size([2, 8, 128, 64]))
Output from custom function: (torch.Size([2, 8, 128, 64]), torch.Size([2, 8, 128, 64]), torch.Size([2, 8, 128, 64]))
Output from standard function: (torch.Size([2, 8, 128, 64]), torch.Size([2, 8, 128, 64]), torch.Size([2, 8, 128, 64]))
--------------------------------------------------
Output Equivalence check:
Flash vs Custom: False
Flash vs Standard: False
Standard vs Custom: True
--------------------------------------------------
Loss equivalence check:
Flash vs Custom: False
Flash vs Standard: False
Standard vs Custom: True


In [12]:
with torch.no_grad():
    print(output_flash - output_custom)

tensor([[[[ 1.6431e-01, -1.3013e-01,  7.6807e-01,  ...,  1.5557e+00,
           -9.8047e-01, -7.8223e-01],
          [-4.1797e-01, -2.8857e-01, -1.2910e+00,  ..., -2.7441e-01,
            2.0996e+00,  3.7769e-01],
          [-4.5654e-02,  4.2310e-01,  2.7422e+00,  ...,  4.8828e-01,
           -1.9102e+00,  4.2188e-01],
          ...,
          [-3.5889e-01, -1.0791e+00,  1.0559e-01,  ..., -8.3496e-01,
           -1.3301e+00, -8.0615e-01],
          [-1.6113e+00, -2.3511e-01, -3.9209e-01,  ..., -3.6084e-01,
           -2.1716e-01, -2.3633e-01],
          [-1.9102e+00,  8.7500e-01, -1.6748e-01,  ...,  2.2375e-01,
           -1.9434e+00, -9.2920e-01]],

         [[ 2.1448e-01, -5.6055e-01,  2.7002e-01,  ...,  1.5781e+00,
           -9.2090e-01, -5.0342e-01],
          [ 1.7322e-01, -1.2329e-02, -1.5596e+00,  ..., -6.9434e-01,
            7.2168e-01,  8.4277e-01],
          [-3.6621e-03,  9.1553e-02,  2.3789e+00,  ..., -1.3062e-02,
           -7.7832e-01,  6.0059e-01],
          ...,
     

In [95]:
print(f"Equality check for gradients:")

print('-' * 50)
print(f"Query:")
print(f'Custom vs Standard: {torch.allclose(q_custom.grad, q_standard.grad, atol=1e-4)}')

print('-' * 50)
print(f"Key:")
print(f'Custom vs Standard: {torch.allclose(k_custom.grad, k_standard.grad, atol=1e-4)}')

print('-' * 50)
print(f"Value:")
print(f'Custom vs Standard: {torch.allclose(v_custom.grad, v_standard.grad, atol=1e-4)}')

Equality check for gradients:
--------------------------------------------------
Query:
Custom vs Standard: True
--------------------------------------------------
Key:
Custom vs Standard: True
--------------------------------------------------
Value:
Custom vs Standard: True


In [84]:
with torch.no_grad():
    print(k_custom.grad - k_standard.grad)

tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 

tensor([ 0.0449, -0.4001,  0.1593,  0.0342,  0.4521,  0.4736, -0.2952, -1.1221,
        -0.2947,  0.1538, -0.2749,  0.2402,  0.0555,  0.2164, -0.3953, -0.5356,
         0.2583,  0.7163,  0.2443, -0.1348,  0.2311,  0.0945, -0.2944,  0.4666,
        -0.1801, -0.2402,  0.6650,  0.6274,  0.8462,  0.2656,  1.0762, -0.0509,
         0.3887, -0.0023,  0.8545,  0.0223, -0.1755, -0.5200,  1.1436,  0.7930,
        -1.0967, -0.4268,  0.2148, -0.0367, -0.0276, -0.1140, -0.1279, -0.2732,
        -0.3071,  0.3386,  0.5171,  0.4324, -0.4165,  0.5942,  0.9927, -0.4075,
         0.0721, -0.3181,  0.3213, -0.1255, -0.0927, -0.2695,  0.5225,  0.3623],
       device='cuda:0', dtype=torch.float16)