In [4]:
import torch
import torch.nn as nn

In [18]:
def qkv_attention_value(q, k, v, mask = False):

    k = torch.transpose(k, dim0 = -2, dim1 = -1)
    scores = torch.matmul(q, k)

    if isinstance(mask, bool):
        if mask:
            _shape = scores.shape
            mask = torch.ones((_shape[-2], _shape[-1]))
            mask = mask[None][None]
        else:
            mask = None
    if mask is not None:
        scores = scores + mask

    alpha = torch.softmax(scoresm, dim = -1)

    v = torch.matmul(alpha, v)
    return v

In [32]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, hidden_size, num_header):
        super(MultiHeadSeflAttention, self).__init__()
        assert hidden_size % num_header == 0, f"header的数目无法整除:{hidden_size}, {num_header}"

        self.hidden_size = hidden_size
        self.num_header = num_header
        self.wq = nn.Sequential(
            nn.Linear(in_features = self.hidden_size, out_features = self.hidden_size)
        )
        self.wk = nn.Sequential(
            nn.Linear(in_features = self.hidden_size, out_features = self.hidden_size)
        )
        self.wv = nn.Sequential(
            nn.Linear(in_features = self.hidden_size, out_features = self.hidden_size)
        )
        self.wo = nn.Sequential(
            nn.Linear(in_features = self.hidden_size, out_features = self.hidden_size),
            nn.ReLU()
        )

    def split(self, vs):
        n, t, e = vs.shape
        vs = torch.reshape(vs, shape = (n, t, self.num_header, e // self.num_header))
        vs = torch.permute(vs, dims = (0,2,1,3))
        return vs
    def forward(self, x, attention_mask = None, **kwargs):
        q = self.wq(x)  # [n,t,e]
        k = self.wk(x)  # [n,t,e]
        v = self.wv(x)  # [n,t,e]
        q = self.split(q)  # [n,t,e] --> [n,h,t,v]  e=h*v h就是head的数目，v就是每个头中self-attention的维度大小
        k = self.split(k)  # [n,t,e] --> [n,h,t,v]  e=h*v
        v = self.split(v)  # [n,t,e] --> [n,h,t,v]  e=h*v

        v = qkv_attention_value(q, k, v, attention_mask)

        v = torch.permute(v, dims=(0, 2, 1, 3))  # [n,h,t,v] --> [n,t,h,v]
        n, t, _, _ = v.shape
        v = torch.reshape(v, shape=(n, t, -1))  # [n,t,h,v] -> [n,t,e]
        v = self.wo(v)  # 多个头之间的特征组合合并
        return v

class MultiHeadEncoderDecoderAttention(nn.Module):
    def __init__(self, hidden_size, num_header):
        super(MultiHeadEncoderDecoderAttention, self).__init__()
        assert hidden_size % num_header == 0, f"header的数目没办法整除:{hidden_size}, {num_header}"

        self.hidden_size = hidden_size  # 就是向量维度大小，也就是E
        self.num_header = num_header  # 头的数目

        self.wo = nn.Sequential(
            nn.Linear(in_features=self.hidden_size, out_features=self.hidden_size),
            nn.ReLU()
        )

    def split(self, vs):
        n, t, e = vs.shape
        vs = torch.reshape(vs, shape=(n, t, self.num_header, e // self.num_header))
        vs = torch.permute(vs, dims=(0, 2, 1, 3))
        return vs

    def forward(self, q, encoder_k, encoder_v, encoder_attention_mask, **kwargs):
        """
        编码器解码器attention
        :param q: [N,T1,E]
        :param encoder_k: [N,T2,E]
        :param encoder_v: [N,T2,E]
        :param encoder_attention_mask: [N,1,T2,T2]
        :return: [N,T1,E]
        """
        q = self.split(q)  # [n,t,e] --> [n,h,t,v]  e=h*v h就是head的数目，v就是每个头中self-attention的维度大小
        k = self.split(encoder_k)  # [n,t,e] --> [n,h,t,v]  e=h*v
        v = self.split(encoder_v)  # [n,t,e] --> [n,h,t,v]  e=h*v

        # 计算attention value
        v = qkv_attention_value(q, k, v, mask=encoder_attention_mask)

        # 5. 输出
        v = torch.permute(v, dims=(0, 2, 1, 3))  # [n,h,t,v] --> [n,t,h,v]
        n, t, _, _ = v.shape
        v = torch.reshape(v, shape=(n, t, -1))  # [n,t,h,v] -> [n,t,e]
        v = self.wo(v)  # 多个头之间的特征组合合并
        return v

In [22]:
class FFN(nn.Module):
    def __init__(self, hidden_size):
        super(FFN, self).__init__()

        self.ffn = nn.Sequential(
            nn.Linear(hidden_size, 4 * hidden_size),
            nn.ReLU(),
            nn.Linear(4 * hidden_size, hidden_size)
        )

    def forward(self, x, **kwargs):
        return self.ffn(x)

In [24]:
class ResidualsNorm(nn.Module):
    def __init__(self, block, hidden_size):
        super(ResidualsNorm, self).__init__()
        self.block = block
        self.norm = nn.LayerNorm(normalized_shape=hidden_size)
        self.relu = nn.ReLU()

    def forward(self, x, **kwargs):
        z = self.block(x, **kwargs)
        z = self.relu(x + z)
        z = self.norm(z)
        return z

In [26]:
class TransformerEncoderLayers(nn.Module):
    def __init__(self, hidden_size, num_header, encoder_layers):
        super(TransformerEncoderLayers, self).__init__()

        layers = []
        for i in range(encoder_layers):
            layer = [
                ResidualsNorm(
                    block=MultiHeadSelfAttention(hidden_size=hidden_size, num_header=num_header),
                    hidden_size=hidden_size
                ),
                ResidualsNorm(
                    block=FFN(hidden_size=hidden_size),
                    hidden_size=hidden_size
                )
            ]
            layers.extend(layer)
        self.layers = nn.ModuleList(layers)

    def forward(self, x, attention_mask):
        attention_mask = torch.unsqueeze(attention_mask, dim=1)  # 增加header维度
        for layer in self.layers:
            x = layer(x, attention_mask=attention_mask)
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_header, max_seq_length, encoder_layers):
        super(TransformerEncoder, self).__init__()

        self.input_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size)
        self.position_emb = nn.Embedding(num_embeddings=max_seq_length, embedding_dim=hidden_size)
        self.layers = TransformerEncoderLayers(hidden_size, num_header, encoder_layers)

    def forward(self, input_token_ids, input_position_ids, input_mask):
        """
        前向过程
        :param input_token_ids: [N,T] long类型的token id
        :param input_position_ids: [N,T] long类型的位置id
        :param input_mask: [N,T,T] float类型的mask矩阵
        :return:
        """
        # 1. 获取token的embedding
        inp_embedding = self.input_emb(input_token_ids)  # [N,T,E]

        # 2. 获取位置embedding
        position_embedding = self.position_emb(input_position_ids)

        # 3. 合并embedding
        emd = inp_embedding + position_embedding

        # 4. 输入到attention提取特征
        feat_emd = self.layers(emd, attention_mask=input_mask)

        return feat_emd


class TransformerDecoderLayers(nn.Module):
    def __init__(self, hidden_size, num_header, decoder_layers):
        super(TransformerDecoderLayers, self).__init__()

        self.wk = nn.Linear(hidden_size, hidden_size)
        self.wv = nn.Linear(hidden_size, hidden_size)

        layers = []
        for i in range(decoder_layers):
            layer = [
                ResidualsNorm(
                    block=MultiHeadSelfAttention(hidden_size=hidden_size, num_header=num_header),
                    hidden_size=hidden_size
                ),
                ResidualsNorm(
                    block=MultiHeadEncoderDecoderAttention(hidden_size=hidden_size, num_header=num_header),
                    hidden_size=hidden_size
                ),
                ResidualsNorm(
                    block=FFN(hidden_size=hidden_size),
                    hidden_size=hidden_size
                )
            ]
            layers.extend(layer)
        self.layers = nn.ModuleList(layers)

    def forward(self, x, encoder_outputs=None, encoder_attention_mask=None, attention_mask=None):
        """
        :param x: [N,T2,E]
        :param encoder_outputs: [N,T1,E]
        :param encoder_attention_mask: [N,1,T1]
        :param attention_mask: [N,T2,T2]
        :return:
        """
        attention_mask = torch.unsqueeze(attention_mask, dim=1)  # 增加header维度 [N,T2,T2] -> [N,1,T2,T2]
        encoder_attention_mask = torch.unsqueeze(encoder_attention_mask, dim=1)  # 增加header维度 [N,1,T1] -> [N,1,1,T1]
        k = self.wk(encoder_outputs)  # [N,T1,E]
        v = self.wv(encoder_outputs)  # [N,T1,E]

        for layer in self.layers:
            x = layer(
                x,
                encoder_k=k, encoder_v=v, encoder_attention_mask=encoder_attention_mask,
                attention_mask=attention_mask
            )
        return x


class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_header, max_seq_length, decoder_layers):
        super(TransformerDecoder, self).__init__()

        self.input_emb = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_size)
        self.position_emb = nn.Embedding(num_embeddings=max_seq_length, embedding_dim=hidden_size)
        self.layers = TransformerDecoderLayers(hidden_size, num_header, decoder_layers)

    def forward(self, input_token_ids, input_position_ids, input_mask, encoder_outputs, encoder_attention_mask):
        """
        前向过程
        :param input_token_ids: [N,T] long类型的token id
        :param input_position_ids: [N,T] long类型的位置id
        :param input_mask: [N,T,T] float类型的mask矩阵
        :param encoder_outputs: [N,T1,E] 编码器的输出状态信息
        :param encoder_attention_mask: [N,T1,T1] 编码器的输入mask信息
        :return:
        """
        if self.training:
            # 1. 获取token的embedding
            inp_embedding = self.input_emb(input_token_ids)  # [N,T,E]

            # 2. 获取位置embedding
            position_embedding = self.position_emb(input_position_ids)

            # 3. 合并embedding
            emd = inp_embedding + position_embedding

            # 4. 输入到attention提取特征
            feat_emd = self.layers(
                emd, encoder_outputs=encoder_outputs,
                encoder_attention_mask=encoder_attention_mask, attention_mask=input_mask
            )

            return feat_emd
        else:
            raise ValueError("当前模拟代码不实现推理过程，仅实现training过程")


