# 练习3.1

**Comparing SelfAttention_v1 and SelfAttention_v2**  
Your task is to correctly assign the weights from an instance of
SelfAttention_v2 to an instance of SelfAttention_v1. To do this, you need
to understand the relationship between the weights in both versions. (Hint:
nn.Linear stores the weight matrix in a transposed form.) After the
assignment, you should observe that both instances produce the same outputs. 

In [3]:
import torch
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    
    def __init__(self, d_in, d_out):
        # 根据输入、输出维度初始化参数, Q、K、V矩阵(可学习参数)
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        # 计算keys, queries, values
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        # 计算attn weights
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(
            attn_scores / (keys.shape[1] ** 0.5), dim=-1
        )
        # 计算上下文向量
        context_vec = attn_weights @ values
        
        return context_vec
        

In [2]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        # 根据输入、输出维度初始化nn.Linear参数, Q、K、V矩阵(可学习参数)
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        # 计算keys, queries, values
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # 计算attn scores
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)

        # 计算上下文向量
        context_vec = attn_weights @ values

        return context_vec

In [18]:
# 测试输入
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # Your     (x^1)
     [0.55, 0.87, 0.66],  # journey  (x^2)
     [0.57, 0.85, 0.64],  # starts   (x^3)
     [0.22, 0.58, 0.33],  # with     (x^4)
     [0.77, 0.25, 0.10],  # one      (x^5)
     [0.05, 0.80, 0.55]]  # step     (x^6)
)
d_in = inputs.shape[1]
d_out = 2

torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2.W_query.weight.shape)
print(sa_v2.W_key.weight.shape)
print(sa_v2.W_value.weight.shape)

sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1.W_query.shape)
print(sa_v1.W_key.shape)
print(sa_v1.W_value.shape)

sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)

print("*" * 20, "sa_v2", "*" * 20)
print(sa_v2(inputs))
print("*" * 20, "sa_v1", "*" * 20)
print(sa_v1(inputs))


torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([2, 3])
torch.Size([3, 2])
torch.Size([3, 2])
torch.Size([3, 2])
******************** sa_v2 ********************
tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)
******************** sa_v1 ********************
tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)


# 练习3.2

**Returning 2-dimensional embedding vectors**  
Change the input arguments for the MultiHeadAttentionWrapper(...,
num_heads=2) call such that the output context vectors are 2-dimensional
instead of 4-dimensional while keeping the setting num_heads=2. Hint: You
don't have to modify the class implementation; you just have to change one of
the other input arguments.

In [19]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, 
                 dropout, qkv_bias=False):
        # 根据输入、输出维度初始化参数
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        # 新增加的dropout层
        self.dropout = nn.Dropout(dropout)
        # 创建mask，并注册为缓存，不参与梯度更新
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
    def forward(self, x):
        # 计算keys, queries, values
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # 增加mask逻辑计算attn weights
        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)
        
        # 使用dropout防止过拟合
        attn_weights = self.dropout(attn_weights)

        # 计算上下文向量
        context_vec = attn_weights @ values
        
        return context_vec

In [20]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        # 根据num_heads初始化多个单头注意力层
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)]
        )

    def forward(self, x):
        # 将每个head的输出拼接起来
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [32]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],  # Your     (x^1)
     [0.55, 0.87, 0.66],  # journey  (x^2)
     [0.57, 0.85, 0.64],  # starts   (x^3)
     [0.22, 0.58, 0.33],  # with     (x^4)
     [0.77, 0.25, 0.10],  # one      (x^5)
     [0.05, 0.80, 0.55]]  # step     (x^6)
)

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

context_length = batch.shape[1]
d_in = batch.shape[2]
d_out = 1
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

