In [None]:
from torch import nn
import math
from torch import Tensor
import torch
import torch.nn.functional as F

In [None]:
class Embedding(nn.Module):
    #nnembedding创建一个词表。 输入的xsize为： batch_size， max_seq_length,  vocab_size
    def __init__(self, vocab_size, hidden_dim):
        self.embedding = nn.Embedding(vocab_size, hidden_dim)     # size: vocabsize, hidden_dim
        self.sqrt_factor = math.sqrt(hidden_dim)

    def forward(self, x: Tensor) -> Tensor :
        embedding = self.embedding(x.long())
        x = embedding * self.sqrt_factor
        return x   # size: batch_size, max_seq_length, hidden_dim
        

In [None]:
class positionalEncoding(nn.Module):
    def __init__(self, max_positions, hidden_dim, drop_prob):
        pe = torch.zeros(max_positions, hidden_dim)
        self.dropout = nn.Dropout(drop_prob)
        positions = torch.arange(max_positions).unsqueeze(1)  # [maxpositions, 1]
        div_pair = torch.arange(0, hidden_dim, 2)
        div_term = torch.exp(div_pair * (-math.log(10000.0)) / hidden_dim)

        pe[:, 0::2] = torch.sin(positions * div_term)  # max_porsitions x 1   *   1 x hidden_dim
        pe[:, 1::2] = torch.cos(positions * div_term)

        pe = self.pe.unsqueeze(0)  # 1,max_positions, hidden_dim
        self.register_buffer('pe',pe)  # register as non-learnable paras
        #定义一组参数，该组参数的特别之处在于：模型训练时不会更新（即调用 optimizer.step() 后该组参数不会变化，只可人为地改变它们的值），但是保存模型时，该组参数又作为模型参数不可或缺的一部分被保存。

    def forward(self, x):
        max_seq_length = x.size(1)
        x = x + self.pe[:, :max_seq_length]
        return self.dropout(x)

In [None]:
def attention(query, key, value, mask):
    sqrt_dim = query.shape[-1] ** 0.5

    scores = torch.matmul(query, key.transpose(-2,-1)) / sqrt_dim # 矩阵乘法

    if mask is not None:
        scores = scores.masked_fill(mask==0,-1e9)  # scores =0 we set to -1e9
    weight = F.softmax(scores, dim=-1)  # 对最后一个维度（也就是 maxpositions的行做softmax， 每一行的和都为1）， 这个就是对应的attention map（在self attention中）
    return torch.matmul(weight, value)  

# about mask
# We deal with a batch of sequences during training and need to add padding to shorter sequences. However, we do not want them for attention calculation.An attention mask has zeros for padding positions after the end of a sequence.
# We give a large negative value to masked positions so that the softmax will assign zero probability to them. So, it would look like the below.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, hidden_dim, num_heads, drop_prob):
        super().__init__()
        self.dropout = nn.Dropout(drop_prob)
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dim_head = hidden_dim // num_heads

        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x, y, mask):
        q = self.query(x)
        k = self.query(y)
        v = self.query(y)

        batch_size = x.size(0)
        query = query.view(batch_size, -1, self.num_heads, self.dim_head)
        key   = key.view(batch_size, -1, self.num_heads, self.dim_head)
        value = value.view(batch_size, -1, self.num_heads, self.dim_head)

        # Into the number of heads (batch_size, num_heads, -1, dim_head)
        query = query.transpose(1, 2)
        key   = key.transpose(1, 2)
        value = value.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)

        attn = attention(query, key, value, mask)
        attn = attn.transpose(1, 2).contiguous().view(batch_size, -1, self.hidden_dim) # concat the output on num_heads together to become hidden_dim back.

        out = self.dropout(self.output(attn))
        return out


In [None]:
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, ffd_dim, drop_prob):
        super().__init__()
        self.ffd = nn.Sequential(
            nn.Linear(hidden_dim, ffd_dim),
            nn.ReLU(inplace=True),  # to replace the variable to save the storage memory
            nn.Dropout(drop_prob),
            nn.Linear(ffd_dim, hidden_dim),
            nn.Dropout(drop_prob),
        )
    
    def forward(self, x):
        return self.ffd(x)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, drop_prob, ffd_dim):
        super().__init__()
        self.LayerNorm1 = nn.LayerNorm(hidden_dim)
        self.attention = MultiHeadAttention(hidden_dim, num_heads, drop_prob)
        self.LayerNorm2 = nn.LayerNorm(hidden_dim)
        self.feedforward = FeedForward(hidden_dim, ffd_dim, drop_prob)



    def forward(self, x,  x_mask):
        x = x + self.layer1(x,  x_mask)
        x = x + self.layer2(x)
        return x
        


    def layer1(self, x, x_mask):
        x = self.LayerNorm1(x)
        x = self.attention(x, x, x_mask)
        return x
    
    def layer2(self, x, x_mask):
        x = self.LayerNorm2(x)
        x = self.feedforward(x)
        return x



