In [1]:
import torch
from torch import nn 
from d2l import torch as d2l

class Seq2SeqEncoder(d2l.Encoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,
        dropout=0,**kwargs):
        super(Seq2SeqEncoder, self).__init__(**kwargs)
        # 嵌入层
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)

    def forward(self, X, *args):
        X = self.embedding(X)
        X = X.permute(1,0,2)
        output, state = self.rnn(X)
        return output,state

encoder = Seq2SeqEncoder(vocab_size=10, embed_size=8, num_hiddens=16,
    num_layers=2)
encoder.eval()
X= torch.zeros((4,7),dtype=torch.long)
output, state = encoder(X)
output.shape

torch.Size([7, 4, 16])

In [2]:
state.shape

torch.Size([2, 4, 16])

In [3]:
class Seq2SeqDecoder(d2l.Decoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
        super(Seq2SeqDecoder, self).__init__(**kwargs)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)
        self.dense = nn.Linear(num_hiddens,vocab_size)

    def init_state(self, enc_outputs, *args):
        return enc_outputs[1]

    def forward(self, X, state):
        X = self.embedding(X).permute(1,0,2)
        context = state[-1].repeat(X.shape[0],1,1)
        X_and_context = torch.cat((X, context), 2)
        output, state = self.rnn(X_and_context, state)
        output = self.dense(output).permute(1,0,2)
        return output, state

In [5]:
decoder = Seq2SeqDecoder(vocab_size=10, embed_size=8, num_hiddens=16,num_layers=2)
decoder.eval()
state = decoder.init_state(encoder(X))
output, state = decoder(X, state)
output.shape, state.shape

(torch.Size([4, 7, 10]), torch.Size([2, 4, 16]))

In [8]:
def sequence_mask(X, valid_len, value=0):
    maxlen = X.size(1)
    mask = torch.arange((maxlen), dtype=torch.float32,
        device=X.device)[None,:]<valid_len[:,None]
    X[~mask] = value
    return X

In [9]:
X = torch.tensor([[1,2,3],[4,5,6]])
sequence_mask(X, torch.tensor([1,2]))

tensor([[1, 0, 0],
        [4, 5, 0]])

In [12]:
class MaskedSoftmaxCELoss(nn.CrossEntropyLoss):
    def forward(self, pred, label, valid_len):
        weights = torch.ones_like(label)
        weights = sequence_mask(weights, valid_len)
        self.reduction='none'
        unweighted_loss = super(MaskedSoftmaxCELoss, self).forward(
            pred.permute(0,2,1),label)
        # 只保留有效的地方，无效的地方全部设为0
        weighted_loss = (unweighted_loss * weights).mean(dim=1)
        return weighted_loss
        

In [13]:
loss = MaskedSoftmaxCELoss()
loss(torch.ones(3,4,10),torch.ones((3,4),dtype=torch.long),torch.tensor([4,2,0]))

tensor([2.3026, 1.1513, 0.0000])