## Seq2seq
使用pytorch官方的LSTM和自制的数据集，实现Seq2seq任务的训练和预测。  
在这个demo里，输入长度和输出长度不一定一样。

In [38]:
import torch
from torch import nn
import random

In [None]:
# 建立Dataloader

from toy_datasets import get_dataset_AddSeq

# 这里只考虑batch_first=True的情况
data_num = 5000
dataset, collate_fn,(index_bos,index_eos,index_pad,index_add,vocab_size) = get_dataset_AddSeq(
    data_num=data_num
) # 一个带符号的个位数加法任务的数据集

train_size = 0.2
train_dataset = dataset[: int(data_num * train_size)]
eval_dataset = dataset[int(data_num * train_size) :]
batch_size = 16
train_dataloader = torch.utils.data.DataLoader(
    dataset=train_dataset, batch_size=batch_size, collate_fn=collate_fn
)
eval_dataloader = torch.utils.data.DataLoader(
    dataset=eval_dataset, batch_size=batch_size, collate_fn=collate_fn
)

In [52]:
next(train_dataloader.__iter__())

{'src': tensor([[12,  9, 10,  2, 13],
         [12,  0, 10,  4, 13],
         [12,  2, 10,  3, 13],
         [12,  1, 10,  0, 13],
         [12,  9, 10,  6, 13],
         [12,  6, 10,  3, 13],
         [12,  0, 10,  1, 13],
         [12,  1, 10,  2, 13],
         [12,  6, 10,  2, 13],
         [12,  2, 10,  7, 13],
         [12,  9, 10,  1, 13],
         [12,  1, 10,  0, 13],
         [12,  0, 10,  6, 13],
         [12,  0, 10,  5, 13],
         [12,  8, 10,  7, 13],
         [12,  9, 10,  9, 13]]),
 'tgt': tensor([[12,  1,  1, 13],
         [12,  4, 13, 11],
         [12,  5, 13, 11],
         [12,  1, 13, 11],
         [12,  1,  5, 13],
         [12,  9, 13, 11],
         [12,  1, 13, 11],
         [12,  3, 13, 11],
         [12,  8, 13, 11],
         [12,  9, 13, 11],
         [12,  1,  0, 13],
         [12,  1, 13, 11],
         [12,  6, 13, 11],
         [12,  5, 13, 11],
         [12,  1,  5, 13],
         [12,  1,  8, 13]])}

In [None]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,embed_dim,hidden_dim,num_layers):
        super().__init__()
        self.embedding=nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim)
        self.lstm=nn.LSTM(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers,batch_first=True)
    def forward(self,src):
        x=self.embedding(src)
        lstm_output,(hidden_states,cell)=self.lstm(x)
        return hidden_states,cell
    
class Decoder(nn.Module):
    def __init__(self,vocab_size,embed_dim,hidden_dim,num_layers):
        super().__init__()
        self.lstm=nn.LSTM(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers,batch_first=True)
        self.embedding=nn.Embedding(num_embeddings=vocab_size,embedding_dim=embed_dim)
        self.head=nn.Linear(hidden_dim,vocab_size) # 分类头
    def forward(self,tgt,hidden_states,cell):
      
        x=self.embedding(tgt)
        lstm_output,(hidden_states,cell)=self.lstm(x,(hidden_states,cell)) 
        output=self.head(lstm_output)
        return output,(hidden_states,cell) 
    
