In [109]:
import torch 
from torch import nn 
import torch.nn.functional as F
import numpy as np 
from typing import Tuple


# 多头注意力
## Multihead Attention (MHA)


In [27]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super.__init__()
        assert hidden_size % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_size//num_heads
        # 注意力模块需要Wq Wk Wv变换层，输出变换层，dropout
        self.q_linear = nn.Linear(hidden_size, hidden_size)   # Wq
        self.k_linear = nn.Linear(hidden_size, hidden_size)   # Wk
        self.v_linear = nn.Linear(hidden_size, hidden_size)   # Wv
        self.o_linear = nn.Linear(hidden_size, hidden_size)   # Wo
        self.dropout = 0.0
        self.attn_dropout = nn.Dropout(self.dropout)
        self.resid_dropout = nn.Dropout(self.dropout)

        self.FlashAttention = hasattr(F, 'scaled_dot_product_attention')  # 判断torch版本有无flashattention2

    def forward(self, hidden_state: torch.Tensor, attention_mask: bool = True) -> torch.Tensor:
        batch_size, seq_len, hidden_size = hidden_state.size()
        q = self.q_linear(hidden_state) # (batch_size, seq_len, hidden_size)
        k = self.k_linear(hidden_state)
        v = self.v_linear(hidden_state)
        #1  (batch_size, seq_len, hidden_size) -> (batch_size, seq_len, num_heads, hidden_size//num_heads) -> (batch_size, num_heads, seq_len, hidden_size//num_heads)
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2) 
        k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2) 
        v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2) 
        if self.FlashAttention:  # 若有flashattention2则调之，无则手动实现MHA
            y = F.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask, dropout_p = self.dropout if self.training else 0.0, is_causal=True)
        else:
            attn = q @ k.transpose(-1, -2) / torch.sqrt(torch.Tensor(self.head_dim))
            casual_mask = nn.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)
            attn = attn.masked_fill(casual_mask, float('-inf'))
            attn = F.softmax(attn)
            attn = self.attn_dropout(attn)
            y = attn @ v

        #2 (batch_size, seq_len, num_heads, hidden_size//num_heads) -> (batch_size, seq_len, hidden_size)
        # #2为#1的相反变换过程
        # view()之前需要先用contiguous()使得内存连续
        y = y.transpose(1,2).contiguous().view(batch_size, seq_len, hidden_size)
        # 注意力结果经过一个线性变换层再dropout（optional）
        y = self.resid_dropout(self.o_linear(y))



# 多查询注意力
## MultiQuery Attention (MQA)

In [25]:
# 这里我忽略掉dropout
# 具体处理参考上面MHA实现
class MultiQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super.__init__()
        assert hidden_size % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        # 
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.head_dim)
        self.v_linear = nn.Linear(hidden_size, self.head_dim)
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        self.FlashAttention = hasattr(F, 'scaled_dot_product_attention')

    def forward(self, hidden_state: torch.Tensor, attention_mask: bool = True):
        batch_size, seq_len, hidden_size = hidden_state.size()
        dk = 1/torch.sqrt(self.head_dim)
        q = self.q_linear(hidden_state)
        k = self.k_linear(hidden_state)
        v = self.v_linear(hidden_state)
        # 每一个注意力模块共享一个k，v，将q分头，view将qkv的形状对齐
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(batch_size, -1, 1, self.head_dim).transpose(1,2)
        v = v.view(batch_size, -1, 1, self.head_dim).transpose(1,2)

        # @算符,torch.mulmat, F.scaled_dot_product_attention能通过广播机制方便地实现MHA
        if self.FlashAttention:
            y = F.scaled_dot_product_attention(q, k.transpose(-1,-2), dropout_p = 0.0, is_causal=attention_mask)
        else:
            attn = q @ k.transpose(-1,-2) * dk
            casual_mask = torch.tril(torch.ones(seq_len, seq_len)).view(1, 1, seq_len, seq_len)
            attn = attn.fill_masked(casual_mask, float('-inf'))
            attn = F.softmax(attn)
            y = attn @ v
        y = y.transpose(1,2).contiguous().view(batch_size, seq_len, hidden_size)
        y = self.o_linear(y)
        return y
            
               
        


# 组查询注意力
## Group Query Attention (GQA)

