# Self-Attention

In [38]:
import torch
import torch.nn as nn
import math
from torch import repeat_interleave


In [33]:
class self_attention(nn.Module):
    def __init__(self, input_dim) -> None:
        super().__init__()
        self.embedding_dim = input_dim
        self.attention_dim = input_dim
        self.q = nn.Linear(self.embedding_dim, self.attention_dim)
        self.k = nn.Linear(self.embedding_dim, self.attention_dim)
        self.v = nn.Linear(self.embedding_dim, self.attention_dim)
    def forward(self, x):# x.shape = bs, seqlen, input_dim 
        q = self.q(x)
        k = self.k(x)
        v = self.v(x)
        scores = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(self.attention_dim)
        weights = torch.softmax(scores, dim=-1)
        out = torch.bmm(weights, v)
        return out



## 代码测试

In [34]:
bs, seqlen, input_dim = 2, 10, 64
x = torch.rand(bs, seqlen, input_dim)
attn = self_attention(input_dim)
output = attn(x)
print("x.shape:", x.shape)
print("out.shape:", output.shape)


x.shape: torch.Size([2, 10, 64])
out.shape: torch.Size([2, 10, 64])


# Multi-Head Attention

In [36]:
class MHA(nn.Module):
    def __init__(self, head, input_dim) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.head = head
        assert self.input_dim % self.head == 0
        self.attn_dim = self.input_dim // head
        self.wq = nn.Linear(self.input_dim, self.attn_dim * self.head)
        self.wk = nn.Linear(self.input_dim, self.attn_dim * self.head)
        self.wv = nn.Linear(self.input_dim, self.attn_dim * self.head)

        self.out = nn.Linear(self.input_dim, self.attn_dim * self.head)# 使用一个投影层整合每个头之间的信息

    def forward(self, x):
        bs, seq, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bs, seq, self.head, self.attn_dim).transpose(1, 2)
        xk = xk.view(bs, seq, self.head, self.attn_dim).transpose(1, 2)
        xv = xv.view(bs, seq, self.head, self.attn_dim).transpose(1, 2)
        # xv.shape : bs, seq, head, attn_dim
        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.attn_dim) # bs, head, seq, seq
        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, xv) # bs, head, seq, attn_dim
        out = out.transpose(1, 2).contiguous().view(bs, seq, -1)
        # 这里必须调用.contiguous()，因为.view()方法需要内存空间中连续的tensor，而.transpose()方法会使得原本的tensor变得不连续
        return self.out(out)


## 代码测试

In [37]:
bs, seqlen, head, input_dim = 2, 10, 8, 64
x = torch.rand(bs, seqlen, input_dim)
mha = MHA(head, input_dim)
output = mha(x)
print("x.shape:", x.shape)
print("out.shape:", output.shape)

x.shape: torch.Size([2, 10, 64])
out.shape: torch.Size([2, 10, 64])


# Grouped Multi-Query Attention

In [None]:

class GQA(nn.Module):
    def __init__(self, head, input_dim, kv_heads) -> None:
        super().__init__()
        self.input_dim = input_dim
        self.head = head
        assert self.input_dim % self.head == 0
        self.head_dim = self.input_dim // self.head
        self.kv_heads = kv_heads
        assert self.head % self.kv_heads == 0
        self.group = self.head // self.kv_heads
        self.wq = nn.Linear(self.input_dim, self.head_dim * self.head)

        self.wk = nn.Linear(self.input_dim, self.head_dim * self.kv_heads)# 这里头的数量不再是head，而是kv_heads,即将Q分为了self.group个组
        self.wv = nn.Linear(self.input_dim, self.head_dim * self.kv_heads)

        self.wo = nn.Linear(self.input_dim, self.head_dim * self.head)

    def forward(self, x):
        bs, seq, _ = x.shape
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bs, seq, self.head, self.head_dim).transpose(1, 2)
        xk = xk.view(bs, seq, self.kv_heads, self.head_dim).transpose(1, 2)
        xv = xv.view(bs, seq, self.kv_heads, self.head_dim).transpose(1, 2)
        # bs head seq head_dim
        xk = xk.repeat_interleave(self.group, dim=1)
        xv = xv.repeat_interleave(self.group, dim=1)
        scores = torch.matmul(xq, xk.transpose(-1, -2)) / math.sqrt(self.head_dim)
        weights = torch.softmax(scores, dim=-1)
        out = torch.matmul(weights, xv)
        out = out.transpose(1, 2).contiguous().view(bs, seq, -1)
        return self.wo(out)



## 代码测试

In [None]:
bs, seq, input_dim, kv_heads, head = 2, 10, 64, 4, 8
x = torch.rand(bs, seq, input_dim)
gqa = GQA(head, input_dim, kv_heads)
out = gqa(x)
print("x.shape:", x.shape)
print("out.shape:", output.shape)

### 值得注意的点

        xk = xk.repeat_interleave(self.group, dim=1)
        xv = xv.repeat_interleave(self.group, dim=1)

这里使用了.repeat_interleave()函数来对输入数据的指定维度进行拷贝，这是实现复制的第一种方式。

另一种方式是：

        # xk: (bs, kv_heads, seq, head_dim)
        xk = xk.unsqueeze(2)  # (bs, kv_heads, 1, seq, head_dim)
        xk = xk.expand(-1, -1, self.group, -1, -1)  # (bs, kv_heads, group, seq, head_dim)
        xk = xk.reshape(bs, -1, seq, self.head_dim)  # (bs, head, seq, head_dim)

这里分别给出unsqueeze， expand， reshape三个函数的作用：

1. **unsqueeze()**: 用于在指定维度前面添加一个为大小1的维度，一般用于维度的增加；
2. **expand()**: 用于扩展数量为1的维度的数量，并且不复制数据；
3. **reshape()**: 重组tensor形状，当tensor连续时不复制数据，不连续时复制数据；


In [71]:

x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
x



tensor([1, 2, 3, 4, 5, 6, 7, 8])

In [72]:
x = x.reshape(2, 4)
x

tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

In [73]:
x = x.unsqueeze(1)
x = x.expand(-1, 2, -1)
x

tensor([[[1, 2, 3, 4],
         [1, 2, 3, 4]],

        [[5, 6, 7, 8],
         [5, 6, 7, 8]]])

In [None]:
x = x.reshape(16)
x