class Seq2Seq(nn.Module):
    def __init__(self,vocab_size,embed_dim,hidden_dim,num_layers,teacher_ratio):
        super().__init__()
        self.teacher_ratio=teacher_ratio # teacher forcing的概率
        self.encoder=Encoder(vocab_size=vocab_size,embed_dim=embed_dim,hidden_dim=hidden_dim,num_layers=num_layers)
        self.decoder=Decoder(vocab_size=vocab_size,embed_dim=embed_dim,hidden_dim=hidden_dim,num_layers=num_layers)
    def forward(self,src,tgt):
        batch_size,tgt_seq_len=tgt.shape
        hidden_states,cell=self.encoder(src) # encoder先对输入进行编码，得到hidden_states和cell
        last_output_ids=tgt[:,0].unsqueeze(1) # batch中，每个序列的第一个token作为起始输入
        # (batch_size,1)
        output=[]
        for i in range(1,tgt_seq_len):
            # 输入保存着的hidden_states、cell和这次的输入，预测下一个token的logits
            logits,(hidden_states,cell)=self.decoder(last_output_ids,hidden_states,cell) 
            # logits: (batch_size,1,vocab_size)
            output.append(logits.squeeze(1))
            if random.random() < self.teacher_ratio:
                last_output_ids=tgt[:,i].unsqueeze(1) # teacher forcing，取真实的target
            else :
                last_output_ids=torch.argmax(logits,dim=-1) # 取预测的target
        return torch.stack(output,dim=1)
    
    @torch.inference_mode()
    def predict(self,src,tgt,max_seq_len=10,index_eos=13):
        batch_size,tgt_seq_len=tgt.shape
        hidden_states,cell=self.encoder(src)
        last_output_ids=tgt[:,0].unsqueeze(1) # (batch_size,1)
        output=[]
        for _ in range(tgt_seq_len,max_seq_len):
            logits,(hidden_states,cell)=self.decoder(last_output_ids,hidden_states,cell)
            last_output_ids=torch.argmax(logits,dim=-1)
            output.append(last_output_ids.item())
            if last_output_ids.item()==index_eos: # 预测到了eos，就停止预测
                break
        return output

In [None]:
from tqdm import tqdm

    
hidden_dim=128
embed_dim=64
num_layers=3
device='cuda:0'
teacher_ratio=0.5
model=Seq2Seq(vocab_size=vocab_size,embed_dim=embed_dim,hidden_dim=hidden_dim,num_layers=num_layers,teacher_ratio=teacher_ratio).to(device)
loss_func=nn.CrossEntropyLoss(ignore_index=index_pad) # 在训练时，忽略pad的loss
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

In [43]:
def train(model, train_dataloader, loss_func, optimizer, device):
    losses = []
    for batch in train_dataloader:
        src=batch['src'].to(device)
        tgt=batch['tgt'].to(device)
        tgt_output=model(src,tgt)
        tgt=tgt[:,1:] # 由于预测的是下一个token，所以需要去掉第一个token
        tgt_output=tgt_output.view(-1,tgt_output.shape[-1])
        tgt=tgt.reshape(-1)
        loss=loss_func(tgt_output,tgt)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad() 
        losses.append(loss.item())
    return sum(losses) / (len(losses) * batch_size)


@torch.inference_mode() # 验证的时候关闭梯度计算
def eval(model, eval_dataloader, loss_func, device):
    losses = []
    for batch in eval_dataloader:
        src=batch['src'].to(device)
        tgt=batch['tgt'].to(device)
        tgt_output=model(src,tgt)
        tgt=tgt[:,1:]
        tgt_output=tgt_output.view(-1,tgt_output.shape[-1])
        tgt=tgt.reshape(-1)
        loss=loss_func(tgt_output,tgt)
        losses.append(loss.item())
    return sum(losses) / (len(losses) * batch_size)

In [44]:
epochs = 10
for epoch in range(epochs):
    train_loss = train(
        model=model,
        train_dataloader=train_dataloader,
        loss_func=loss_func,
        optimizer=optimizer,
        device=device,
    )
    eval_loss = eval(
        model=model, eval_dataloader=eval_dataloader, loss_func=loss_func, device=device
    )
    print(f"epoch:{epoch}, train_loss:{train_loss}, eval_loss:{eval_loss}")

epoch:0, train_loss:0.10153274072541131, eval_loss:0.07112701869010925
epoch:1, train_loss:0.06079302958789326, eval_loss:0.05005219842493534
epoch:2, train_loss:0.04143774811001051, eval_loss:0.03180935980379582
epoch:3, train_loss:0.03020581407915978, eval_loss:0.025679446585476398
epoch:4, train_loss:0.01945081050138152, eval_loss:0.0145473104827106


In [None]:
# 个位数加法
a=9
b=5
def add(a,b):
    input_list=[index_bos]+[a,index_add,b]+[index_eos]
    tgt_list=[index_bos]
    src=torch.tensor(input_list).to(device).unsqueeze(0)
    tgt=torch.tensor(tgt_list).to(device).unsqueeze(0)
    output=model.predict(src,tgt,max_seq_len=10)
    print(f'输入序列 {src}')
    print(f'输出序列 {tgt_list+output}')
    print(f'{a} + {b}={output[:-1]}')
add(a,b)

输入序列 tensor([[12,  9, 10,  5, 13]], device='cuda:0')
输出序列 [12, 1, 4, 13]
9 + 5=[1, 4]
