In [1]:
import collections
import math
import torch
from torch import nn
from d2l import torch as d2l

In [2]:
class Seq2SeqEncoder(d2l.Encoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0, **kwargs):
        super(Seq2SeqEncoder,self).__init__(**kwargs)
        #嵌入层
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)

    def forward(self,X,*args):
        #输出X的形状(batch_size,num_steps,embed_size)
        X=self.embedding(X)
        #在循环神经网络模型中，第一个轴对应于时间步
        X=X.permute(1,0,2)
        #如果没有提到状态，则默认为0
        output,state=self.rnn(X)
        # output的形状:(num_steps,batch_size,num_hiddens)
        # state[0]的形状:(num_layers,batch_size,num_hiddens)
        return output,state

In [4]:
encoder=Seq2SeqEncoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
encoder.eval()
X=torch.zeros((4,7),dtype=torch.long)
output,state=encoder(X)
output.shape,state.shape

(torch.Size([7, 4, 16]), torch.Size([2, 4, 16]))

In [9]:
class Seq2SeqDecoder(d2l.Decoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0, **kwargs):
        super(Seq2SeqDecoder,self).__init__(**kwargs)
        self.embedding=nn.Embedding(vocab_size,embed_size)
        self.rnn=nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)
        self.dense=nn.Linear(num_hiddens,vocab_size)

    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]

    def forward(self, X, state):
        X=self.embedding(X).permute(1,0,2)
        context=state[-1].repeat(X.shape[0],1,1)
        X_and_Context=torch.cat((X,context),2)
        output,state=self.rnn(X_and_Context,state)
        output=self.dense(output).permute(1,0,2)
        return output,state

In [10]:
decoder=Seq2SeqDecoder(vocab_size=10,embed_size=8,num_hiddens=16,num_layers=2)
decoder.eval()
state=decoder.init_state(encoder(X))
output,state=decoder(X,state)
output.shape,state.shape

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

In [11]:
def sequence_mask(X,valid_len,value=0):
    maxlen=X.size(1)
    mask=torch.arange((maxlen),dtype=torch.float32,device=X.device)[None,:]<valid_len[:,None]
    X[~mask]=value
    return X

X=torch.tensor([[1,2,3],[4,5,6]])
sequence_mask(X,torch.tensor([1,2]))

tensor([[1, 0, 0],
        [4, 5, 0]])

In [12]:
X = torch.ones(2, 3, 4)
sequence_mask(X, torch.tensor([1, 2]), value=-1)

tensor([[[ 1.,  1.,  1.,  1.],
         [-1., -1., -1., -1.],
         [-1., -1., -1., -1.]],

        [[ 1.,  1.,  1.,  1.],
         [ 1.,  1.,  1.,  1.],
         [-1., -1., -1., -1.]]])