In [1]:
import os
import os.path as osp
import time
import numpy as np
import torch
os.chdir('../')
from mmcv.cnn.bricks.transformer import MultiheadAttention
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
from my_projects.CMT.cmt.models.utils.flash_attention import FlashMHA
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-cuxlzjsv because the default path (/home/hello/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
# Create an instance of the regular attention layer
regular_attention = MultiheadAttention(embed_dims=256, num_heads=8).cuda()

# Create an instance of the deformable attention layer
deformable_attention = MultiScaleDeformableAttention(embed_dims=256, num_heads=8).cuda()

# Create an instance of the flash attention layer
flash_attention = FlashMHA(embed_dim=256, num_heads=8).cuda()

In [3]:
start = time.time()
with torch.no_grad():
    query = torch.randn(1024, 1, 256, device=torch.device('cuda'))
    attn_mask = torch.zeros(1024, 1024, device=torch.device('cuda')).bool()
    for _ in range(1000):
        out = regular_attention(query, attn_mask=attn_mask)
print(time.time() - start)

1.3918044567108154


In [4]:
start = time.time()
with torch.no_grad():
    query = torch.randn(1024, 1, 256, device=torch.device('cuda'))
    reference_points = torch.randn(16, 1024, 1, 2, device=torch.device('cuda'))
    spatial_shapes = torch.tensor([[32, 32]], device=torch.device('cuda'))
    level_start_index = torch.tensor([0], device=torch.device('cuda'))
    for _ in range(1000):
        out = deformable_attention(query, reference_points=reference_points, 
            spatial_shapes=spatial_shapes, level_start_index=level_start_index)
print(time.time() - start)

0.9290800094604492


In [5]:
start = time.time()
with torch.no_grad():
    query = torch.randn(1024, 1, 256, device=torch.device('cuda'))
    for _ in range(1000):
        out = flash_attention(query, query, query)
print(time.time() - start)

0.9855546951293945