In [None]:
class Encoder(nn.Module):
    def __init__(self, hidden_dim, num_heads, drop_prob, ffd_dim):
        super().__init__()
        self.blocks = nn.ModuleList(
            [EncoderBlock( hidden_dim, num_heads, drop_prob, ffd_dim)]
        )
        self.laynorm = nn.LayerNorm(hidden_dim)

    def forward(self, x, x_mask):
        for block in self.blocks:
            x = block(x, x_mask)
        x = self.layernorm(x)
        return x


In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, hidden_dim, num_heads, drop_prob, ffd_dim):
        super().__init__()
        self.LayerNorm1 = nn.LayerNorm(hidden_dim)
        self.self_attn = MultiHeadAttention(hidden_dim, num_heads, drop_prob)
        self.LayerNorm2 = nn.LayerNorm(hidden_dim)
        self.cross_attn = MultiHeadAttention(hidden_dim, num_heads, drop_prob)
        self.ffd = FeedForward(hidden_dim, ffd_dim, drop_prob)
        self.LayerNorm3 = nn.LayerNorm(hidden_dim)



    def forward(self, y, y_mask, x, x_mask):
        y = y + self.layer1(y, y_mask)
        y = y + self.layer2(y, x, x_mask)
        y = y + self.layer3(y)
        return y


    def layer1(self, y, y_mask):
        y = self.LayerNorm1(y)
        y = self.self_attn(y, y, y_mask)
        return y

    def layer2(self, y, x, x_mask):
        y = self.LayerNorm2(y)
        y = self.cross_attn(y, x, x_mask)
        return y
    
    def layer3(self, y):
        y = self.LayerNorm3(y)
        y = self.ffd(y)
        return y



In [None]:
class Decoder(nn.Module):
    def __init__(self, hidden_dim, num_heads, drop_prob, ffd_dim, num_blocks) :
        super().__init__()
        self.blocks = nn.ModuleList(
            [DecoderBlock(hidden_dim, num_heads, drop_prob, ffd_dim)
            for _ in num_blocks]
        )
        self.LayerNorm = nn.LayerNorm(hidden_dim)

    def forward(self, x, y, x_mask, y_mask):
        for block in self.blocks:
            y = block(y, y_mask, x, x_mask)
            y = self.LayerNorm(y)
            



In [None]:
class Transformer(nn.Module):
    def __init__(self, hidden_dim, num_heads, drop_prob, ffd_dim, input_vocab_size, max_positions, output_vocab_size):
        super().__init__()
        self.input_embedding_layer = Embedding(input_vocab_size, hidden_dim)
        self.input_position_encoding = positionalEncoding(max_positions, hidden_dim, drop_prob)
        self.output_embedding_layer = Embedding(output_vocab_size, hidden_dim)
        self.output_pos_encoding = positionalEncoding(
                                       max_positions, hidden_dim, drop_prob)
        self.encoder = Encoder(hidden_dim, num_heads, drop_prob, ffd_dim)
        self.decoder = Decoder(hidden_dim, num_heads, drop_prob, ffd_dim)
        self.projection = nn.Linear(hidden_dim, output_vocab_size)

        # initialize paras
        for para in self.parameters():
            if para.dim()>1:
                nn.init.xavier_uniform_(para)



    def forward(self, x, y, x_mask, y_mask):
        x = self.encode(x, x_mask)
        y = self.decode(x, y, x_mask, y_mask)
        return y

    def encode(self, x, x_mask):
        x = self.input_embedding_layer(x)
        x = self.input_position_encoding(x)
        x = self.encoder(x, x_mask)
        return x
    
    def deocde(self, x, x_mask, y, y_mask):
        y = self.output_embedding_layer(y)
        y = self.output_pos_encoding(y)
        y = self.decoder(x, y, x_mask, y_mask)
        return self.projection(y)



In [None]:
# nn.LayerNorm supplement
class LayerNorm_ty(nn.Module):
    def __init__(self, eps, dimension):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dimension))
        self.beta = nn.Parameter(torch.zeros(dimension))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)  # why choose dim=-1?
        var = x.var(-1, unbiased=True, keepdim=True)
        x_out = (x - mean) / torch.sqrt(var + self.eps)
        y_out = self.gamma * x_out + self.beta
        return y_out