In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import math

# ATTN
1、除以d_k是因为q和k相乘后方差变为d_k
2、softmax将scores转换为概率分布

In [2]:
def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)  

    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

    if mask is not None:
        scores = scores.masked_fill(mask==0, -1e9)

    scores = F.softmax(scores, dim=-1)

    if dropout is not None:
        scores = dropout(scores)

    return torch.matmul(scores, value), scores

# Muti-Head ATTN
1、在张量操作之后保存内存连续性，以满足内存访问效率和后续网络层的要求
2、

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.head_dim = d_model // n_head
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        
        self.fc_out = nn.Linear(d_model, d_model)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        bs = x.shape[0]
        q = self.q(x).view(bs, -1, self.n_head, self.head_dim).transpose(1, 2)
        k = self.k(x).view(bs, -1, self.n_head, self.head_dim).transpose(1, 2)
        v = self.v(x).view(bs, -1, self.n_head, self.head_dim).transpose(1, 2)
        
        out = attention(q, k, v, dropout=self.dropout).permute(0, 2, 1, 3).contiguous().flatten(2)
        
        return self.fc_out(out)

# FFN

In [4]:
class FeedForwardLayer(nn.Module):
    def __init__(self, d_model, forward_expansion):
        super(FeedForwardLayer, self).__init__()
        self.w1 = nn.Linear(d_model, d_model*forward_expansion)
        self.w2 = nn.Linear(d_model*forward_expansion, d_model)

    def forward(self, x):
        return self.w2((F.relu(self.w1(x))))

In [5]:
class PositionEmbedding(nn.Module):
    def __init__(self, d_model, max_len=1000): 
        super(PositionEmbedding, self).__init__()
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(100000.0)/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + nn.Parameter(self.pe[:, :x.size(1)], requires_grad=False)
        return x
    
model = PositionEmbedding(256)
model.pe

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0000e+00,
           0.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  5.4030e-01,  7.9194e-01,  ...,  1.0000e+00,
           1.0941e-05,  1.0000e+00],
         [ 9.0930e-01, -4.1615e-01,  9.6711e-01,  ...,  1.0000e+00,
           2.1882e-05,  1.0000e+00],
         ...,
         [-8.9797e-01, -4.4006e-01,  1.7700e-01,  ...,  9.9993e-01,
           1.0908e-02,  9.9994e-01],
         [-8.5547e-01,  5.1785e-01,  8.8749e-01,  ...,  9.9993e-01,
           1.0919e-02,  9.9994e-01],
         [-2.6461e-02,  9.9965e-01,  9.0684e-01,  ...,  9.9993e-01,
           1.0930e-02,  9.9994e-01]]])

# Encoder

In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, head, forward_expansion, dropout):
        super(TransformerBlock, self).__init__()

        self.attn = MultiHeadAttention(embed_size, head)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = FeedForwardLayer(embed_size, forward_expansion)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, mask):
        # ipdb.set_trace()
        attention =  self.attn(query, key, value, mask)
        
        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

In [7]:
class Encoder(nn.Module):
    def __init__(
        self, 
        embed_size, 
        num_layers, 
        heads, 
        forward_expansion, 
        dropout=0.1,
    ):
        super(Encoder, self).__init__()

        self.layers = nn.ModuleList(
            [
                TransformerBlock(embed_size, heads, forward_expansion, dropout)
                for _ in range(num_layers)
            ]
        )
        # self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask):
        # ipdb.set_trace()
        for layer in self.layers:
            x = layer(x, x, x, mask)

        return x

# Decoder

In [8]:
class DecoderBlock(nn.Module):
    def __init__(self, embed_size, heads, forward_expansion, dropout=0.1):
        super(DecoderBlock, self).__init__()
        self.norm1 = nn.LayerNorm(embed_size)
        self.attn = MultiHeadAttention(embed_size, heads, dropout)
        self.transformer = TransformerBlock(embed_size, heads, forward_expansion, dropout)
        self.dropout = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(embed_size)

    def forward(self, x, value, key, src_mask, trg_mask):
        attn = self.attn(x, x, x, trg_mask)
        query = self.dropout(self.norm1(attn + x))
        out = self.attn(query, value, key, src_mask)
        out = self.norm2(out)
        return out

In [9]:
class Decoder(nn.Module):
    def __init__(
        self,
        embed_size,
        num_layers,
        heads,
        forward_expansion,
        dropout=0.1,
    ):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList(
            [
                DecoderBlock(embed_size, heads, forward_expansion, dropout)
                for _ in range(num_layers)
            ]
            
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_out, src_mask, trg_mask):
        for layer in self.layers:
            x = layer(x, encoder_out, encoder_out, src_mask, trg_mask)

        return x

# Transformer

In [10]:
class Transformer(nn.Module):
    def __init__(
        self,
        src_vocab_size,
        trg_vocab_size,
        src_pad_idx,
        trg_pad_idx,
        embed_size=512,
        num_encoder_layers=6,
        num_decoder_layers=6,
        forward_expansion=4,
        heads=8,
        dropout=0,
        max_length=100,  
        device="cpu",  
    ):
        super(Transformer, self).__init__()
        
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

        self.encoder = Encoder(
            embed_size,
            num_encoder_layers,
            heads,
            forward_expansion,
            dropout,
        )
        self.decoder = Decoder(
            embed_size,
            num_decoder_layers,
            heads,
            forward_expansion,
            dropout,
        )

        self.src_word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.src_position_embedding = nn.Embedding(max_length, embed_size)
        self.trg_word_embedding = nn.Embedding(trg_vocab_size, embed_size)
        self.trg_position_embedding = nn.Embedding(max_length, embed_size)

        self.fc_out = nn.Linear(embed_size, trg_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1)
        # (N, 1, src_len)
        return src_mask.to(self.device)

    def make_trg_mask(self, trg):
        N, trg_len = trg.shape
        trg_mask = torch.tril(torch.ones((trg_len, trg_len))).expand(
            N, trg_len, trg_len
        )

    def forward(self, src, trg):
        # ipdb.set_trace()
        N, src_seq_length = src.shape
        N, trg_seq_length = trg.shape
        src_positions = (
            torch.arange(0, src_seq_length)
            .unsqueeze(0)
            .expand(N, src_seq_length)
            .to(self.device)
        )

        trg_positions = (
            torch.arange(0, trg_seq_length)
            .unsqueeze(0)
            .expand(N, trg_seq_length)
            .to(self.device)
        )

        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        # encoder部分
        x = self.dropout(
            self.src_word_embedding(src) + self.src_position_embedding(src_positions)
        )
        encoder_out = self.encoder(x, src_mask)
        # decoder部分
        x = self.dropout(
            self.trg_word_embedding(trg) + self.trg_position_embedding(trg_positions)
        )
        decoder_out = self.decoder(x, encoder_out, src_mask, trg_mask)

        out = self.fc_out(decoder_out)

        return out