In [1]:
import torch
from torch import nn

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout, bias, device):
        """
        多头注意力机制的实现。 
        Args:
        hidden_size (int): 输入特征的维度，也即 hidden_state 的最后一维。
        num_heads (int): 注意力头的数量。
        dropout (float): dropout 的概率，默认为 0.0。 
        """
        super(MultiHeadAttention, self).__init__()
        assert hidden_size % num_heads == 0, "hidden_size 必须能被 num_heads 整除"
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads # 每个头的维度
        
        # 定义线性变换层，用于生成Q、K、V
        self.query = nn.Linear(hidden_size, hidden_size, bias, device)
        self.key = nn.Linear(hidden_size, hidden_size, bias, device)
        self.value = nn.Linear(hidden_size, hidden_size, bias, device)
        
        self.dropout = nn.Dropout(dropout)
        self.out = nn.Linear(hidden_size, hidden_size, bias, device)
        self.out_projection = nn.Linear(hidden_size, hidden_size, bias, device)
        
    def forward(self, hidden_state, attention_mask):
        """
        前向传播函数。
        Args:
            hidden_state (torch.Tensor): 输入的 hidden_state，形状为 [batch_size, seq_len, hidden_size]。
            attention_mask (torch.Tensor, optional): 注意力掩码，用于屏蔽某些位置，形状为 [batch_size, seq_len]。默认为 None。
        Returns:
             torch.Tensor: 注意力输出，形状为 [batch_size, seq_len, hidden_size]。
        """
        batch_size, seq_len, _ = hidden_state.size()
        
        # 1. 通过线性层得到 Q, K, V
        query = self.query(hidden_state) # [batch_size, seq_len, hidden_size]
        key = self.key(hidden_state) # [batch_size, seq_len, hidden_size]
        value = self.value(hidden_state) # [batch_size, seq_len, hidden_size]
        
        # 2. 将 Q, K, V 拆分成多头
        query = query.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        key = key.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        value = value.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # [batch_size, num_heads, seq_len, head_dim]
        
        # 3. 计算注意力权重
        attention_weights = torch.matmul(query, key.transpose(-2, -1)) / (self.head_dim ** 0.5) # [batch_size, num_heads， seq_len, seq_len]
        
        # 应用 attention mask
        if attention_mask is not None:
            attention_weights = attention_weights.masked_fill(attention_mask[:, None, None, :] == 0, float('-inf')) # attention_mask[:, None, None, :] 将掩码从 [batch_size, seq_len] 扩展为 [batch_size, 1, 1, seq_len]
        attention_weights = torch.softmax(attention_weights, dim=-1)  # [batch_size, num_heads, seq_len, seq_len]
        attention_weights = self.dropout(attention_weights)
        
        # 4. 计算上下文向量
        context = torch.matmul(attention_weights, value)  # [batch_size, num_heads, seq_len, head_dim]
        
        # 5. 将多头合并
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)  # [batch_size, seq_len, hidden_size]，contiguous()确保内存布局是连续的，为后续的view操作做准备
        
        # 6. 通过输出线性层
        output = self.out_projection(context)  # [batch_size, seq_len, hidden_size]
        return output

In [23]:
class RotaryPositionalEmbeddings(nn.Module):
    def __init__ (self, d, base):
        super(RotaryPositionalEmbeddings, self).__init__()
        self.d = d
        self.base = base
        self.cos_cached = None
        self.sin_cached = None
    
    def _build_cache(self, x):
        if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]:
            return 
        
        seq_len = x.shape[0]
        
        theta = 1. / (self.base ** (torch.arange(0, self.d, 2, device=x.device).float() / self.d))
        
        seq_idx = torch.arange(seq_len, device=x.device).float()

        idx_theta = torch.einsum('n, d->nd', seq_idx, theta)

        idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1)

        self.cos_cached = torch.cos(idx_theta2)[: , None, None, :]
        self.sin_cached = torch.sin(idx_theta2)[: , None, None, :]

    def _neg_half(self, x):
        d_2 = self.d // 2
        return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1)
    
    def forward(self, x):
        self._build_cache(x)
        x_rope, x_pass = x[:, :, :, :self.d], x[:, :, :, self.d:]

        neg_half = self._neg_half(x_rope)

        x_rope = (x_rope * self.cos_cached[:x.shape[0]]) + (neg_half * self.sin_cached[:x.shape[0]])

        return torch.cat([x_rope, x_pass], dim=-1)


In [25]:
class RotaryPEMultiHeadAttention(MultiHeadAttention):
    def __init__(self, heads, d_model, rope_percentage, base, dropout_prob, bias, device):
        super(RotaryPEMultiHeadAttention, self).__init__(d_model, heads, dropout_prob, bias, device)
        self.head_dim = d_model // heads
        d_rope = int(self.head_dim * rope_percentage)
        self.query_rotary_pe = RotaryPositionalEmbeddings(d_rope)
        self.key_rotary_pe = RotaryPositionalEmbeddings(d_rope)
    
    def get_scores(self, query, key):
        return torch.einsum('ibhd,jbhd->ijbh', self.query_rotary_pe(query), self.key_rotary_pe(key))

In [None]:
if __name__ == '__main__':
    d_model = 4
    base = 10000
    rope_percentage = 0.5
    dropout_prob = 0
    bias = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    x = torch.tensor([[1, 2, 3, 4], [4, 5, 6, 7], [7, 8, 9, 10]], dtype=torch.float, device='cuda')
    x = x[:, None, None, :]
    print(x.shape)
    rotary_pe = RotaryPositionalEmbeddings(d_model, base)
    print(rotary_pe(x))



torch.Size([3, 1, 1, 4])
tensor([[[[  1.0000,   2.0000,   3.0000,   4.0000]]],


        [[[ -2.8876,   4.9298,   6.6077,   7.0496]]],


        [[[-11.0967,   7.7984,   2.6198,  10.1580]]]], device='cuda:0')
