## Seq2seq-LSTM
使用pytorch官方的lstm和自制的数据集，实现Seq2seq任务的训练和预测。  
这里是一个很简单的demo，输入长度和输出长度是一样的。  
实现的功能：将输入序列的每个元素都加1。

In [12]:
import torch
from torch import nn

In [13]:
from toy_datasets import get_dataset_add1

vocab_size = 24
seq_len = 8
data_num = 3000
batch_first = True

dataset, collate_fn = get_dataset_add1(
    data_num=data_num, vocab_size=vocab_size, seq_len=seq_len, batch_first=batch_first
) # 一个加法任务的数据集
train_size = 0.2
train_dataset = dataset[: int(data_num * train_size)]
eval_dataset = dataset[int(data_num * train_size) :]
batch_size = 16
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, collate_fn=collate_fn
)
eval_dataloader = torch.utils.data.DataLoader(
    dataset=eval_dataset, batch_size=batch_size, collate_fn=collate_fn
)

In [None]:
train_dataset[:3]  # 看一下dataset，可以发现tgt是src+1

[{'src': tensor([16, 18,  6,  0, 18, 11, 14,  6]),
  'tgt': tensor([17, 19,  7,  1, 19, 12, 15,  7])},
 {'src': tensor([ 0,  0, 15,  1,  3,  9,  6,  1]),
  'tgt': tensor([ 1,  1, 16,  2,  4, 10,  7,  2])},
 {'src': tensor([17, 12, 20, 21, 10,  0,  4,  5]),
  'tgt': tensor([18, 13, 21, 22, 11,  1,  5,  6])}]

In [15]:
# 看一下dataloader
# dataloader是一个迭代器，不能直接train_dataloader[0]查看
next(train_dataloader.__iter__())

{'src': tensor([[16, 18,  6,  0, 18, 11, 14,  6],
         [ 0,  0, 15,  1,  3,  9,  6,  1],
         [17, 12, 20, 21, 10,  0,  4,  5],
         [17, 10,  9,  6, 16, 16, 14,  5],
         [18,  2,  2,  2,  0, 18,  9, 19],
         [ 1, 22, 18, 15,  0,  9, 14, 19],
         [16,  6,  7, 22,  3, 11, 21,  9],
         [16,  8, 17, 19, 12, 16,  6, 16],
         [ 7,  1, 21,  5,  2,  6, 22, 21],
         [ 1, 14, 18,  9,  6,  6, 11, 17],
         [17, 14,  1, 21,  2, 18,  5,  2],
         [11, 11,  0,  0,  8, 13,  0, 11],
         [12, 17, 15,  7, 21,  5,  1,  9],
         [15, 22, 15,  9,  3, 19, 12,  6],
         [ 8, 10, 20, 11, 16,  5,  0, 15],
         [ 7, 10, 18, 15,  6, 20, 15,  5]]),
 'tgt': tensor([[17, 19,  7,  1, 19, 12, 15,  7],
         [ 1,  1, 16,  2,  4, 10,  7,  2],
         [18, 13, 21, 22, 11,  1,  5,  6],
         [18, 11, 10,  7, 17, 17, 15,  6],
         [19,  3,  3,  3,  1, 19, 10, 20],
         [ 2, 23, 19, 16,  1, 10, 15, 20],
         [17,  7,  8, 23,  4, 12, 22, 

In [None]:
class MyLSTM(nn.Module):
    def __init__(
        self,
        vocab_size,
        embed_dim=128,
        lstm_hidden_dim=128,
        lstm_num_layers=3,
        batch_first=False,
    ):
        super(MyLSTM, self).__init__()
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size, embedding_dim=embed_dim
        )  # 嵌入
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=lstm_hidden_dim,
            num_layers=lstm_num_layers,
            batch_first=batch_first,
        )  # 使用torch的LSTM
        self.head = nn.Linear(lstm_hidden_dim, vocab_size)  # 线性分类头，根据输出的logits分类

    def forward(self, x):
        x = self.embedding(x)  # (batch_size,seq_len,hidden_dim)
        lstm_output, (hidden_states, cell) = self.lstm(x)  # (batch_size,seq_len,hidden_dim)
        output = self.head(lstm_output)
        return output  # (batch_size,seq_len,vocab_size)

In [17]:
lstm_hidden_dim = 128
embed_dim = 64
lstm_num_layers = 3
device = "cuda:0"
model = MyLSTM(
    vocab_size=vocab_size,
    embed_dim=embed_dim,
    lstm_hidden_dim=lstm_hidden_dim,
    lstm_num_layers=lstm_num_layers,
    batch_first=batch_first,
).to(device)
loss_func = nn.CrossEntropyLoss()  # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
def train(model, train_dataloader, loss_func, optimizer, device):
    losses = []
    model.train()
    for batch in train_dataloader:
        src = batch["src"].to(device)
        tgt = batch["tgt"].to(device)
        logits = model(src)  # output是一个序列，每个元素是一个长度为vocab_size的logits向量
        logits = logits.view(-1, logits.shape[-1])
        tgt = tgt.view(-1)
        loss = loss_func(logits, tgt)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    return sum(losses) / (len(losses) * batch_size)


@torch.inference_mode() # 验证的时候关闭梯度计算
def eval(model, eval_dataloader, loss_func, device):
    losses = []
    model.eval()
    for batch in eval_dataloader:
        src = batch["src"].to(device)
        tgt = batch["tgt"].to(device)
        logits = model(src)
        logits = logits.view(-1, logits.shape[-1])
        tgt = tgt.view(-1)
        loss = loss_func(logits, tgt)
        losses.append(loss.item())
    return sum(losses) / (len(losses) * batch_size)

In [19]:
epochs = 5
for epoch in range(epochs):
    train_loss = train(
        model=model,
        train_dataloader=train_dataloader,
        loss_func=loss_func,
        optimizer=optimizer,
        device=device,
    )
    eval_loss = eval(
        model=model, eval_dataloader=eval_dataloader, loss_func=loss_func, device=device
    )
    print(f"epoch:{epoch}, train_loss:{train_loss}, eval_loss:{eval_loss}")

epoch:0, train_loss:0.19413185237269653, eval_loss:0.17754725565512974
epoch:1, train_loss:0.0988553432061484, eval_loss:0.028643952620526155
epoch:2, train_loss:0.011248174650398525, eval_loss:0.004105844631170233
epoch:3, train_loss:0.002652483719621638, eval_loss:0.0018517607264220714
epoch:4, train_loss:0.0014275532969469694, eval_loss:0.001156220156699419


In [21]:
@torch.inference_mode()
def predict(sentence, device):
    x_tensor = torch.tensor(sentence).unsqueeze(0).to(device)
    logits = model(x_tensor)
    # print(logits.shape) # (1,seq_len,vocab_size)
    output_id = torch.argmax(logits, dim=-1).squeeze(0)  # 在logits那一维取argmax
    print(f"原序列 {sentence}")
    print(f"生成的序列 {output_id.tolist()}")


test_sentence = [4, 6, 21, 18, 0]
predict(test_sentence, device=device)

原序列 [4, 6, 21, 18, 0]
生成的序列 [5, 7, 22, 19, 1]
