In [None]:
import numpy as np

def softmax(x):
    """计算 softmax"""
    e_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return e_x / e_x.sum(axis=-1, keepdims=True)

def masked_attention(query, key, value, mask):
    """
    计算带有 mask 的注意力机制
    :param query: 查询矩阵 (batch_size, seq_len, d_k)
    :param key: 键矩阵 (batch_size, seq_len, d_k)
    :param value: 值矩阵 (batch_size, seq_len, d_v)
    :param mask: mask 矩阵 (batch_size, 1, seq_len) 或 (batch_size, seq_len, seq_len)
    :return: 注意力输出
    """
    # 计算注意力分数
    scores = np.matmul(query, key.transpose(0, 2, 1)) / np.sqrt(query.shape[-1])
    
    # 应用 mask
    scores += (mask * -1e9)  # 将 mask 为 0 的位置设置为一个很小的值

    # 计算注意力权重
    attention_weights = softmax(scores)

    # 计算注意力输出
    output = np.matmul(attention_weights, value)
    return output, attention_weights

# 示例数据
batch_size = 2
seq_len = 5
d_k = 4
d_v = 4

# 随机生成查询、键、值和 mask
query = np.random.rand(batch_size, seq_len, d_k)
key = np.random.rand(batch_size, seq_len, d_k)
value = np.random.rand(batch_size, seq_len, d_v)

# 创建一个 mask，假设我们只想关注前 3 个位置
mask = np.zeros((batch_size, 1, seq_len))
mask[:, :, 3:] = 1  # 将后面的位置设置为 1

# 计算带有 mask 的注意力
output, attention_weights = masked_attention(query, key, value, mask)

print("注意力输出:\n", output)
print("注意力权重:\n", attention_weights)
