In [3]:
import numpy as np
import torch
from torch import Tensor
from typing import Optional, Any, Union, Callable
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time

**参考**  
[知乎:self-attention](https://zhuanlan.zhihu.com/p/455399791)

**nn.ModuleList**  
nn.ModuleList是pytorch中的一个容器，允许将多个模块组合在一起，形成一个更大的模块。
特点：  
1. **动态构建模块列表**：可以在创建后，随时动态添加模块到列表中，对于需要在运行时根据条件构建不同子模块的情况非常有用。  
2. **不会自动跟踪模块**：nn.ModuleList不会自动跟踪其包含的模块，意味着它的参数不会自动添加到父模块的参数列表中。需要自己手动管理这些参数。

In [24]:
import torch.nn as nn

class CustomModel(nn.Module):
    def __init__(self, layer_sizes):
        #在构造函数中初始化模型
        super(CustomModel, self).__init__()

        #创建一个神经网络层列表，每一个元素都是一个线性层(nn.Linear)
        #layer_sizes 是一个包含每个层的输入和输出维度的列表
        #循环遍历layer_sizes，创建并添加线性层到layers
        self.layers = nn.ModuleList([nn.Linear(layer_sizes[i], layer_sizes[i+1]) for i in range(len(layer_sizes)-1)])
        print(self.layers)

    def forward(self, x):
        #定义前向传播函数forward
        #输入x表示模型的输入数据

        #遍历 self.layers中的线性层，依次将输入x传递给每一层
        for layer in self.layers:
            x = layer(x)

        #返回最终输出，x包含了前向传播后的模型输出
        return x

mlist = CustomModel([10, 20, 30, 40])
mlist

ModuleList(
  (0): Linear(in_features=10, out_features=20, bias=True)
  (1): Linear(in_features=20, out_features=30, bias=True)
  (2): Linear(in_features=30, out_features=40, bias=True)
)


CustomModel(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): Linear(in_features=20, out_features=30, bias=True)
    (2): Linear(in_features=30, out_features=40, bias=True)
  )
)

In [13]:
q,k,v = Tensor(), Tensor(), Tensor()
mm = [q,k,v]
zip(m, mm)

TypeError: zip argument #1 must support iteration

In [35]:
class MultiHeadedAttention(nn.Module):
    def __init__(self,
                 num_heads: int,  #训练多头个数
                 d_model: int,     #输入token的embedding维数
                 dropout: float = 0.1):
        super(MultiHeadedAttention, self).__init__()
        assert d_model % num_heads == 0, "d_model must be a multiple of num_heads"
        # Assume v_dim always equals k_dim
        self.k_dim = d_model // num_heads #每一个W的维度数  
        self.num_heads = num_heads  #训练的attention的多头数
        #W^Q, W^K, W^V, W^O，其维度(d_model, d_k))
        self.proj_weights = clones(nn.Linear(d_model, d_model), 4) 
        print("###########")
        print(self.proj_weights)
        self.attention_score = None #attention_score
        self.dropout = nn.Dropout(p=dropout)

    def forward(self,
                query: Tensor,
                key: Tensor,
                value: Tensor,
                mask: Optional[Tensor] = None):
        """
        Args:
            query: shape (batch_size, seq_len, d_model)
            key: shape (batch_size, seq_len, d_model)
            value: shape (batch_size, seq_len, d_model)
            mask: shape (batch_size, seq_len, seq_len).Since we assume all data use a same mask, so
                here the shape also equals to (1, seq_len, seq_len)

        Returns:
            out: shape (batch_size, seq_len, d_model).The output of a multihead attention layer
        """
        if mask is not None:
            # Same mask applied to all heads
            mask = mask.unsqueeze(1)   #unsqueeze() 增加张量维数，接受参数表示在哪个维度之前增加新维度
        batch_size = query.size(0)


        print(query)
        print(key)
        print(value)
        m = zip(self.proj_weights, [query, key, value])
        print(m)

        # 1) Apply W^Q, W^K, W^V to generate new query, key, value
        query, key, value\
            = [proj_weight(x).view(batch_size, -1, self.num_heads, self.k_dim).transpose(1,2)
               for proj_weight, x in zip(self.proj_weights, [query, key, value])]
        print("------------------------")
        
        # 2) Calculate attention score and the out
        out, self.aatention_score = attention(query, key, value, mask=mask, dropout=self.dropout)
        
        # 3)Concat output
        out = out.transpose(1,2).contiguous()\
            .view(batch_size, -1, self.num_heads * self.k_dim)
        
        # 4) Apply W^o to get the final output
        out = self.proj_weights[-1](out)

        return out
    
def clones(module, N):
    "Produce N identical layers."
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

#定义attention的计算逻辑
def attention(query:Tensor,
              key:Tensor,
              value:Tensor, 
              mask:Optional[Tensor] = None,
              dropout:Optional[Callable] = None):
    """
    Define how to calculate attention score
    Args:
        query: shape (batch_size, num_heads, seq_len, k_dim)
        key: shape (batch_size, num_heads, seq_len, k_dim)
        value: shape (batch_size, num_heads, seq_len, v_dim)
        mask: shape (batch_size, num_heads, seq_len, seq_len). Since our assumption, here the shape is
        (1, 1, seq_len, seq_len)
    """
    #取出w^Q, w^K, w^V的维度，以便后续计算attention score时，需要scaled by sqrt(k_dim)
    k_dim = query.size(-1)

    #shape (seq_len, seq_len), row:token, col:that token's attention score
    #transpose(x, y) x,y转置
    #Q*K^T,
    #query shape(batch_size, num_heads, seq_len, k_dim) 
    #key^T shape(batch_size, num_heads, k_dim, seq_len)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(k_dim)

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    
    attention_score = F.softmax(scores, dim=-1)

    if dropout is not None:
        attention_score = dropout(attention_score)  

    out = torch.matmul(attention_score, value)

    return out, attention_score
    

In [36]:
if __name__ == "__main__":
    # Test MultiHeadedAttention
    batch_size = 6
    seq_len = 3
    d_model = 8
    num_heads = 2
    #mask = None
    mask = torch.tril(torch.ones((seq_len, seq_len)), diagonal = 0).unsqueeze(0)
    

    input = torch.rand(batch_size, seq_len, d_model)
    multi_attn = MultiHeadedAttention(num_heads = num_heads, d_model = d_model, dropout = 0.1)
    out = multi_attn(query = input, key = input, value = input, mask = mask)
    print(out.shape)

###########
ModuleList(
  (0): Linear(in_features=8, out_features=8, bias=True)
  (1): Linear(in_features=8, out_features=8, bias=True)
  (2): Linear(in_features=8, out_features=8, bias=True)
  (3): Linear(in_features=8, out_features=8, bias=True)
)
<zip object at 0x00000247863D6CC8>
------------------------
torch.Size([6, 3, 8])
