In [24]:
import torch
import math
import torch.nn as nn


class MultiHeadAttention(nn.Module):
    def __init__(self, head_nums:int, dim:int, drop_radio:float=0.1):
        super().__init__()
        self.head_nums = head_nums
        self.dim = dim
        self.head_dim = dim // head_nums

        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)

        self.out_proj = nn.Linear(dim, dim)
        self.att_drop = nn.Dropout(drop_radio)

    def forward(self, x, attention_mask:torch.Tensor=None):
        # x shape (batch_size, seq_len, dim)
        bs, seq_len, dim = x.shape
        # QKV shape (batch_size, seq_len, dim)
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # (batch_size, head_nums, seq_len, head_dim)
        q_state = Q.reshape(bs, seq_len, self.head_nums, self.head_dim).transpose(1, 2)
        k_state = K.reshape(bs, seq_len, self.head_nums, self.head_dim).transpose(1, 2)
        v_state = V.reshape(bs, seq_len, self.head_nums, self.head_dim).transpose(1, 2)


        att_val = q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)

        # attention_mask (batch, head_nums, seq, seq)
        if attention_mask is not None:
            att_val = att_val.masked_fill(
                attention_mask == 0,
                float('-inf')
            )
        
        # (batch_size, head_nums, seq, seq)
        att_weight = torch.softmax(att_val, dim=-1)
        print(att_weight)

        # dropout
        att_weight = self.att_drop(att_weight)

        # output_state (batch_size, nums_head, seq_len, head_dim)
        output_state = att_weight @ v_state

        # output (batch_size, seq_len, dim)
        output = output_state.transpose(1, 2).reshape(bs, seq_len, dim)

        output = self.out_proj(output)
        return output     

if __name__ == '__main__':
    x = torch.randn(3, 3, 128)
    model = MultiHeadAttention(8, 128)
    attention_mask = torch.tensor(
        [
            [1, 0, 0],
            [1, 1, 0],
            [1, 1, 0]
        ]
    ).unsqueeze(1).unsqueeze(2).expand(3, 8, 3, 3)

    # 用repeat的方式扩展attention_mask
    attention_mask = torch.tensor(
        [
            [1, 0, 0],
            [1, 1, 0],
            [1, 1, 0]
        ]
    ).unsqueeze(1).unsqueeze(2).repeat(1, 8, 3, 1)
    
    print(model(x, attention_mask).shape)
    

tensor([[[[1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000]],

         [[1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000]]],


        [[[0.2903, 0.7097, 0.0000],
          [0.5406, 0.4594, 0.0000],
          [0.5025, 0.4975, 0.0000]],

        

In [41]:
import torch
import torch.nn as nn
import math

# nums_key_value_head设置为1就是MQA，设置为head_nums就是MHA
class GroupQueryAttention(nn.Module):
    def __init__(self, dim:int, head_nums:int, nums_key_value_head:int, drop_radio:float=0.1):
        # nums_key_valu_head 相当于是KV的head数
        super().__init__()

        assert dim % head_nums == 0 # 满足整除
        assert head_nums % nums_key_value_head == 0 # N个query head为一组
        self.dim = dim
        self.head_dim = dim // head_nums
        self.head_nums = head_nums
        self.nums_key_value_head = nums_key_value_head
        
        self.q_proj = nn.Linear(dim, self.head_nums * self.head_dim)
        self.k_proj = nn.Linear(dim, nums_key_value_head * self.head_dim)
        self.v_proj = nn.Linear(dim, nums_key_value_head * self.head_dim)

        self.att_weight_drop = nn.Dropout(drop_radio)
        self.o_proj = nn.Linear(dim, dim)

    def forward(self, x, attention_mask:torch.Tensor=None):
        # x shape (batch_size, seq_len, dim)
        batch_size, seq_len, dim = x.shape

        # QKV
        # Q (batch_size, seq_len, head_nums*head_dim)
        Q = self.q_proj(x)
        # KV (batch_size, seq_len, nums_key_value_head*head_dim)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # q_state (batch_size, head_nums, seq_len, head_dim)
        q_state = Q.reshape(batch_size, seq_len, self.head_nums, self.head_dim).transpose(1, 2)
        # k_state, v_state (batch_size, nums_key_value_head, seq_len, head_dim)
        k_state = K.reshape(batch_size, seq_len, self.nums_key_value_head, self.head_dim).transpose(1, 2)
        v_state = V.reshape(batch_size, seq_len, self.nums_key_value_head, self.head_dim).transpose(1, 2)

        # k,v repeat 广播操作 (batch_size, head_nums, seq_len, head_dim)
        k_state = k_state.repeat_interleave(
            self.head_nums // self.nums_key_value_head, dim=1)
        v_state = v_state.repeat_interleave(
            self.head_nums // self.nums_key_value_head, dim=1)

        # attention_val (batch_size, head_nums, seq_len, seq_len)
        attention_val = q_state @ k_state.transpose(-1, -2) / math.sqrt(self.head_dim)

        # attention_mask
        if attention_mask is not None:
            attention_val = attention_val.masked_fill(
                attention_mask == 0,
                float('-inf')
            )
        attention_weight = torch.softmax(attention_val, dim=-1) 
        print(attention_weight)

        # dropout
        attention_weight = self.att_weight_drop(attention_weight)

        # output_state (batch_size, head_nums, seq_len, head_dim)
        output_state = attention_weight @ v_state
        
        # (batch_size, seq_len, dim)
        output = output_state.transpose(1, 2).reshape(batch_size, seq_len, -1)

        output = self.o_proj(output)

        return output
    
if __name__ == '__main__':
    x = torch.rand(3, 2, 128)

    # (batch_size, head_nums, seq_len, seq_len) -> 3, 8, 2, 2
    attention_mask = torch.tensor(
        [
            [1, 1],
            [1, 0],
            [1, 0]
        ]
    ).unsqueeze(1).unsqueeze(2).repeat(1, 8, 2, 1)

    attention_mask = torch.tensor(
        [
            [1, 1],
            [1, 0],
            [1, 0]
        ]
    ).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)
    model = GroupQueryAttention(dim=128,head_nums=8, nums_key_value_head=2)
    print(model(x, attention_mask=attention_mask).shape)

tensor([[[[0.4864, 0.5136],
          [0.4827, 0.5173]],

         [[0.4970, 0.5030],
          [0.5078, 0.4922]],

         [[0.5348, 0.4652],
          [0.5167, 0.4833]],

         [[0.5250, 0.4750],
          [0.5303, 0.4697]],

         [[0.5142, 0.4858],
          [0.4925, 0.5075]],

         [[0.4912, 0.5088],
          [0.4937, 0.5063]],

         [[0.4628, 0.5372],
          [0.4619, 0.5381]],

         [[0.5207, 0.4793],
          [0.5397, 0.4603]]],


        [[[1.0000, 0.0000],
          [1.0000, 0.0000]],

         [[1.0000, 0.0000],
          [1.0000, 0.0000]],

         [[1.0000, 0.0000],
          [1.0000, 0.0000]],

         [[1.0000, 0.0000],
          [1.0000, 0.0000]],

         [[1.0000, 0.0000],
          [1.0000, 0.0000]],

         [[1.0000, 0.0000],
          [1.0000, 0.0000]],

         [[1.0000, 0.0000],
          [1.0000, 0.0000]],

         [[1.0000, 0.0000],
          [1.0000, 0.0000]]],


        [[[1.0000, 0.0000],
          [1.0000, 0.0000]],

         [