In [None]:
class GroupQueryAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, num_groups):
        super.__init__()
        assert hidden_size % num_heads == 0
        assert hidden_size % num_groups == 0
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = hidden_size // num_heads

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.num_groups * self.head_dim)
        self.v_linear = nn.Linear(hidden_size, self.num_groups * self.head_dim)
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        self.FlashAttention = hasattr(F, 'scaled_dot_product_attention')
    
    def forward(self, hidden_state: torch.Tensor, attention_mask: bool = True):
        batch_size, seq_len, hidden_size = hidden_state.size()
        dk = 1/nn.sqrt(self.head_dim)
        q = self.q_linaer(hidden_state)
        k = self.k_linaer(hidden_state)
        v = self.v_linaer(hidden_state)
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(batch_size, -1, self.num_groups, self.head_dim).transpose(1,2)
        v = v.view(batch_size, -1, self.num_groups, self.head_dim).transpose(1,2)
        k = k[:, :, None, :, :].expand(batch_size, self.num_groups, self.num_heads//self.num_groups, seq_len, self.head_dim)
        v = v[:, :, None, :, :].expand(batch_size, self.num_groups, self.num_heads//self.num_groups, seq_len, self.head_dim)
        # GQA的难点在这里。
        # k[:,:,None,:,:]可以新增一个维度，[batch_size, num_groups, seq_len, head_dim] -> [batch_size, num_groups, 1, seq_len, head_dim]
        # expand可以“复制”后面维度的数据成self.num_heads//self.num_groups份。不是真正的复制，而是只返回view，数据依旧共享内存。

        if self.FlashAttention:
            y = F.scaled_dot_product_attention(q, k.transpose(-1,-2), v, dropout_p=0.0, is_causal=attention_mask)
        else:
            attn = q @ k.transpose(-1,-2) * dk
            casual_mask = torch.tril(torch.ones(seq_len, seq_len)).view(1,1,seq_len, seq_len)
            attn = attn.fill_masked(casual_mask, float('-inf'))
            attn = F.softmax(attn)
            y = attn @ v
        
        y = self.o_linear(y)
        return y



# RoPE

In [None]:
def compute_freqs_cis(dim: int, seqlen:int, base:float=10000.0) -> torch.Tensor:
    ## m和theta求外积，转成向量，1为幅值，m_theta为指数
    freqs = 1.0 / (base ** (torch.arange(0, dim, 2)/dim))
    m = torch.arange(seqlen)
    m_theta = torch.outer(m, freqs)
    freqs_cis = torch.polar(torch.ones_like(m_theta), m_theta)
    return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x:torch.Tensor) -> torch.Tensor:
    ## 将旋转矩阵的形状变换成qk矩阵的形状，以供下一步广播
    ndim = x.ndim
    assert freqs_cis.shape==(x.shape[1], x.shape[-1]), "matrixs dont match."
    freq_shape = [d if i==1 or i==ndim-1 else 1 for i,d in enumerate(x.shape)]  # [seqlen, dim//2] -> [1, seqlen, 1, dim//2]
    return freqs_cis.view(*freq_shape)

def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_reshape = xq.reshape(xq.shape[:-1], -1, 2) # [batch_size, seq_len, num_heads, head_dim] -> [batch_size, seq_len, num_heads, head_dim//2, 2]??
    xk_reshape = xk.reshape(xk.shape[:-1], -1, 2)
    xq_complex = torch.view_as_complex(xq_reshape)
    xk_complex = torch.view_as_complex(xk_reshape)
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_complex)
    xq_out = torch.view_as_real(xq_complex*freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_complex*freqs_cis).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

# 下面是伪代码，只说明RoPE如何加入到注意力层中
class Lite_attention(nn.Module):
    def __init__(self, hidden_size, num_heads, seq_len):
        super.__init__()
        self.num_heads = num_heads
        self.head_dim = hidden_size//num_heads
        self.q_linear = nn.Linear(hidden_size, hidden_size)   # Wq
        self.k_linear = nn.Linear(hidden_size, hidden_size)   # Wk
        self.v_linear = nn.Linear(hidden_size, hidden_size)   # Wv
        self.o_linear = nn.Linear(hidden_size, hidden_size)   # Wo
        self.freqs_cis = compute_freqs_cis(self.head_dim, seq_len)

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        self.xq = self.q_linear(x)
        self.xk = self.k_linear(x)
        self.xv = self.v_linear(x)
        self.xq = self.xq.view(batch_size, seq_len, self.num_heads, -1)
        self.xk = self.xk.view(batch_size, seq_len, self.num_heads, -1)
        self.xv = self.xv.view(batch_size, seq_len, self.num_heads, -1)
        self.xq_, self.xk_ = apply_rope(self.xq, self.xk, self.freqs_cis)
        ##下面执行Attention计算....
        ## Attention calulation......
        ## 




# Attention with KV Cache


