In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device):
    # (max_len, 1)
    position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1)
    # (output_dim // 2), 公式中的i，i的取值范围为(0, d/2)
    ids = torch.arange(0, output_dim //2, dtype=torch.float)  # 区别奇数项与偶数项
    theta = torch.pow(10000, -2*ids/output_dim)
    # (max_len, output_dim/2)
    embedding = position * theta # 公式定义 pos/ 10000^(2i/d)
    # (max_len, output_dim // 2, 2)
    embedding = torch.stack([torch.sin(embedding), torch.cos(embedding)], dim=-1)
    # (bs, head, max_len, output_dim //2, 2)
    embedding = embedding.repeat((batch_size, nums_head, *([1] * len(embedding.shape))))

    # (ba, head, max_len, output_dim), reshape之后就是sin和cos
    embedding = torch.reshape(embedding, (batch_size, nums_head, max_len, output_dim))
    embedding = embedding.to(device)
    return embedding


def RoPE(q, k):
    batch_size, nums_head, max_len, output_dim = q.shape
    pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device)

    # cos_pos,sin_pos: (bs, head, max_len, output_dim)
    # 看rope公式可知，相邻cos，sin之间是相同的，所以复制一遍。如(1,2,3)变成(1,1,2,2,3,3)
    cos_pos = pos_emb[...,  1::2].repeat_interleave(2, dim=-1)  # 将奇数列信息抽取出来也就是cos 拿出来并复制
    sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1)  # 将偶数列信息抽取出来也就是sin 拿出来并复制

    # q,k: (bs, head, max_len, output_dim)
    q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1)
    q2 = q2.reshape(q.shape)  # reshape后就是正负交替了


    # 更新qw, *对应位置相乘
    q = q * cos_pos + q2 * sin_pos

    k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1)
    k2 = k2.reshape(k.shape)
    # 更新kw, *对应位置相乘
    k = k * cos_pos + k2 * sin_pos

    return q, k


def attention(q, k, v, mask=None, dropout=None, use_RoPE=True):
    if use_RoPE:
        q, k = RoPE(q, k)
    d_k = k.size()[-1]

    att_logits = torch.matmul(q, k.transpose(-2, -1))  # (bs, head, seq_len, seq_len)
    att_logits /= math.sqrt(d_k)

    if mask is not None:
        att_logits = att_logits.masked_fill(mask == 0, -1e9)  # mask掉为0的部分，设为无穷大

    att_scores = F.softmax(att_logits, dim=-1)  # (bs, head, seq_len, seq_len)

    if dropout is not None:
        att_scores = dropout(att_scores)

    # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk)
    return torch.matmul(att_scores, v), att_scores


if __name__ == "__main__":
    # (bs, head, seq_len, dk)
    q = torch.randn(8, 12, 10, 32)
    k = torch.randn(8, 12, 10, 32)
    v = torch.randn(8, 12, 10, 32)

    res, att_score = attention(q, k, v, mask=None, dropout=None, use_RoPE=True)