In [76]:
import torch
import torch.nn as nn
import dltools

In [58]:
class Encoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def forward(self,X,*args):
        return NotImplementedError
class Decoder(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def __init_state__(self,eng_output,*args):
        return NotImplementedError
    def forward(self,X,*args):
        return NotImplementedError



In [59]:
class Seq2SeqEncoder(Encoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout if num_layers > 1 else 0.0)
    def forward(self,X,*args):
        X = self.embedding(X)
        X = X.permute(1,0,2)
        enc_output,enc_state = self.rnn(X)
        return enc_output,enc_state




In [64]:
class Seq2SeqDecoder(Decoder):
    def __init__(self,vocab_size,embed_size,num_hiddens,num_layer,dropout=0.0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.embedding = nn.Embedding(vocab_size,embed_size)
        self.rnn = nn.GRU(embed_size + num_hiddens,num_hiddens,num_layer,dropout=dropout if num_layer >1 else 0.0)
        self.dense = nn.Linear(num_hiddens,vocab_size)
    def init_state(self, eng_output, *args):
        return eng_output[1]
    def forward(self, X,state):
        X = self.embedding(X).permute(1,0,2)
      
        context = state[-1].repeat(X.shape[0],1,1)
        X_and_context = torch.cat([X,context],dim=2)
        output,new_state = self.rnn(X_and_context)
        logits = self.dense(output).permute(1,0,2)
        return logits,new_state
    


In [81]:
class EncoderDecoder(nn.Module):
    def __init__(self,encoder,decoder, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def forward(self,enc_X,dec_X,*args):
        X = self.encoder(enc_X)
        state = self.decoder.init_state(X)
        print( state[-1].shape)
        return self.decoder(dec_X,state)

        

In [67]:
# ==== 超参数 ====
src_vocab = 200
tgt_vocab = 150
embed_size = 8
hidden_size = 16
num_layer = 2

# ==== 模块实例 ====
encoder = Seq2SeqEncoder(
    vocab_size=src_vocab,
    embed_size=embed_size,
    num_hiddens=hidden_size,
    num_layers=num_layer
)
decoder = Seq2SeqDecoder(
    vocab_size=tgt_vocab,
    embed_size=embed_size,
    num_hiddens=hidden_size,
    num_layer=num_layer  # 注意这里叫 num_layer（单数），值与 encoder 对齐
)
net = EncoderDecoder(encoder, decoder)

# ==== 伪造一批数据 ====
batch_size = 4
src_len = 7
tgt_len = 6
enc_X = torch.zeros((batch_size, src_len), dtype=torch.long)  # 假装是源句
# 训练时 dec_X 通常是 <bos> + gold[:-1]；这里只做连通性测试
dec_X = torch.zeros((batch_size, tgt_len), dtype=torch.long)

# ==== 前向 ====
logits, dec_state = net(enc_X, dec_X)

print("encoder → (time_outputs, final_state) shapes:")
enc_time_out, enc_final_state = encoder(enc_X)
print("  time_outputs:", enc_time_out.shape)      # [S, B, H] = [7, 4, 16]
print("  final_state:", enc_final_state.shape)    # [L, B, H] = [2, 4, 16]

print("\ndecoder/net outputs:")
print("  logits:", logits.shape)                  # [B, T, V] = [4, 6, 150]
print("  dec_state:", dec_state.shape)            # [L, B, H] = [2, 4, 16]

# 可选断言（自检失败会抛错）
assert enc_time_out.shape == (src_len, batch_size, hidden_size)
assert enc_final_state.shape == (num_layer, batch_size, hidden_size)
assert logits.shape == (batch_size, tgt_len, tgt_vocab)
assert dec_state.shape == (num_layer, batch_size, hidden_size)


torch.Size([4, 16])
encoder → (time_outputs, final_state) shapes:
  time_outputs: torch.Size([7, 4, 16])
  final_state: torch.Size([2, 4, 16])

decoder/net outputs:
  logits: torch.Size([4, 6, 150])
  dec_state: torch.Size([2, 4, 16])


In [None]:
def sequence_mask(X,valid_len,value = 0):
    max_len = X.size(1)
    mask = torch.arange(max_len,device=X.device)[None] < valid_len[:,None]
    X[~mask] =value
    return X


tensor([[1, 2, 3, 0, 0],
        [6, 7, 0, 0, 0]])


In [73]:
class MaskedSeqMaxSEloss(nn.CrossEntropyLoss):
    def forward(self,pred,label,valid_len):
        weight = torch.ones_like(label)
        weight = sequence_mask(weight,valid_len)
        self.reduction ='none'
        unweighted_loss = super().forward(pred.permute(0,2,1),label)
        weighted_loss = (unweighted_loss * weight).mean(dim=1)
        return weighted_loss
        

In [None]:
def train_seq2seq(net,train_iter,lr,num_epochs,device,tgt_vocab):
    def xavier_init_weight(m):
        if type(m) == nn.Linear:
            nn.init.xavier_uniform_(m.weight)
        if type(m) == nn.GRU:
            for param in m._flat_weights_names:
                if 'weight' in param:
                    nn.init.xavier_uniform_(m._parameters[param])
    net.apply(xavier_init_weight)
    net = net = net.to(device)
    optimizer = torch.optim.Adam(net.parameters(),lr=lr)
    loss = MaskedSeqMaxSEloss()
    animater = dltools.Animator(xlabel='epoch',ylabel='loss',xlim=[10,num_epochs])
    for batch in train_iter:
        timer = dltools.Timer()
        metric = dltools.Accumulator(2)
        optimizer.zero_grad()
        X,X_valid_len,Y,Y_valid_len = [x.to(device) for x in batch]
        bos = torch.tensor([tgt_vocab['<bos>']]*Y.shape[0],device=device).reshape(-1,1)
        dec_input = torch.cat([bos,Y[:,:-1]],1)
        Y_pre,_=net(X,dec_input,X_valid_len)
        l = loss(Y_pre,Y,Y_valid_len)
        l.sum().backward()
        optimizer.step()
        with torch.no_grad():
            metric.add(l.sum(),num_token)
    if (epoch +1 )% 10 ==0:
        animator.add(epoch+1,(metric[0]/metric[1]))
    print(f'loss {metric[0]/metric[1]:.3f},{metric[1]/timer.stop():.1f}',f'tokens/sec on {str(device)}')


In [84]:
embed_size,num_hiddens,num_layers,dropout = 32,32,2,0.1
batch_size,num_steps = 64,10
lr,num_epoches,device = 0.005,100,dltools.try_gpu()
train_iter,src_vocab,tgt_vocab = dltools.load_data_nmt(batch_size,num_steps)
encoder = Seq2SeqEncoder(len(src_vocab),embed_size,num_hiddens,num_layers,dropout)
decoder = Seq2SeqDecoder(len(tgt_vocab),embed_size,num_hiddens,num_layers,dropout)
net = EncoderDecoder(encoder,decoder)
train_seq2seq(net,train_iter,lr,num_epoches,device,tgt_vocab)

torch.Size([64, 32])
tensor([2.1040, 2.6693, 2.1642, 2.1670, 3.2551, 2.6721, 2.1586, 1.6176, 3.2182,
        2.1417, 2.1398, 1.6139, 2.7057, 2.1426, 2.6892, 2.6750, 1.5876, 2.7035,
        1.6113, 2.1604, 2.6914, 1.6142, 2.1656, 2.1364, 1.6152, 2.7010, 1.6057,
        2.7059, 2.6964, 2.6809, 2.1419, 1.5995, 1.6197, 1.6010, 2.1467, 2.1670,
        1.6158, 2.1178, 2.1450, 1.6075, 2.7082, 1.5907, 2.1250, 2.1173, 1.6053,
        2.1488, 2.6436, 2.6693, 3.2181, 2.6771, 3.2543, 2.6770, 2.1488, 2.1358,
        2.1453, 2.6741, 1.5991, 1.5995, 2.1191, 3.1952, 2.1414, 3.7958, 2.6956,
        3.2589], device='cuda:0', grad_fn=<MeanBackward1>)
torch.Size([64, 32])
tensor([1.5436, 2.6159, 1.5567, 1.5755, 1.5431, 2.0937, 2.6548, 2.0854, 2.6092,
        2.6196, 3.2108, 2.0641, 2.6331, 2.6521, 1.5804, 2.6183, 3.0936, 2.5996,
        1.5469, 2.6280, 2.6460, 2.1269, 2.0716, 1.5671, 3.6931, 2.0728, 1.5732,
        1.5508, 2.6365, 1.5600, 2.6254, 1.5505, 2.5971, 1.6007, 2.0861, 2.6233,
        2.0794, 1.5