# Attention Mask

In [None]:
import torch

def create_attention_mask(input_ids, pad_token_id=0, causal=True):
    batch_size, seq_len = input_ids.size()

    padding_mask = (input_ids == pad_token_id).view(batch_size, 1, 1, seq_len)

    if causal:
        causal_mask = torch.triu(
            torch.ones(seq_len, seq_len, dtype=torch.bool),
            diagonal=1 # 对角线及以下为False，对角线以上为True
        )
        causal_mask = causal_mask.view(1, 1, *causal_mask.shape)
        mask = padding_mask | causal_mask # [batch_size, 1, seq_len, seq_len] 是为了方便 torch.mask_filled广播
    else:
        mask = padding_mask.expand(batch_size, 1, seq_len, seq_len) # [batch_size, 1, seq_len, seq_len]
    return mask

# 使用时：
# mask = get_attention_mask(input_ids, pad_token_id=0)
# attention_scores = attention_weights.masked_fill(mask, float('-inf'))