In [None]:
import torch
from torch import nn
import random

In [82]:
# 建立Dataloader
max_ids=24
seq_len=8
embed_dim=64
def make_dataset(data_num,seq_len):
    dataset=[]
    for _ in range(data_num):
        x_ids=torch.tensor([random.randint(0,15) for i in range(seq_len)])
        y_ids=x_ids+1
        dataset.append(
            {
                'x_ids':x_ids,
                'y_ids':y_ids
            }
        )
    def collate_fn(batch):
        x_ids=[b['x_ids'] for b in batch]
        y_ids=[b['y_ids'] for b in batch]
        x_ids=nn.utils.rnn.pad_sequence(x_ids)
        y_ids=nn.utils.rnn.pad_sequence(y_ids) # (seq_len,batch_size)
        # 把seq_len放在第一维的原因是让pytorch可以对每个timestep的结果进行并行计算而不是对每个完整序列的结果进行并行计算。
        # 因为seq2seq的任务需要对齐的是每个timestep的结果。别的任务一般都是对齐每条数据的结果。
        batch={
                'x_ids':x_ids,
                'y_ids':y_ids
            }
        return batch
    return dataset,collate_fn
train_dataset,collate_fn=make_dataset(data_num=1000,seq_len=seq_len) # 构造一个y=x+1的序列
batch_size=16
train_dataloader=torch.utils.data.DataLoader(dataset=train_dataset,batch_size=batch_size,collate_fn=collate_fn)

In [83]:
next(train_dataloader.__iter__())

{'x_ids': tensor([[ 6,  4, 15, 12,  1,  6,  7,  3, 12, 14,  3, 12, 10, 13,  1,  2],
         [ 8, 10,  2,  4,  4, 15,  0,  5,  6,  6, 15,  7, 11, 13, 15, 14],
         [ 3,  3,  9,  7, 15, 14,  7,  4, 14,  5,  1,  4, 11,  3,  0,  9],
         [15, 15,  5, 10, 15, 11,  1,  1,  2, 12, 13,  6,  7, 12,  5,  3],
         [ 9, 11, 15,  0,  8,  5,  1, 14,  5,  9,  6, 12, 12,  1, 11,  6],
         [15,  2,  2,  0,  8,  7,  2,  0, 12,  1,  6, 14,  6,  0, 13, 12],
         [12,  3, 12,  0,  8, 13,  6, 13,  8,  5,  3,  1,  4,  1, 10,  7],
         [11, 10,  2, 10,  1, 15,  5, 11, 15, 12, 12,  6,  5, 10, 13, 10]]),
 'y_ids': tensor([[ 7,  5, 16, 13,  2,  7,  8,  4, 13, 15,  4, 13, 11, 14,  2,  3],
         [ 9, 11,  3,  5,  5, 16,  1,  6,  7,  7, 16,  8, 12, 14, 16, 15],
         [ 4,  4, 10,  8, 16, 15,  8,  5, 15,  6,  2,  5, 12,  4,  1, 10],
         [16, 16,  6, 11, 16, 12,  2,  2,  3, 13, 14,  7,  8, 13,  6,  4],
         [10, 12, 16,  1,  9,  6,  2, 15,  6, 10,  7, 13, 13,  2, 12,  7],
     

In [84]:
from tqdm import tqdm
class MyRNN(nn.Module):
    def __init__(self,max_ids,embed_dim,rnn_hidden_dim,rnn_num_layers):
        super().__init__()
        self.embedding=nn.Embedding(num_embeddings=max_ids,embedding_dim=embed_dim)
        self.rnn=nn.RNN(input_size=embed_dim,hidden_size=rnn_hidden_dim,num_layers=rnn_num_layers)
        self.head=nn.Linear(rnn_hidden_dim,max_ids)
    def forward(self,x):
        x=self.embedding(x) # (seq_len,batch_size,hidden_dim)
        rnn_output,hidden_states=self.rnn(x) # (seq_len,batch_size,hidden_dim)
        output=self.head(rnn_output)
        return output # (seq_len,batch_size,max_ids)
    
rnn_hidden_dim=128
rnn_num_layers=3
device='cuda:0'
model=MyRNN(max_ids=max_ids,embed_dim=embed_dim,rnn_hidden_dim=rnn_hidden_dim,rnn_num_layers=rnn_num_layers).to(device)
loss_func=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
epoch=3

for i in range(epoch):
    with tqdm(total=len(train_dataloader), desc=f"Epoch {i+1}/{epoch}") as pbar:
        for batch in train_dataloader:
            x_ids=batch['x_ids'].to(device)
            y_ids=batch['y_ids'].to(device)  
            y_output=model(x_ids)
            y_ids=y_ids.unsqueeze(2)
            y_output=y_output.view(-1,y_output.shape[-1])
            y_ids=y_ids.view(-1)
            loss=loss_func(y_output,y_ids)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad() 

            pbar.update(1)
            pbar.set_postfix({"loss": loss.item()/batch_size})

Epoch 1/3:   0%|          | 0/63 [00:00<?, ?it/s]

Epoch 1/3: 100%|██████████| 63/63 [00:00<00:00, 142.28it/s, loss=0.00225]
Epoch 2/3: 100%|██████████| 63/63 [00:00<00:00, 157.54it/s, loss=0.000645]
Epoch 3/3: 100%|██████████| 63/63 [00:00<00:00, 198.93it/s, loss=0.000355]


In [None]:
x_test=[4,6,8,4,0]
x_tensor=torch.tensor(x_test).to(device)
logits=model(x_tensor) # (seq_len,max_ids)
print(logits.shape)
print(logits[0])
output_ids=torch.argmax(logits,dim=1) #在max_ids那一维取argmax
output_ids

torch.Size([5, 24])
tensor([-0.8360, -0.0113,  0.6309,  0.8853, -3.6343,  8.5278,  0.8091,  0.9241,
        -0.5562, -2.6830,  0.3736, -1.2310,  1.3955, -0.9789,  0.5743,  0.9856,
        -1.6160, -0.9963, -0.3596, -0.9936, -0.7572, -0.8647, -1.0270, -0.6180],
       device='cuda:0', grad_fn=<SelectBackward0>)


tensor([5, 7, 9, 5, 1], device='cuda:0')