In [2]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F

In [3]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def forward(self, 
                query: Tensor, 
                key: Tensor, 
                value: Tensor, 
                casual_mask: Tensor = None, 
                pad_mask: Tensor = None):
        d_k = query.size(-1)
        
        # [b, seq, seq]
        attn_score = torch.matmul(query, key.transpose(-1, -2)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        
        if casual_mask is not None: attn_score += casual_mask * -1e9
        if pad_mask is not None:
            pad_mask = pad_mask.unsqueeze(1).unsqueeze(1) # [b, seq] -> [b, 1, 1, seq]
            attn_score += pad_mask * -1e9 # 可以广播 若attn_score为3d 则pad_mask广播为3d； 4d亦然
        
        attn_probs = F.softmax(attn_score, dim=-1) # [b, seq, seq]
        output = torch.matmul(attn_probs, value) # [b, seq, hidden_size]
        return output
        

In [4]:
def sdpa_test():
    b, s, h = 128, 256, 1024
    q = torch.randn(b, s, h)
    k = torch.randn(b, s, h)
    v = torch.randn(b, s, h)
    sdpa = ScaledDotProductAttention()
    output = sdpa(q, k, v)
    return output

In [5]:
if __name__ == "__main__":
    sdpa_test()