class Transformer(nn.Module):
    def __init__(self):
        super(Transformer, self).__init__()

    def forward(self, encoder_input_ids, encoder_lengths, label_ids=None, label_lengths=None):
        """
        :param encoder_input_ids: 编码器输入token id: [N,T1]
        :param encoder_lengths: 编码器输入的文本实际长度值, [N,] 也可以直接外部传入mask等信息，但是至少要有代码实现从length到mask的转换
        :param label_ids: 模型预测期望输出的token id: [N,T2]
        :param label_lengths: 模型预测期望输出文本实际长度, [N,]
        :return:
        """
        pass

In [28]:
def transformer():
    encoder = TransformerEncoder(vocab_size=1000, hidden_size=512, num_header=8, max_seq_length=1024, encoder_layers=6)
    decoder = TransformerDecoder(vocab_size=1000, hidden_size=512, num_header=8, max_seq_length=1024, decoder_layers=6)

    input_token_ids = torch.tensor([
        [100, 102, 108, 253, 125],  # 第一个样本实际长度为5
        [254, 125, 106, 0, 0]  # 第二个样本实际长度为3
    ])
    input_position_ids = torch.tensor([
        [0, 1, 2, 3, 4],
        [0, 1, 2, 3, 4]
    ])
    input_mask = torch.tensor([
        [
            [0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0, 0.0, 0.0],
        ],
        [
            [0.0, 0.0, 0.0, -10000.0, -10000.0],
            [0.0, 0.0, 0.0, -10000.0, -10000.0],
            [0.0, 0.0, 0.0, -10000.0, -10000.0],
            [-10000.0, -10000.0, -10000.0, 0.0, -10000.0],
            [-10000.0, -10000.0, -10000.0, -10000.0, 0.0],
        ],
    ])
    encoder_attention_mask = torch.tensor([
        [
            [0.0, 0.0, 0.0, 0.0, 0.0]  # 表示第一个样本的解码器中第一个时刻和编码器的各个时刻之间的mask值
        ],
        [
            [0.0, 0.0, 0.0, -10000.0, -10000.0]  # 是因为编码器的输入中，最后两个位置是填充
        ],
    ])

    input_decoder_token_ids = torch.tensor([
        [251, 235, 124, 321, 25, 68],
        [351, 235, 126, 253, 0, 0]
    ])
    input_decoder_position_ids = torch.tensor([
        [0, 1, 2, 3, 4, 5],
        [0, 1, 2, 3, 4, 5]
    ])
    input_decoder_mask = torch.tensor([
        [
            [0.0, -10000.0, -10000.0, -10000.0, -10000.0, -10000.0],
            [0.0, 0.0, -10000.0, -10000.0, -10000.0, -10000.0],
            [0.0, 0.0, 0.0, -10000.0, -10000.0, -10000.0],
            [0.0, 0.0, 0.0, 0.0, -10000.0, -10000.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, -10000.0],
            [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
        ],
        [
            [0.0, -10000.0, -10000.0, -10000.0, -10000.0, -10000.0],
            [0.0, 0.0, -10000.0, -10000.0, -10000.0, -10000.0],
            [0.0, 0.0, 0.0, -10000.0, -10000.0, -10000.0],
            [0.0, 0.0, 0.0, 0.0, -10000.0, -10000.0],
            [-10000.0, -10000.0, -10000.0, -10000.0, 0.0, -10000.0],
            [-10000.0, -10000.0, -10000.0, -10000.0, -10000.0, 0.0]
        ],
    ])

    encoder_outputs = encoder(input_token_ids, input_position_ids, input_mask)
    print(encoder_outputs.shape)

    decoder_outputs = decoder(
        input_token_ids=input_decoder_token_ids,
        input_position_ids=input_decoder_position_ids,
        input_mask=input_decoder_mask,
        encoder_outputs=encoder_outputs,
        encoder_attention_mask=encoder_attention_mask
    )
    print(decoder_outputs.shape)

def model_structure(model):
    blank = ' '
    print('-' * 90)
    print('|' + ' ' * 11 + 'weight name' + ' ' * 10 + '|' \
          + ' ' * 15 + 'weight shape' + ' ' * 15 + '|' \
          + ' ' * 3 + 'number' + ' ' * 3 + '|')
    print('-' * 90)
    num_para = 0
    type_size = 1  # 如果是浮点数就是4
 
    for index, (key, w_variable) in enumerate(model.named_parameters()):
        if len(key) <= 30:
            key = key + (30 - len(key)) * blank
        shape = str(w_variable.shape)
        if len(shape) <= 40:
            shape = shape + (40 - len(shape)) * blank
        each_para = 1
        for k in w_variable.shape:
            each_para *= k
        num_para += each_para
        str_num = str(each_para)
        if len(str_num) <= 10:
            str_num = str_num + (10 - len(str_num)) * blank
 
        print('| {} | {} | {} |'.format(key, shape, str_num))
    print('-' * 90)
    print('The total number of parameters: ' + str(num_para))
    print('The parameters of Model {}: {:4f}M'.format(model._get_name(), num_para * type_size / 1000 / 1000))
    print('-' * 90)


In [34]:
if __name__ == '__main__':
    transformer()
    net=TransformerDecoder(1,8,8,20,5)
    print(net)

TypeError: super(type, obj): obj must be an instance or subtype of type