In [10]:
import math
import torch 
import numpy as np
from torch import nn
import pandas as pd
from d2l import torch as d2l

In [3]:
class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_inputs, ffn_num_hiddens, ffn_num_outputs,
                  **kwargs):
        super().__init__(**kwargs)

        self.dense1 = nn.Linear(ffn_num_inputs, ffn_num_hiddens)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

    def forward(self, X):
        return self.dense2(self.relu(self.dense1(X)))

只改变最后一个维度，batch_size, step, d  ==> batch_size, step, ffn_num_output

In [4]:
ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()
ffn(torch.ones((2,3,4)))[0]

tensor([[-0.3627, -0.5171, -0.2832, -0.0434, -0.1086,  0.1533,  0.3102, -0.0160],
        [-0.3627, -0.5171, -0.2832, -0.0434, -0.1086,  0.1533,  0.3102, -0.0160],
        [-0.3627, -0.5171, -0.2832, -0.0434, -0.1086,  0.1533,  0.3102, -0.0160]],
       grad_fn=<SelectBackward0>)

In [5]:
ln = nn.LayerNorm(2)
bn = nn.BatchNorm1d(2)

X = torch.tensor([[1, 2], [2, 3]], dtype=torch.float32)
ln(X), bn(X)

(tensor([[-1.0000,  1.0000],
         [-1.0000,  1.0000]], grad_fn=<NativeLayerNormBackward0>),
 tensor([[-1.0000, -1.0000],
         [ 1.0000,  1.0000]], grad_fn=<NativeBatchNormBackward0>))

In [6]:
class AddNorm(nn.Module):
    def __init__(self, normalized_shape, dropout, **kwargs):
        super().__init__( **kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(normalized_shape)
    
    def forward(self, X, Y):
        return self.ln(self.dropout(Y) + X)

In [7]:
add_norm = AddNorm([3, 4], 0.5)
add_norm.eval()
add_norm(torch.ones((2,3,4)), torch.zeros((2,3,4)))


tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], grad_fn=<NativeLayerNormBackward0>)

In [13]:
class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens, 
                 norm_shape, ffn_num_inputs, ffn_num_hiddens, num_heads, 
                 dropout, use_bias=False, **kwargs):
        super().__init__(**kwargs)
        self.attention = d2l.MultiHeadAttention(
            key_size, query_size,value_size, num_hiddens, 
            num_heads, dropout, use_bias)
        self.addnorm1 = AddNorm(norm_shape, dropout)
        self.ffn = PositionWiseFFN(ffn_num_inputs, ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(norm_shape, dropout)

    def forward(self, X, valid_lens):
        Y = self.addnorm1(X, self.attention(X,X,X,valid_lens))
        return self.addnorm2(Y, self.ffn(Y))

In [18]:
X = torch.ones((2,100,24))
valid_lens = torch.tensor([3,2])
encoderblock = EncoderBlock(24, 24, 24, 24, [100, 24], 24, 48, 8, 0.5)
encoderblock.eval()
encoderblock(X, valid_lens).shape

torch.Size([2, 100, 24])

In [22]:
class TransformerEncoder(d2l.Encoder):
    def __init__(self, vocab_size, key_size, query_size, value_size, 
                 num_hiddens, norm_shape, ffn_num_inputs, ffn_num_hiddens, 
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super().__init__(**kwargs)

        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module("block"+str(i), 
                                 EncoderBlock(key_size, query_size, 
                                              value_size, num_hiddens, norm_shape, 
                                              ffn_num_inputs, ffn_num_hiddens,
                                            num_heads, dropout, use_bias))
    def forward(self, X, valid_lens, *args):
        X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
        self.attention_weights = [None] * len(self.blks)
        for i, blk in enumerate(self.blks):
            X = blk(X, valid_lens)
            self.attention_weights[i]=blk.attention.attention.attention_weights
        return X


In [23]:
encoder = TransformerEncoder(
    200, 24, 24, 24, 24, [100, 24], 24, 48, 8, 2, 0.5)
encoder.eval()
encoder(torch.ones((2, 100), dtype=torch.long), valid_lens).shape

torch.Size([2, 100, 24])

In [None]:
class  DecoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                  dropout, i, **kwargs):
        super().__init__(**kwargs)
        self.i = i
        self.attention1 = d2l.MultiHeadAttention(key_size, query_size, value_size, 
                                                 num_hiddens, num_heads, dropout)
        self.addnorm1 = d2l.AddNorm(norm_shape,dropout)
        self.attention2 = d2l.MultiHeadAttention(key_size, query_size, value_size, 
                                                 num_hiddens, num_heads, dropout)
        self.addnorm2 = d2l.AddNorm(norm_shape, dropout)
        self.ffn = d2l.PositionWiseFFN(ffn_num_input, ffn_num_hiddens, 
                                       num_hiddens)
        self.addnrom3 = d2l.AddNorm(ffn_num_input, ffn_num_hiddens, num_hiddens)

    def forward(self, X, state):
        enc_outputs, enc_valid_lens = state[0], state[1]
        if state[2][self.i] is None:
            key_values = None
        else:
            key_values = torch.cat((state[2][self.i]), X, axis=1)
        state[2][self.i] = key_values
        if self.training:
            batch_size, num_steps, _ = X.shape
            dec_valid_lens = torch.arange(
                1, num_steps + 1, device=X.device).repeat(batch_size, 1)
        else:
            dec_valid_lens = None


