In [1]:
import torch
from torch import nn
import random

In [2]:
# 建立Dataloader
max_ids=24
embed_dim=64
index_pad=0
index_bos=1
index_eos=2

def make_dataset(data_num):
    dataset=[]
    def myiter(start,max):
        while start <=max:
            yield start
            start+=2
    
    for _ in range(data_num):
        x_start=random.randint(3,6)
        x_stop=x_start+random.randint(1,3)*2
        x_ids=torch.tensor([index_bos]+list(range(x_start,x_stop,2))+[index_eos])
        y_start=x_stop
        y_ids=torch.tensor([index_bos]+[i for i in myiter(y_start,max_ids-1)]+[index_eos])
        dataset.append(
            {
                'x_ids':x_ids,
                'y_ids':y_ids
            }
        )
    def collate_fn(batch):
        x_ids=[b['x_ids'] for b in batch]
        y_ids=[b['y_ids'] for b in batch]
        x_ids=nn.utils.rnn.pad_sequence(x_ids,padding_value=index_pad)
        y_ids=nn.utils.rnn.pad_sequence(y_ids,padding_value=index_pad) # (seq_len,batch_size)
        
        batch={
                'x_ids':x_ids,
                'y_ids':y_ids
            }
        return batch
    return dataset,collate_fn
train_dataset,collate_fn=make_dataset(data_num=10000) # 构造一个y=x+1的序列
batch_size=16
train_dataloader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,collate_fn=collate_fn)

In [3]:
next(train_dataloader.__iter__())