In [125]:
# 这里继续忽略掉dropout
class MultiHeadAttention_with_KVCache(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        assert hidden_size%num_heads == 0
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, hidden_size)
        self.v_linear = nn.Linear(hidden_size, hidden_size)
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        self.k_cache = self.v_cache = None
        self.total_len = 0
        
    
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        q = self.q_linear(x)
        k = self.k_linear(x)
        v = self.v_linear(x)
        q = q.view(batch_size, seq_len, -1, self.head_dim).transpose(1,2)
        k = k.view(batch_size, seq_len, -1, self.head_dim).transpose(1,2)
        v = v.view(batch_size, seq_len, -1, self.head_dim).transpose(1,2)
        ## KV Cache, 用内存换速度，内存大小(t+o)*hidden_size * 4byte (float32), t为prompt长度，o为已输出token长度
        if self.k_cache == None:  # 
            self.k_cache = k
            self.v_cache = v
            self.total_len = seq_len
        else:
            self.k_cache = torch.cat([self.k_cache, k], dim = 2)
            self.v_cache = torch.cat([self.v_cache, v], dim = 2)
            k = self.k_cache
            v = self.v_cache
            self.total_len += seq_len
        mask = nn.tril(torch.ones((self.total_len, self.total_len)).view(1,1,self.total_len,self.total_len))
   
        
        atten = q@k.transpose(-1, -2) / torch.sqrt(self.head_dim)
        atten = atten.mask_filled(mask, float('-inf'))
        y = F.softmax(atten, dim=-1) @ v
        y = y.transpose(1,2).contiguous().view(batch_size, seq_len, self.hidden_size)
        return self.o_linear(y)
    

class MultiQueryAttention_with_KVCache(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.head_dim  = hidden_size // num_heads

        self.q_linear = nn.Linear(hidden_size)
        self.k_linear = nn.Linear(self.head_dim)
        self.v_linear = nn.Linear(self.head_dim)        
        self.o_linear = nn.Linear(hidden_size)
        self.k_cache = self.v_cache = None
        self.total_len = 0

    # 兼顾训练和推理时有无kvcache的情况
    # 以及单token推理时应该可以去掉mask
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        q, k, v = self.q_linear(x), self.k_linear(x), self.v_linear(x)
        # 内存大小(t+o)*head_dim * 4byte (float32), t为prompt长度，o为已输出token长度
        # 缩小num_heads倍
        if self.training:
            mask = torch.tril(torch.ones(self.seq_len, self.seq_len)).view(1,1,self.seq_len, self.seq_len)
        else:
            if self.k_cache == None:
                self.k_cache, self.v_cache = k, v
                self.total_len = seq_len
                mask = torch.tril(torch.ones(self.seq_len, self.seq_len)).view(1,1,self.seq_len, self.seq_len)
            else:
                self.k_cache = torch.concat([self.k_cache, k])
                self.v_cache = torch.concat([self.v_cache, v])
                k = self.k_cache
                v = self.v_cache
                self.total_len += seq_len

        q = q.view(batch_size, seq_len, self.num_heads, -1).transpose(1,2)
        k = k.view(batch_size, seq_len, 1, self.head_dim).transpose(1,2)
        v = v.view(batch_size, seq_len, 1, self.head_dim).transpose(1,2)

        atten = q@k.transpose(-1,-2) / torch.sqrt(self.head_dim)
        if self.training:
            atten = atten.masked_fill(mask, float('-inf'))  # 在
        atten = F.softmax(atten, dim=-1)
        y = atten @ v
        y = y.transpose(1,2).contiguous().view(batch_size, seq_len, self.hidden_size)
        return self.o_linear(y)

class GroupQueryAttention_with_KVCache(nn.Module):
    def __init__(self, hidden_size, num_groups, num_heads):
        super().__init__()
        assert hidden_size%num_heads==0 and hidden_size%num_groups==0
        self.num_heads = num_heads
        self.num_groups = num_groups
        self.head_dim = hidden_size//self.num_heads

        self.q_linear = nn.Linear(hidden_size, hidden_size)
        self.k_linear = nn.Linear(hidden_size, self.num_groups*self.head_dim)
        self.v_linear = nn.Linear(hidden_size, self.num_groups*self.head_dim)
        self.o_linear = nn.Linear(hidden_size, hidden_size)
        self.total_len = 0
        self.k_cache = self.v_cache = None

    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        q, k, v = self.q_linear(x), self.k_linear(x), self.v_linear(x)
        q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1,2)
        k = k.view(batch_size, -1, self.num_groups, self.head_dim).transpose(1,2)
        v = v.view(batch_size, -1, self.num_groups, self.head_dim).transpose(1,2)
        k = k[:,:,None,:,:].expand(batch_size, self.num_groups, self.num_heads//self.num_groups, seq_len, self.head_dim)
        v = v[:,:,None,:,:].expand(batch_size, self.num_groups, self.num_heads//self.num_groups, seq_len, self.head_dim)
        if self.k_cache == None:
            self.k_cache = k
            self.v_cache = v
            self.total_len = seq_len
        else:
            self.k_cache = torch.concat([self.k_cache, k], dim=3)
            self.v_cache = torch.concat([self.v_cache, v], dim=3)
            k = self.k_cache
            v = self.v_cache
            self.total_len += seq_len
        
        mask = torch.tril(torch.ones(self.total_len, self.total_len)).view(1, 1, self.total_len, self.total_len)
        atten = q @ k.transpose(-2,-1) / torch.sqrt(self.head_dim)
        atten = atten.masked_fill(mask==0, float('-inf'))
        y = F.softmax(atten, dim = -1) @ v
        y = y.transpose(1,2).contiguous().view(batch_size, seq_len, -1)
        
        return self.o_lieanr(y)
        

