In [1]:
import torch, math, gc

# For details, please refer to https://pytorch.org/docs/stable/torch_cuda_memory.html

In [3]:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype, device = 'cuda')

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)  # applied to balance the workload

    return attn_weight @ value

In [4]:
def gen_memory_timeline(seq_lengths, repeat_time = 5, output_file = './my_snapshot.pickle'):    
    torch.cuda.memory._record_memory_history()
    for seq_len in seq_lengths:
        Q = torch.rand(1, 8, seq_len, 64, dtype=torch.float16, device="cuda")
        K = torch.rand(1, 8, seq_len, 64, dtype=torch.float16, device="cuda")
        V = torch.rand(1, 8, seq_len, 64, dtype=torch.float16, device="cuda")
        for _ in range(repeat_time):
            res = scaled_dot_product_attention(Q, K, V)
        torch.cuda.empty_cache()
        gc.collect()
    torch.cuda.memory._dump_snapshot(output_file)

In [5]:
seq_lengths = [10000, 20000, 30000, 40000, 45000]
gen_memory_timeline(seq_lengths)

