In [None]:
import torch
import pandas as pd
import torch.nn as nn
from itertools import takewhile
from torch.utils.data import Dataset, DataLoader


class Vocab:
    def __init__(self):
        self.vocab = None
        self.size = None
        self.word2idx = None
        self.sequences = None
        self.seq_len = None


# 文本处理
def text_process(file_path, source, target):
    data = pd.read_csv(file_path)
    # 构建词表
    source.vocab = ["<pad>", "<unk>", "<bos>", "<eos>"] + list(set("".join(data["source"])))
    target.vocab = ["<pad>", "<unk>", "<bos>", "<eos>"] + list("1234567890-")
    # 词表大小
    source.size = len(source.vocab)
    target.size = len(target.vocab)
    # 词到索引的映射
    source.word2idx = {word: index for index, word in enumerate(source.vocab)}
    target.word2idx = {word: index for index, word in enumerate(target.vocab)}
    # 语料索引化，并添加起止符
    source_idx = [
        [source.word2idx["<bos>"]]
        + [source.word2idx.get(word, source.word2idx["<unk>"]) for word in line]
        + [source.word2idx["<eos>"]]
        for line in data["source"]
    ]
    target_idx = [
        [target.word2idx["<bos>"]]
        + [target.word2idx.get(word, target.word2idx["<unk>"]) for word in line]
        + [target.word2idx["<eos>"]]
        for line in data["target"]
    ]
    # 计算最大的序列长度，并对长度不足的序列使用<pad>补齐
    source.seq_len = max([len(line) for line in source_idx])
    target.seq_len = max([len(line) for line in target_idx])
    source.sequences = [line + [source.word2idx["<pad>"]] * (source.seq_len - len(line)) for line in source_idx]
    target.sequences = [line + [target.word2idx["<pad>"]] * (target.seq_len - len(line)) for line in target_idx]
    return source, target


source, target = text_process("data/date.csv", Vocab(), Vocab())

In [None]:
# 构建数据集
class Seq2SeqDataset(Dataset):
    def __init__(self, source_idx, target_idx):
        self.source_idx = torch.LongTensor(source_idx)  # [num_samples, seq_len]
        self.target_idx = torch.LongTensor(target_idx)  # [num_samples, seq_len]

    def __len__(self):
        return len(self.source_idx)

    def __getitem__(self, idx):
        return self.source_idx[idx], self.target_idx[idx]


def collate_fn(batch):
    source, target = zip(*batch)
    source = torch.stack(source).transpose(0, 1)  # [seq_len, batch_size]
    target = torch.stack(target).transpose(0, 1)  # [seq_len, batch_size]
    return source, target

In [None]:
# 模型搭建
class Encoder(nn.Module):
    """编码器"""

    def __init__(self, input_size, hidden_size):
        # input_size:输入词表大小
        # hidden_size:隐藏层维度数量
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)  # 词嵌入层
        self.lstm = nn.LSTM(hidden_size, hidden_size)  # lstm层

    def forward(self, input):
        # input: [seq_len, batch_size]
        embedded = self.embedding(input)  # [seq_len, batch_size, hidden_size]
        output, (hidden, cell) = self.lstm(embedded)
        # output: [seq_len, batch_size, hidden_size]
        # hidden: [1, batch_size, hidden_size]
        # cell: [1, batch_size, hidden_size]
        return output, hidden, cell


class Attention(nn.Module):
    """注意力机制"""

    def __init__(self, hidden_size):
        super().__init__()
        self.Wq = nn.Linear(hidden_size, hidden_size)
        self.Wk = nn.Linear(hidden_size, hidden_size)
        self.W = nn.Linear(hidden_size, 1)

    def forward(self, encoder_outputs, decoder_hidden):
        # encoder_outputs: [seq_len, batch_size, hidden_size]
        # decoder_hidden: [1, batch_size, hidden_size]
        encoder_outputs1 = encoder_outputs.permute(1, 0, 2)  # [batch_size, seq_len, hidden_size]
        decoder_hidden = decoder_hidden.permute(1, 0, 2)  # [batch_size, 1, hidden_size]
        score = self.W(torch.tanh(self.Wq(decoder_hidden) + self.Wk(encoder_outputs1)))  # [batch_size, seq_len, 1]
        attention_weights = torch.softmax(score, dim=1).transpose(1, 2)  # [batch_size, 1, seq_len]
        # bmm为批量矩阵乘法，对批量中的每个样本独立执行矩阵乘法。bmm([batch_size, n, m],[batch_size, m, p]) -> [batch_size, n, p]
        context = torch.bmm(attention_weights, encoder_outputs1)  # [batch_size, 1, hidden_size]
        return context.transpose(0, 1)  # [1, batch_size, hidden_size]


