In [3]:
import torch 
from torch import nn 
import torch.nn.functional as F

# 多头注意力
## Multihead Attention (MHA)


In [27]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super.__init__()
        assert hidden_size % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_size//num_heads
        # 注意力模块需要Wq Wk Wv变换层，输出变换层，dropout
        self.q_linear = nn.Linear(hidden_size, hidden_size)   # Wq
        self.k_linear = nn.Linear(hidden_size, hidden_size)   # Wk
        self.v_linear = nn.Linear(hidden_size, hidden_size)   # Wv
        self.o_linear = nn.Linear(hidden_size, hidden_size)   # Wo
        self.dropout = 0.0
        self.attn_dropout = nn.Dropout(self.dropout)
        self.resid_dropout = nn.Dropout(self.dropout)

        self.FlashAttention = hasattr(F, 'scaled_dot_product_attention')  # 判断torch版本有无flashattention2

    def forward(self, hidden_state: torch.Tensor, attention_mask: bool = True) -> torch.Tensor:
        batch_size, seq_len, hidden_size = hidden_state.size()
        q = self.q_linear(hidden_state) # (batch_size, seq_len, hidden_size)
        k = self.k_linear(hidden_state)
        v = self.v_linear(hidden_state)
        #1  (batch_size, seq_len, hidden_size) -> (batch_size, seq_len, num_heads, hidden_size//num_heads) -> (batch_size, num_heads, seq_len, hidden_size//num_heads)
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2) 
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2) 
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2) 
        if self.FlashAttention:  # 若有flashattention2则调之，无则手动实现MHA
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p = self.dropout if self.training else 0.0, is_causal=True)
        else:
            attn = q @ k.transpose(-1, -2) / torch.sqrt(torch.Tensor(self.head_dim))
            casual_mask = nn.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)
            attn = attn.masked_fill(casual_mask, float('-inf'))
            attn = F.softmax(attn)
            attn = self.attn_dropout(attn)
            y = attn @ v

        #2 (batch_size, seq_len, num_heads, hidden_size//num_heads) -> (batch_size, seq_len, hidden_size)
        # #2为#1的相反变换过程
        # view()之前需要先用contiguous()使得内存连续
        y = y.transpose(1,2).contiguous().view(batch_size, seq_len, hidden_size)
        # 注意力结果经过一个线性变换层再dropout（optional）
        y = self.resid_dropout(self.o_linear(y))



# 多查询注意力
## MultiQuery Attention (MQA)

In [25]:
# 这里我忽略掉dropout
# 具体处理参考上面MHA实现
class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super.__init__()
        assert hidden_size % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        # 
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.head_dim)
        self.v_linear = nn.Linear(hidden_size, self.head_dim)
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        self.FlashAttention = hasattr(F, 'scaled_dot_product_attention')

    def forward(self, hidden_state: torch.Tensor, attention_mask: bool = True):
        batch_size, seq_len, hidden_size = hidden_state.size()
        dk = 1/torch.sqrt(self.head_dim)
        q = self.q_linear(hidden_state)
        k = self.k_linear(hidden_state)
        v = self.v_linear(hidden_state)
        # 每一个注意力模块共享一个k，v，将q分头，view将qkv的形状对齐
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(batch_size, -1, 1, self.head_dim).transpose(1,2)
        v = v.view(batch_size, -1, 1, self.head_dim).transpose(1,2)

        # @算符,torch.mulmat, F.scaled_dot_product_attention能通过广播机制方便地实现MHA
        if self.FlashAttention:
            y = F.scaled_dot_product_attention(q, k.transpose(-1,-2), dropout_p = 0.0, is_causal=attention_mask)
        else:
            attn = q @ k.transpose(-1,-2) * dk
            casual_mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)
            attn = attn.fill_masked(casual_mask, float('-inf'))
            attn = F.softmax(attn)
            y = attn @ v
        y = y.transpose(1,2).contiguous().view(batch_size, seq_len, hidden_size)
        y = self.o_linear(y)
        return y
            
               
        


# 组查询注意力
## Group Query Attention (GQA)

In [None]:
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, num_groups):
        super.__init__()
        assert hidden_size % num_heads == 0
        assert hidden_size % num_groups == 0
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = hidden_size // num_heads

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.num_groups * self.head_dim)
        self.v_linear = nn.Linear(hidden_size, self.num_groups * self.head_dim)
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        self.FlashAttention = hasattr(F, 'scaled_dot_product_attention')
    
    def forward(self, hidden_state: torch.Tensor, attention_mask: bool = True):
        batch_size, seq_len, hidden_size = hidden_state.size()
        dk = 1/nn.sqrt(self.head_dim)
        q = self.q_linaer(hidden_state)
        k = self.k_linaer(hidden_state)
        v = self.v_linaer(hidden_state)
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(batch_size, -1, self.num_groups, self.head_dim).transpose(1,2)
        v = v.view(batch_size, -1, self.num_groups, self.head_dim).transpose(1,2)
        k = k[:, :, None, :, :].expand(batch_size, self.num_groups, self.num_heads//self.num_groups, seq_len, self.head_dim)
        v = v[:, :, None, :, :].expand(batch_size, self.num_groups, self.num_heads//self.num_groups, seq_len, self.head_dim)
        # GQA的难点在这里。
        # k[:,:,None,:,:]可以新增一个维度，[batch_size, num_groups, seq_len, head_dim] -> [batch_size, num_groups, 1, seq_len, head_dim]
        # expand可以“复制”后面维度的数据成self.num_heads//self.num_groups份。不是真正的复制，而是只返回view，数据依旧共享内存。

        if self.FlashAttention:
            y = F.scaled_dot_product_attention(q, k.transpose(-1,-2), v, dropout_p=0.0, is_causal=attention_mask)
        else:
            attn = q @ k.transpose(-1,-2) * dk
            casual_mask = torch.tril(torch.ones(seq_len, seq_len)).view(1,1,seq_len, seq_len)
            attn = attn.fill_masked(casual_mask, float('-inf'))
            attn = F.softmax(attn)
            y = attn @ v
        
        y = self.o_linear(y)
        return y

