#### Build the cuda attention module

In [1]:
!python setup.py install

running install
!!

        ********************************************************************************
        Please avoid running ``setup.py`` directly.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://blog.ganssle.io/articles/2021/10/setup-py-deprecated.html for details.
        ********************************************************************************

!!
  self.initialize_options()
!!

        ********************************************************************************
        Please avoid running ``setup.py`` and ``easy_install``.
        Instead, use pypa/build, pypa/installer or other
        standards-based tools.

        See https://github.com/pypa/setuptools/issues/917 for details.
        ********************************************************************************

!!
  self.initialize_options()
running bdist_egg
running egg_info
creating custom_cuda.egg-info
writing custom_cuda.egg-info/PKG-IN

#### import
    - flash_attention.py contains the triton version and its equivalent torch version
    - looking at the torch version we can follow the code in the triton version

In [1]:
import torch, cuda_attention, flash_attention

#### samples

In [2]:
batch_size, seq_len, num_heads, head_dim = (
            4,
            4096,
            32,
            32,
        )

#### query, key, value

In [3]:
torch.manual_seed(1)
q = torch.randn(batch_size,num_heads, seq_len, head_dim, device='cuda') # query
k, v = q,q
kp = q.transpose(2,3).contiguous()  # the triton version uses shape (batch_size,num_heads, head_dim, seq_len)


#### torch version 

In [4]:
ref_out = torch.nn.functional.scaled_dot_product_attention(q,k,v, is_causal=True, scale=None)
print(f"ref_out: {ref_out.shape}")

ref_out: torch.Size([4, 32, 4096, 32])


#### custom flash attention 
    -- flash_attention
        -- using loops (for understanding the triton version)
    -- its slow, comment out for larger shapes
    

In [5]:
# COMMENTED OUT (SLOW)

# torch_out = flash_attention.flash_attention(q,k,v)
# print(f"torch_out: {ref_out.shape}")
# print(torch.allclose(ref_out, torch_out, atol=1e-3, rtol=1e-3))

#### triton version
    -- triton_flash_attention which uses the kernel: flash_attention_kernel

In [6]:
triton_out = flash_attention.triton_flash_attention(q, kp,v)
print(f"triton_out: {triton_out.shape}")

Grid: (256, 128, 1)
triton_out: torch.Size([4, 32, 4096, 32])


#### cuda version 
    -- kernels are found in attention.cu

In [7]:
cuda_out = cuda_attention.attention_forward(q,k,v, True, True)
print(f"cuda_out: {cuda_out.shape}")

cuda_out: torch.Size([4, 32, 4096, 32])


In [8]:
# in some its not matching. need to dig further

if not torch.allclose(triton_out, ref_out,atol=1e-2, rtol=1e-3):
    print(f"triton_out, ref_out not matching. Try with larger values of atol and rtol\n")
    print(f"Sample:\nref_out: {ref_out[-1,-1,-1,-5:]}\ntriton_out: {triton_out[-1,-1,-1,-5:]}")
else:
    print(True)

True


In [9]:
if not torch.allclose(cuda_out, ref_out,atol=1e-2, rtol=1e-3):
    print(f"cuda_out, ref_out not matching. Try with larger values of atol and rtol\n")
    print(f"cuda_out: {cuda_out[-1,-1,-1,-5:]}\nref_out: {ref_out[-1,-1,-1,-5:]}\ntriton_out: {triton_out[-1,-1,-1,-5:]}")
else:
    print(True)

True