torch.Size([2, 6, 3])
tensor([[[-0.4798],
         [-0.4475],
         [-0.4361],
         [-0.3765],
         [-0.3470],
         [-0.3309]],

        [[-0.4798],
         [-0.4475],
         [-0.4361],
         [-0.3765],
         [-0.3470],
         [-0.3309]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 1])


In [34]:
context_length = batch.shape[1]
d_in, d_out = 3, 1
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.1786,  0.5817],
         [-0.0753,  0.5727],
         [-0.0342,  0.5687],
         [-0.0138,  0.4944],
         [ 0.0254,  0.4620],
         [ 0.0106,  0.4427]],

        [[-0.1786,  0.5817],
         [-0.0753,  0.5727],
         [-0.0342,  0.5687],
         [-0.0138,  0.4944],
         [ 0.0254,  0.4620],
         [ 0.0106,  0.4427]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


# 练习3.3

**Initializing GPT-2 size attention modules**  
Using the MultiHeadAttention class, initialize a multi-head attention
module that has the same number of attention heads as the smallest GPT-2
model (12 attention heads). Also ensure that you use the respective input and
output embedding sizes similar to GPT-2 (768 dimensions). Note that the
smallest GPT-2 model supports a context length of 1024 tokens.

In [24]:
# 利用权重拆分实现多头注意力，更高效
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        # 确保d_out是否能被num_heads整除
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        # 参数初始化
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # 增加线性层，不改变维度
        self.out_proj = nn.Linear(d_out, d_out)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape

        # 计算keys, queries, values
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # 将keys, queries, values拆分成多个head
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        # 转置，将num_heads移到前面，方便后续计算
        # shape = b, num_heads, num_tokens, head_dim
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # 计算attn weights
        # shape = b, num_heads, num_tokens, num_tokens
        attn_scores = queries @ keys.transpose(2, 3)
        # mask未来信息, 避免信息泄露，同时适配不同token长度
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # 归一化
        attn_weights = torch.softmax(attn_scores / (keys.shape[-1] ** 0.5), dim=-1)

        # 使用dropout防止过拟合
        attn_weights = self.dropout(attn_weights)
        
        # 计算上下文向量
        # shape = b, num_tokens, num_heads, head_dim
        context_vec = (attn_weights @ values).transpose(1, 2)
        # 调整上下文形状
        # shape = b, num_tokens, d_out(=num_heads * head_dim)
        # 在进行view之前，需要先进行contiguous()，否则会报错
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        # 线性层，增加一次变换
        context_vec = self.out_proj(context_vec)

        return context_vec


In [37]:
torch.manual_seed(123)

d_in = 768
d_out = 768
context_length = 1024
num_heads = 12 # 64 * 12 = 768

large_inputs = torch.rand(10, d_in)
large_batch = torch.stack([large_inputs, large_inputs], dim=0)
print(large_inputs.shape)
print(large_batch.shape)

mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=num_heads)

context_vecs = mha(large_batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

torch.Size([10, 768])
torch.Size([2, 10, 768])
tensor([[[ 0.1412,  0.0380,  0.2516,  ...,  0.1747, -0.3599, -0.0996],
         [ 0.2090,  0.0488,  0.2684,  ...,  0.1145, -0.2759, -0.0632],
         [ 0.1183,  0.0207,  0.2602,  ...,  0.1041, -0.2878, -0.0919],
         ...,
         [ 0.1387,  0.0279,  0.2362,  ...,  0.1131, -0.2243, -0.0805],
         [ 0.1103,  0.0187,  0.2680,  ...,  0.1130, -0.2266, -0.0812],
         [ 0.1139,  0.0234,  0.2802,  ...,  0.0983, -0.2193, -0.1011]],

        [[ 0.1412,  0.0380,  0.2516,  ...,  0.1747, -0.3599, -0.0996],
         [ 0.2090,  0.0488,  0.2684,  ...,  0.1145, -0.2759, -0.0632],
         [ 0.1183,  0.0207,  0.2602,  ...,  0.1041, -0.2878, -0.0919],
         ...,
         [ 0.1387,  0.0279,  0.2362,  ...,  0.1131, -0.2243, -0.0805],
         [ 0.1103,  0.0187,  0.2680,  ...,  0.1130, -0.2266, -0.0812],
         [ 0.1139,  0.0234,  0.2802,  ...,  0.0983, -0.2193, -0.1011]]],
       grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 10

In [43]:
# 可训练的参数量
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

### 参数计算
# q、k、v = d_in * d_out * 3 = 768 * 768 * 3 = 1769472
# linear_proj = 768 * 768 + 768(bias) = 590592

print(count_parameters(mha))
print(768 * 768 * 3 + 768 * 768 + 768)

2360064
2360064
