In [1]:
import torch
from torch import nn

# 实例化nn.MultiheadAttention


> class
> torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]

Parameters:
- embed_dim – Total dimension of the model.
- num_heads – Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).
- dropout – Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
- bias – If specified, adds bias to input / output projection layers. Default: True.
- add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default: False.
- add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.
- kdim – Total number of features for keys. Default: None (uses kdim=embed_dim).
- vdim – Total number of features for values. Default: None (uses vdim=embed_dim).
- batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

In [4]:
# batch_first – If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).
mha = nn.MultiheadAttention(
    embed_dim=768, num_heads=12, dropout=0.1, bias=True, batch_first=True
)
mha

MultiheadAttention(
  (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)

# 推理

In [5]:
# [batch_size, seq_len, features]
x = torch.ones(2, 192, 768)
x.shape

torch.Size([2, 192, 768])

In [6]:
# attn = q @ k.T [1, 192, 768] @ [1, 768, 192] = [1, 192, 192]
# attn = attn.softmax(dim=-1)
# attn @ v = [1, 192, 192] @ [1, 192, 768] = [1, 192, 768]

attn_output, attn_output_weights = mha(x, x, x)
attn_output.shape, attn_output_weights.shape

(torch.Size([2, 192, 768]), torch.Size([2, 192, 192]))

In [7]:
# need_weights 不要attn的权重
attn_output, attn_output_weights = mha(x, x, x, need_weights=False)
attn_output.shape, attn_output_weights

(torch.Size([2, 192, 768]), None)