#1

In [3]:
from google.colab import drive
drive.mount('./mount')


Drive already mounted at ./mount; to attempt to forcibly remount, call drive.mount("./mount", force_remount=True).


process data

In [2]:

import os
import re
from collections import Counter
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

# 检查是否有GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 定义语料库路径
corpus_path = './mount/My Drive/Colab Notebooks/BH/1-2/DL/3/chinese_corpus/'

# # 读取所有txt文件
# texts = []
# for file_name in os.listdir(corpus_path):
#     if file_name.endswith('.txt'):
#         with open(os.path.join(corpus_path, file_name), 'r', encoding='utf-8') as file:
#             texts.append(file.read())

# 读取‘特定’文件
texts = []
for file_name in os.listdir(corpus_path):
    if '越女剑' in file_name and file_name.endswith('.txt'):
        with open(os.path.join(corpus_path, file_name), 'r', encoding='utf-8') as file:
            texts.append(file.read())


# 合并所有文本
corpus = "\n".join(texts)

# 使用正则表达式分词
tokenizer = Counter(re.findall(r'\b\w+\b', corpus))
word_index = {word: idx + 1 for idx, (word, _) in enumerate(tokenizer.items())}
total_words = len(word_index) + 1

# 分句
sentences = corpus.split('\n')

# 创建输入和输出序列
input_sequences = []
for sentence in sentences:
    token_list = [word_index[word] for word in re.findall(r'\b\w+\b', sentence)]
    for i in range(1, len(token_list)):
        n_gram_sequence = token_list[:i+1]
        input_sequences.append(n_gram_sequence)

# 填充序列
max_sequence_len = max(len(seq) for seq in input_sequences)
input_sequences = pad_sequence([torch.tensor(seq) for seq in input_sequences], batch_first=True, padding_value=0)

# 创建训练数据
xs, labels = input_sequences[:, :-1], input_sequences[:, -1]
ys = torch.nn.functional.one_hot(labels, num_classes=total_words).float()

class TextDataset(Dataset):
    def __init__(self, xs, ys):
        self.xs = xs
        self.ys = ys

    def __len__(self):
        return len(self.xs)

    def __getitem__(self, idx):
        return self.xs[idx], self.ys[idx]

dataset = TextDataset(xs, ys)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


Using device: cpu


define and train model

In [3]:
import torch.nn as nn
import torch.optim as optim

class TextGenerationModel(nn.Module):
    def __init__(self, total_words, embed_dim, hidden_dim):
        super(TextGenerationModel, self).__init__()
        self.embedding = nn.Embedding(total_words, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, total_words)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.fc(x[:, -1, :])
        return x

model = TextGenerationModel(total_words, 64, 20).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10  # 为了快速测试，可以将epoch数量减少
for epoch in range(num_epochs):
    for inputs, targets in dataloader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        targets = torch.argmax(targets, axis=1)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


Epoch [1/10], Loss: 3.6302
Epoch [2/10], Loss: 0.2892
Epoch [3/10], Loss: 0.0848
Epoch [4/10], Loss: 0.0421
Epoch [5/10], Loss: 0.0263
Epoch [6/10], Loss: 0.0174
Epoch [7/10], Loss: 0.0353
Epoch [8/10], Loss: 0.0384
Epoch [9/10], Loss: 0.0080
Epoch [10/10], Loss: 0.0103


generate text

In [6]:
def generate_text_seq2seq(seed_text, next_words, model, max_sequence_len):
    model.eval()
    words = re.findall(r'\b\w+\b', seed_text)
    for _ in range(next_words):
        token_list = [word_index[word] for word in words if word in word_index]

        # 保证 token_list 不为空
        if not token_list:
            continue

        token_list = torch.tensor(token_list).unsqueeze(0).to(device)

        # 填充序列
        token_list = pad_sequence([token_list], batch_first=True, padding_value=0).to(dtype=torch.long).to(device)
        token_list = token_list[:, -max_sequence_len+1:]

        # 保证输入是2D或3D
        if len(token_list.shape) == 3:
            token_list = token_list.squeeze(0)

        # 跳过序列长度为0的情况
        if token_list.shape[1] == 0:
            continue

        predicted = model(token_list)
        predicted = torch.argmax(predicted, axis=-1).item()

        # 获取预测的词
        output_word = list(word_index.keys())[list(word_index.values()).index(predicted)]
        words.append(output_word)
    return ' '.join(words)

# 生成文本
print(generate_text_seq2seq("在下", 50, model, max_sequence_len))


在下
