In [None]:
 def get_extended_attention_mask(
        self,
        attention_mask: Tensor,
        input_shape: Tuple[int],
        device: device,
        is_decoder: bool,
        has_query: bool = False,
    ) -> Tensor:
        """
        Makes broadcastable attention and causal masks so that future and masked tokens are ignored.

        Arguments:
            attention_mask (:obj:`torch.Tensor`):
                Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
            input_shape (:obj:`Tuple[int]`):
                The shape of the input to the model.
            device: (:obj:`torch.device`):
                The device of the input to the model.

        Returns:
            :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
        """

get_extended_attention_mask用于生成扩展的attention mask。该函数特别适用于处理self-attention和causal mask情况，尤其是在处理解码器decoder时。

In [None]:
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
if attention_mask.dim() == 3:
    extended_attention_mask = attention_mask[:, None, :, :]

Transformer模型中的多头自注意力机制需要一个形状为 [batch_size, num_heads, from_seq_length, to_seq_length] 的掩码，其中 num_heads 是注意力头的数量。通过在第二个维度插入一个新的维度（None），我们可以确保 attention_mask 变得可广播到所有注意力头上。

In [None]:
elif attention_mask.dim() == 2:
    # Provided a padding mask of dimensions [batch_size, seq_length]
    # - if the model is a decoder, apply a causal mask in addition to the padding mask
    # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]

Padding Mask：用于标记输入序列中的填充位置。填充位置通常是为了使序列长度一致，但这些位置不应该在计算注意力得分时被考虑。
Causal Mask：用于确保解码器只能关注当前和之前的标记，以防止信息泄露。这对于自回归生成模型尤为重要，因为模型在生成下一个标记时不应该访问未来的标记。

In [None]:
if is_decoder:
    batch_size, seq_length = input_shape

    seq_ids = torch.arange(seq_length, device=device)
    causal_mask = (
        seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
        <= seq_ids[None, :, None]
    )

    # add a prefix ones mask to the causal mask
    # causal and attention masks must have same type with pytorch version < 1.3
    causal_mask = causal_mask.to(attention_mask.dtype)

这段代码的目的是生成一个causal mask，以确保每个位置只能关注到它之前的位置。
seq_ids[None, None, :]通过添加两个维度变成形状为(1, 1, seq_length)。
seq_ids[None, None, :].repeat(batch_size, seq_length, 1)将其重复成形状为(batch_size, seq_length, seq_length)。
seq_ids[None, :, None]通过添加两个维度变成形状为(1, seq_length, 1)。
比较操作<=会生成一个布尔张量causal_mask，形状为(batch_size, seq_length, seq_length)。对于每个位置(i, j)，如果i <= j，则设为True，否则设为False。

In [None]:
if causal_mask.shape[1] < attention_mask.shape[1]:
    prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
    if has_query:  # UniLM style attention mask
        causal_mask = torch.cat(
            [
                torch.zeros(
                    (batch_size, prefix_seq_len, seq_length),
                    device=device,
                    dtype=causal_mask.dtype,
                ),
                causal_mask,
            ],
            axis=1,
        )

attention_mask所有位置的mask，causal_mask只是causal部分的mask，causal_mask的长度可能小于attention_mask的长度。
如果这种情况是因为存在query(即论文中的learned query)，需要把query这部分的mask设为0。

In [None]:
causal_mask = torch.cat(
        [
            torch.ones(
                (batch_size, causal_mask.shape[1], prefix_seq_len),
                device=device,
                dtype=causal_mask.dtype,
            ),
            causal_mask,
        ],
        axis=-1,
    )

否则，prefix的mask设为1.

In [None]:
extended_attention_mask = (
    causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)

将两个mask广播后相乘，就是前面说的apply a causal mask in addition to the padding mask。