class Decoder(nn.Module):
    """解码器"""

    def __init__(self, output_size, hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)  # 词嵌入层
        self.attention = Attention(hidden_size)  # 注意力机制
        self.lstm = nn.LSTM(hidden_size * 2, hidden_size)  # lstm层
        self.linear = nn.Linear(hidden_size, output_size)  # 全连接层

    def forward(self, input, hidden, cell, encoder_outputs):
        # input: [1, batch_size]
        embedded = self.embedding(input)  # [1, batch_size, hidden_size]
        context = self.attention(encoder_outputs, hidden)  # [1, batch_size, hidden_size]
        input = torch.cat((embedded, context), dim=2)  # 拼接输入向量和上下文向量，# [1, batch_size, hidden_size*2]
        output, (hidden, cell) = self.lstm(input, (hidden, cell))  # [1, batch_size, hidden_size]
        output = self.linear(output)  # [1, batch_size, output_size]
        return output, hidden, cell


class Seq2Seq(nn.Module):

    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        # source: [seq_len, batch_size]
        # target: [seq_len, batch_size]
        batch_size = source.size(1)  # 批量大小
        target_len = target.size(0)  # 目标序列长度
        target_vocab_size = self.decoder.linear.out_features  # 目标词表大小
        outputs = torch.zeros(target_len, batch_size, target_vocab_size).to(source.device)  # 初始化输出张量
        encoder_outputs, hidden, cell = self.encoder(source)  # 获取编码器输出，隐藏层状态，细胞状态
        input = target[0, :]  # 目标序列第一个词<bos>作为解码器输入
        for t in range(1, target_len):
            output, hidden, cell = self.decoder(input.unsqueeze(0), hidden, cell, encoder_outputs)  # 获取解码器输出
            outputs[t:,] = output.squeeze(1)  # 将解码器输出添加到输出中
            # 教师强制是指在解码器训练时，使用真实目标序值作为输入
            teacher_force = torch.rand(1).item() < teacher_forcing_ratio  # 随机使用教师强制
            input = target[t] if teacher_force else output.argmax(2).squeeze(0)  # 根据是否使用教师强制选择输入词
        return outputs

In [None]:
# 模型训练
def train(model, source, target, batch_size, lr, num_epoch, device):
    model = model.train()  # 设置为训练模式
    dataset = Seq2SeqDataset(source.sequences, target.sequences)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    criterion = nn.CrossEntropyLoss(ignore_index=target.word2idx["<pad>"])  # 交叉熵损失函数，忽略pad的index
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)  # Adam优化器
    for epoch in range(num_epoch):
        loss_accumulate = 0
        for batch_count, (x, y) in enumerate(dataloader):
            x, y = x.to(device), y.to(device)
            output = model(x, y)  # [seq_len, batch_size, output_size]
            output = output.view(-1, output.size(-1))  # [seq_len * batch_size, output_size]
            y = y.reshape(-1)  # [seq_len * batch_size,]
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_accumulate += loss.item()
            print(f"\repoch:{epoch:0>2}[{'='*(int((batch_count+1) / len(dataloader) * 50)):<50}]", end="")
        print(f" loss:{loss_accumulate/len(dataloader):.6f}")


hidden_size = 64  # 隐藏层维度
batch_size = 256  # 批量大小
lr = 1e-3  # 学习率
num_epoch = 2  # 训练轮数
encoder = Encoder(source.size, hidden_size)
decoder = Decoder(target.size, hidden_size)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Seq2Seq(encoder, decoder).to(device)
train(model, source, target, batch_size, lr, num_epoch, device)

In [None]:
# 模型预测
def predict(model, x, source, target, device, max_len=30):
    model.eval()  # 设置为预测模式
    x_idx = [
        [source.word2idx["<bos>"]]
        + [source.word2idx.get(word, source.word2idx["<unk>"]) for word in line]
        + [source.word2idx["<eos>"]]
        for line in x
    ]  # 将输入文本转换为索引
    x_sequences = [line + [source.word2idx["<pad>"]] * (source.seq_len - len(line)) for line in x_idx]  # 填充到指定长度
    x_sequences = torch.tensor(x_sequences).transpose(0, 1).to(device)  # [seq_len, batch_size]
    with torch.no_grad():  # 关闭梯度计算
        encoder_outputs, hidden, cell = model.encoder(x_sequences)  # encoder计算
    batch_size = encoder_outputs.size(1)  # 批量大小
    pred = torch.tensor([[target.word2idx["<bos>"]] * batch_size]).to(device)  # 预测结果张量
    flag = torch.zeros(pred.shape, dtype=torch.int8)  # 标记每个序列是否已经生成结束符
    for i in range(max_len):
        input = pred[i].unsqueeze(0)  # decoder的输入
        with torch.no_grad():
            output, hidden, cell = model.decoder(input, hidden, cell, encoder_outputs)
        this_pred = output.argmax(2)  # 当前预测结果
        flag[this_pred == torch.tensor([target.word2idx["<eos>"]] * batch_size).to(device)] = 1  # 标记生成结束符的序列
        pred = torch.vstack([pred, this_pred])  # 将当前预测结果添加到结果张量中
        if flag.sum().item() == batch_size:  # 如果所有序列都已经生成结束符则退出循环
            break
    return pred


x = ["19800130", "二零零零年一月一日", "July 3, 1996", "1.11.1979"]
pred = predict(model, x, source, target, device)
pred = pred.transpose(0, 1)
pred = [
    "".join([target.vocab[idx] for idx in takewhile(lambda x: x != target.word2idx["<eos>"], line[1:])])
    for line in pred
]  # 将序号转换为词，遇到结束符停止
print(pred)