# Transformers简介
## 1. 模型结构


<img src="img/transformer.jpeg" height="400" width="400" >

transformer 本质上依旧是encoder_decoder框架，所以我们可以拆成2部分来了解它。  
Encoder: 对inputsequence进行embedding  
Decoder: 对上一时刻的output进行embedding

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

### 2.模型构造
### 2.1 一些基本的组件
包括MultiHeadAttention, PositionalEncoding, Feedforward, LayerNorm等。
### MultiHeadAttention

In [2]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_size, num_heads, dropout):
        super(MultiHeadAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.linear_q = nn.Linear(hidden_size, hidden_size)
        self.linear_k = nn.Linear(hidden_size, hidden_size)
        self.linear_v = nn.Linear(hidden_size, hidden_size)
        self.linear_out = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # linear projections
        Q = self.linear_q(query)
        K = self.linear_k(key)
        V = self.linear_v(value)

        # split into multiple heads
        Q = Q.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_size]
        K = K.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_size]
        V = V.view(batch_size, -1, self.num_heads, self.head_size).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_size]

        # scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_size)  # [batch_size, num_heads, seq_len, seq_len]
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attention = self.dropout(F.softmax(scores, dim=-1))
        context = torch.matmul(attention, V)  # [batch_size, num_heads, seq_len, head_size]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_size)  # [batch_size, seq_len, hidden_size]

        # final linear projection
        output = self.linear_out(context)
        return output, attention


### PositionalEncoding

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, hidden_size, max_seq_len):
        super(PositionalEncoding, self).__init__()
        self.hidden_size = hidden_size
        self.pe = torch.zeros(max_seq_len, hidden_size)
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, hidden_size, 2).float() * (-math.log(10000.0) / hidden_size))
        self.pe[:, 0::2] = torch.sin(position * div_term)
        self.pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = self.pe.unsqueeze(0)
        self.register_buffer('pe', self.pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x


### PositionwiseFeedForward

In [4]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, hidden_size, feedforward_size, dropout):
        super(PositionwiseFeedForward, self).__init__()
        self.linear1 = nn.Linear(hidden_size, feedforward_size)
        self.linear2 = nn.Linear(feedforward_size, hidden_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.dropout(x)
        x = self.linear2(x)
        return x


### LayerNorm

In [5]:
class LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(hidden_size))
        self.beta = nn.Parameter(torch.zeros(hidden_size))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        norm_x = (x - mean) / (std + self.eps)
        return self.gamma * norm_x + self.beta


### 2.2 定义Encoder和Decoder

In [6]:
class EncoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, feedforward_size, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(hidden_size, num_heads, dropout)
        self.feedforward = PositionwiseFeedForward(hidden_size, feedforward_size, dropout)
        self.norm1 = LayerNorm(hidden_size)
        self.norm2 = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        # multi-head attention
        attention, _ = self.self_attention(x, x, x, mask=mask)
        x = x + self.dropout(attention)
        x = self.norm1(x)

        # position-wise feed forward
        feedforward = self.feedforward(x)
        x = x + self.dropout(feedforward)
        x = self.norm2(x)

        return x


class DecoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, feedforward_size, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(hidden_size, num_heads, dropout)
        self.encoder_attention = MultiHeadAttention(hidden_size, num_heads, dropout)
        self.feedforward = PositionwiseFeedForward(hidden_size, feedforward_size, dropout)
        self.norm1 = LayerNorm(hidden_size)
        self.norm2 = LayerNorm(hidden_size)
        self.norm3 = LayerNorm(hidden_size)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        # masked self-attention
        attention, _ = self.self_attention(x, x, x, mask=tgt_mask)
        x = x + self.dropout(attention)
        x = self.norm1(x)

        # encoder attention
        attention, _ = self.encoder_attention(x, encoder_output, encoder_output, mask=src_mask)
        x = x + self.dropout(attention)
        x = self.norm2(x)

        # position-wise feed forward
        feedforward = self.feedforward(x)
        x = x + self.dropout(feedforward)
        x = self.norm3(x)

        return x


class Encoder(nn.Module):
    def __init__(self, hidden_size, num_layers, num_heads, feedforward_size, dropout, max_seq_len):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(hidden_size, num_heads, feedforward_size, dropout) for _ in range(num_layers)])
        self.pe = PositionalEncoding(hidden_size, dropout, max_seq_len)

    def forward(self, x, mask):
        x = self.pe(x)

        for layer in self.layers:
            x = layer(x, mask)

        return x


class Decoder(nn.Module):
    def __init__(self, hidden_size, num_layers, num_heads, feedforward_size, dropout, max_seq_len):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(hidden_size, num_heads, feedforward_size, dropout) for _ in range(num_layers)])
        self.pe = PositionalEncoding(hidden_size, dropout, max_seq_len)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.pe(x)

        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return x


### 2.3 定义Transformer模型

In [7]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, hidden_size, num_layers, num_heads, feedforward_size, dropout, max_src_len, max_tgt_len):
        super(Transformer, self).__init__()
        self.embedding_src = nn.Embedding(src_vocab_size, hidden_size)
        self.embedding_tgt = nn.Embedding(tgt_vocab_size, hidden_size)
        self.encoder = Encoder(hidden_size, num_layers, num_heads, feedforward_size, dropout, max_src_len)
        self.decoder = Decoder(hidden_size, num_layers, num_heads, feedforward_size, dropout, max_tgt_len)
        self.output_layer = nn.Linear(hidden_size, tgt_vocab_size)

    def forward(self, src_input, tgt_input, src_mask, tgt_mask):
        src_embedded = self.embedding_src(src_input)
        tgt_embedded = self.embedding_tgt(tgt_input)

        encoder_output = self.encoder(src_embedded, src_mask)
        decoder_output = self.decoder(tgt_embedded, encoder_output, src_mask, tgt_mask)

        output = self.output_layer(decoder_output)
        return output