## Seq2seq-RNN
使用pytorch官方的RNN和自制的数据集，实现Seq2seq任务的训练和预测。  
这里是一个很简单的demo，输入长度和输出长度是一样的。  
实现的功能：重复输入序列。

In [2]:
import torch
from torch import nn

In [3]:
from toy_datasets import get_dataset_repeat

vocab_size = 24
seq_len = 8
data_num = 3000
batch_first = True

dataset, collate_fn = get_dataset_repeat(
    data_num=data_num, vocab_size=vocab_size, seq_len=seq_len, batch_first=batch_first
) # 一个加法任务的数据集
train_size = 0.2
train_dataset = dataset[: int(data_num * train_size)]
eval_dataset = dataset[int(data_num * train_size) :]
batch_size = 16
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, collate_fn=collate_fn
)
eval_dataloader = torch.utils.data.DataLoader(
    dataset=eval_dataset, batch_size=batch_size, collate_fn=collate_fn
)

In [4]:
next(train_dataloader.__iter__())

{'src': tensor([[17, 15,  2, 17,  7, 13, 21,  1],
         [22,  6,  4, 22,  2, 16, 16, 16],
         [10,  4,  5,  9, 19, 21, 16, 15],
         [ 6, 18, 14, 10, 18,  6,  1, 11],
         [ 1, 10, 18,  3, 12, 19, 17,  0],
         [ 5, 10,  6, 13, 22,  1,  7, 22],
         [ 0, 15,  7,  1, 17,  8, 14, 22],
         [17, 14, 18,  6, 17,  8,  3,  5],
         [ 0, 15,  6, 13, 20,  2, 14, 13],
         [13, 15, 17, 19,  7, 15, 20, 14],
         [10, 22, 16,  4,  9, 16, 14,  3],
         [10,  7,  9,  2,  5,  9,  6, 14],
         [18, 18, 10, 10, 15, 16, 15,  3],
         [ 5,  0, 15,  3, 10, 15, 21,  4],
         [12, 20,  2,  0,  6,  3,  8,  2],
         [ 2,  2, 17,  6,  2,  7,  6,  8]]),
 'tgt': tensor([[17, 15,  2, 17,  7, 13, 21,  1],
         [22,  6,  4, 22,  2, 16, 16, 16],
         [10,  4,  5,  9, 19, 21, 16, 15],
         [ 6, 18, 14, 10, 18,  6,  1, 11],
         [ 1, 10, 18,  3, 12, 19, 17,  0],
         [ 5, 10,  6, 13, 22,  1,  7, 22],
         [ 0, 15,  7,  1, 17,  8, 14, 

In [5]:
class MyRNN(nn.Module):
    def __init__(self,vocab_size,embed_dim=128,rnn_hidden_dim=128,rnn_num_layers=3):
        super().__init__()
        self.embedding=nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim)
        self.rnn=nn.RNN(input_size=embed_dim,hidden_size=rnn_hidden_dim,num_layers=rnn_num_layers)
        self.head=nn.Linear(rnn_hidden_dim,vocab_size)
    def forward(self,x):
        x=self.embedding(x) # (seq_len,batch_size,hidden_dim)
        rnn_output,hidden_states=self.rnn(x) # (seq_len,batch_size,hidden_dim)
        output=self.head(rnn_output)
        return output # (seq_len,batch_size,max_ids)
    
rnn_hidden_dim=128
rnn_num_layers=3
embed_dim = 64
device='cuda:0'
model=MyRNN(vocab_size=vocab_size,embed_dim=embed_dim,rnn_hidden_dim=rnn_hidden_dim,rnn_num_layers=rnn_num_layers).to(device)
loss_func=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

In [6]:
def train(model, train_dataloader, loss_func, optimizer, device):
    losses = []
    for batch in train_dataloader:
        src = batch["src"].to(device)
        tgt = batch["tgt"].to(device)
        logits = model(src)  # output是一个序列，每个元素是一个长度为vocab_size的logits向量
        logits = logits.view(-1, logits.shape[-1])
        tgt = tgt.view(-1)
        loss = loss_func(logits, tgt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return sum(losses) / (len(losses) * batch_size)


@torch.inference_mode() # 验证的时候关闭梯度计算
def eval(model, eval_dataloader, loss_func, device):
    losses = []
    for batch in eval_dataloader:
        src = batch["src"].to(device)
        tgt = batch["tgt"].to(device)
        logits = model(src)
        logits = logits.view(-1, logits.shape[-1])
        tgt = tgt.view(-1)
        loss = loss_func(logits, tgt)
        losses.append(loss.item())
    return sum(losses) / (len(losses) * batch_size)

In [7]:
epochs = 5
for epoch in range(epochs):
    train_loss = train(
        model=model,
        train_dataloader=train_dataloader,
        loss_func=loss_func,
        optimizer=optimizer,
        device=device,
    )
    eval_loss = eval(
        model=model, eval_dataloader=eval_dataloader, loss_func=loss_func, device=device
    )
    print(f"epoch:{epoch}, train_loss:{train_loss}, eval_loss:{eval_loss}")

epoch:0, train_loss:0.11418168428108881, eval_loss:0.0277207300812006
epoch:1, train_loss:0.008777372134653362, eval_loss:0.0024535128567367793
epoch:2, train_loss:0.0016124067570720065, eval_loss:0.001135591024843355
epoch:3, train_loss:0.0008881887203499087, eval_loss:0.0007181299885269254
epoch:4, train_loss:0.0005925360118884495, eval_loss:0.0005046538570119689


In [8]:
@torch.inference_mode()
def predict(sentence, device):
    x_tensor = torch.tensor(sentence).unsqueeze(0).to(device)
    logits = model(x_tensor)
    # print(logits.shape) # (1,seq_len,vocab_size)
    output_id = torch.argmax(logits, dim=-1).squeeze(0)  # 在logits那一维取argmax
    print(f"原序列 {sentence}")
    print(f"生成的序列 {output_id.tolist()}")


test_sentence = [4, 6, 21, 18, 0]
predict(test_sentence, device=device)

原序列 [4, 6, 21, 18, 0]
生成的序列 [4, 6, 21, 18, 0]