{'x_ids': tensor([[ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
         [ 4,  5,  5,  5,  6,  5,  5,  5,  6,  4,  3,  6,  6,  6,  5,  4],
         [ 6,  2,  7,  7,  8,  7,  7,  7,  8,  6,  2,  2,  8,  2,  7,  2],
         [ 8,  0,  9,  2, 10,  2,  2,  2,  2,  8,  0,  0, 10,  0,  9,  0],
         [ 2,  0,  2,  0,  2,  0,  0,  0,  0,  2,  0,  0,  2,  0,  2,  0]]),
 'y_ids': tensor([[ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
         [10,  7, 11,  9, 12,  9,  9,  9, 10, 10,  5,  8, 12,  8, 11,  6],
         [12,  9, 13, 11, 14, 11, 11, 11, 12, 12,  7, 10, 14, 10, 13,  8],
         [14, 11, 15, 13, 16, 13, 13, 13, 14, 14,  9, 12, 16, 12, 15, 10],
         [16, 13, 17, 15, 18, 15, 15, 15, 16, 16, 11, 14, 18, 14, 17, 12],
         [18, 15, 19, 17, 20, 17, 17, 17, 18, 18, 13, 16, 20, 16, 19, 14],
         [20, 17, 21, 19, 22, 19, 19, 19, 20, 20, 15, 18, 22, 18, 21, 16],
         [22, 19, 23, 21,  2, 21, 21, 21, 22, 22, 17, 20,  2, 20, 23, 18],
     

In [4]:
class Encoder(nn.Module):
    def __init__(self,max_ids,embed_dim,hidden_dim,num_layers):
        super().__init__()
        self.embedding=nn.Embedding(num_embeddings=max_ids,embedding_dim=embed_dim)
        self.lstm=nn.LSTM(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers)
    def forward(self,input_ids):
        x=self.embedding(input_ids) # (seq_len,batch_size,hidden_dim)
        lstm_output,(hidden_states,cell)=self.lstm(x) # (seq_len,batch_size,hidden_dim)
        return hidden_states,cell # (num_layers,batch_size,hidden_dim)
    
class Decoder(nn.Module):
    def __init__(self,max_ids,embed_dim,hidden_dim,num_layers):
        super().__init__()
        self.lstm=nn.LSTM(input_size=embed_dim,hidden_size=hidden_dim,num_layers=num_layers)
        self.embedding=nn.Embedding(num_embeddings=max_ids,embedding_dim=embed_dim)
        self.head=nn.Linear(hidden_dim,max_ids)
    def forward(self,input_ids,hidden_states,cell):
        # input_ids:(seq_len,batch_size)
        x=self.embedding(input_ids)
        lstm_output,(hidden_states,cell)=self.lstm(x,(hidden_states,cell)) # (seq_len,batch_size,hidden_dim)
        output=self.head(lstm_output) #(seq_len,batch_size,max_ids)
        return output,(hidden_states,cell) 
    
class Seq2Seq(nn.Module):
    def __init__(self,max_ids,embed_dim,hidden_dim,num_layers,teacher_ratio):
        super().__init__()
        self.teacher_ratio=teacher_ratio
        self.encoder=Encoder(max_ids=max_ids,embed_dim=embed_dim,hidden_dim=hidden_dim,num_layers=num_layers)
        self.decoder=Decoder(max_ids=max_ids,embed_dim=embed_dim,hidden_dim=hidden_dim,num_layers=num_layers)
    def forward(self,input_ids,tgt_ids):
        # input_ids:(seq_len,batch_size)
        # tgt_ids:(seq_len,batch_size)
        hidden_states,cell=self.encoder(input_ids)
        last_output_ids=tgt_ids[0,:].unsqueeze(0) # Decoder起始输入的都是第一个token
        tgt_seq_len=tgt_ids.shape[0]
        output=[]
        for i in range(1,tgt_seq_len):
            #print(last_output_ids.shape)
            logits,(hidden_states,cell)=self.decoder(last_output_ids,hidden_states,cell)
            output.append(logits.squeeze(0))
            if random.random() < self.teacher_ratio:
                last_output_ids=tgt_ids[i].unsqueeze(0) # teacher forcing
            else :
                last_output_ids=torch.argmax(logits,dim=-1)
        return torch.stack(output)

In [5]:
from tqdm import tqdm

    
hidden_dim=128
num_layers=3
device='cuda:0'
teacher_ratio=0.5
model=Seq2Seq(max_ids=max_ids,embed_dim=embed_dim,hidden_dim=hidden_dim,num_layers=num_layers,teacher_ratio=teacher_ratio).to(device)
model.train()
loss_func=nn.CrossEntropyLoss(ignore_index=index_pad) # 其实不ignore也行，因为tgt中的padding token一定在eos后
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
epoch=3

for i in range(epoch):
    with tqdm(total=len(train_dataloader), desc=f"Epoch {i+1}/{epoch}") as pbar:
        for batch in train_dataloader:
            x_ids=batch['x_ids'].to(device) # (src_seq_len,batch_size)
            y_ids=batch['y_ids'].to(device) # (tgt_seq_len,batch_size)
            y_output=model(x_ids,y_ids) # (tgt_seq_len-1,batch_size)
            y_ids=y_ids[1:] # (tgt_seq_len-1,batch_size) 
            y_output=y_output.view(-1,y_output.shape[-1])
            y_ids=y_ids.view(-1)
            loss=loss_func(y_output,y_ids)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad() 

            pbar.update(1)
            pbar.set_postfix({"loss": loss.item()/batch_size})

Epoch 1/3: 100%|██████████| 625/625 [00:14<00:00, 42.04it/s, loss=0.0015] 
Epoch 2/3: 100%|██████████| 625/625 [00:14<00:00, 44.43it/s, loss=0.000598]
Epoch 3/3: 100%|██████████| 625/625 [00:14<00:00, 44.60it/s, loss=0.000362]


In [12]:
# 0 pad,1 bos ,2 eos
model.eval()
encoder_input_ids= [index_bos]+[2,4]+[index_eos]
input_ids=[torch.tensor(encoder_input_ids).to(device)]
padded_input_ids=nn.utils.rnn.pad_sequence(input_ids,padding_value=index_pad)

hidden_states,cell=model.encoder(padded_input_ids)


res_ids=[]
res_ids.append(index_bos)
decoder_input_ids=[index_bos]
max_len=16

with torch.no_grad():
    for _ in range(max_len):
        decoder_input_tensor=torch.LongTensor([decoder_input_ids[-1]]).to(device).unsqueeze(0)
        output,(hidden_states,cell)=model.decoder(decoder_input_tensor,hidden_states,cell)
        
        output_id = torch.argmax(output,dim=2).item() # dim=-1
        decoder_input_ids.append(output_id)  
        if output_id==index_eos:
            break
decoder_input_ids

[1, 6, 8, 10, 12, 14, 16, 18, 20, 22, 2]