# Assignment
利用给定语料库（金庸小说语料如上链接）通过使用LSTM与Transformer来生成武侠小说, 并比较两种生成方法的优缺点。提示：可以直接从给定语料库训练或者在更大的语料库上训练并利用金庸小说数据进行微调。 (Deadline 30 April)

In [23]:
import os
import jieba
import re
import random
import opencc
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import json
import math
import tqdm

In [24]:
#定义LSTM模型
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, num_layers=2):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)  # (batch_size, seq_len, embedding_dim)
        out, hidden = self.lstm(x, hidden)  # out: (batch_size, seq_len, hidden_dim)
        out = self.fc(out[:, -1, :])  # 只用最后一个时间步输出
        return out
    
#定义transformer模型
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.pe = pe.unsqueeze(0)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)].to(x.device)
        return x

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, num_heads=4, hidden_dim=256, num_layers=2, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_encoder = PositionalEncoding(embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout)
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=1)#一层decoder
        self.fc = nn.Linear(embed_dim, vocab_size)

    def generate_square_subsequent_mask(self, sz):
        # 下三角 mask (tgt_len, tgt_len)
        mask = torch.triu(torch.ones(sz, sz), diagonal=1).bool()
        return mask.to(next(self.parameters()).device)
    
    def forward(self,x,tgt):
        # print(f"Received args: x shape={x.shape}, tgt shape={tgt.shape}")
        # x shape: (batch, seq_len)
        # tgt shape: (batch, seq_len)
        x = self.embedding(x)                     # -> (batch, seq_len, embed_dim)
        x = self.pos_encoder(x)                   # 加位置编码
        x = x.transpose(0, 1)                     # -> (seq_len, batch, embed_dim)
        memory = self.transformer_encoder(x)           # -> (seq_len, batch, embed_dim)
        tgt = self.embedding(tgt)                 # -> (batch, seq_len, embed_dim)
        tgt = self.pos_encoder(tgt)               # 加位置编码
        tgt = tgt.transpose(0, 1)                 # -> (seq_len, batch, embed_dim)
        tgt_mask = self.generate_square_subsequent_mask(tgt.size(0))
        x = self.transformer_decoder(tgt, memory =memory,tgt_mask = tgt_mask)        # -> (seq_len, batch, embed_dim)
        x = self.fc(x[-1])                        # 只取最后一个时间步的输出
        return x

In [25]:
#定义数据集
class TextDataset(Dataset):
    def __init__(self, encoded_texts, length):
        self.inputs = []
        self.targets = []
        for i in range(len(encoded_texts)-length):
            if len(encoded_texts) > length:
                input_seq = encoded_texts[i:i+length]
                target_seq = encoded_texts[i+length]
                self.inputs.append(input_seq)
                self.targets.append(target_seq)
        self.inputs = torch.tensor(self.inputs, dtype=torch.long)
        self.targets = torch.tensor(self.targets, dtype=torch.long)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]
    
#transformer所需数据集
class TransformerDataset(Dataset):
    def __init__(self, encoded_texts, length,shift):
        self.inputs = []
        self.targets = []
        for i in range(len(encoded_texts)-length):
            if len(encoded_texts) > length:
                input_seq = encoded_texts[i:i+length]
                target_seq = encoded_texts[i+shift:i+length+shift]
                # 这里的shift是为了让target_seq比input_seq向后滑动shift
                #检查长度，如果长度不够就跳过
                if len(target_seq) != length:
                    continue
                self.inputs.append(input_seq)
                self.targets.append(target_seq)
        self.inputs = torch.tensor(self.inputs, dtype=torch.long)
        self.targets = torch.tensor(self.targets, dtype=torch.long)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [26]:
comma_list  = [
    '3002', 'FF1F', 'FF01', '3010', '3011', 'FF0C', '3001', 'FF1B',
    'FF1A', '300C', '300D', '300E', '300F', '2019', '201C', '201D',
    '2018', 'FF08', 'FF09', '3014', '3015', '2026', '2013', 'FF0E',
    '2014', '300A', '300B', '3008', '3009'
]

def preprocess_text(text,keep_stopwords=True):

        converter = opencc.OpenCC('t2s.json')
        stopwords = [line.strip() for line in open('D:\homework2025\DLandNLP\DL-nlp2025-main\cn_stopwords.txt', 'r', encoding='utf-8').readlines()]
        #简体转繁体
        text = converter.convert(text)
        #分词
    
        #保留汉字\u4e00-\u9fa5和标点
        text = re.sub(r'[^\u4e00-\u9fa5' + ''.join([chr(int(i, 16)) for i in comma_list]) + ']+', ' ', text)
        words = list(jieba.cut(text))
        #去除停用词
        if not keep_stopwords:
            words = [word for word in words if word not in stopwords]
        return words

def getvocab(texts_folder,ind_file):
    files = []
    with open(ind_file, 'r', encoding='ANSI') as f:
        # 读取文件列表并加上.txt后缀
        line = f.readline()
        #用,分割
        files = line.split(',')


    # print(files)
    texts = []
    for file in [file.strip() + '.txt' for file in files]:
        #读取文件内容
        with open(os.path.join(texts_folder, file), 'r', encoding='ANSI') as f:
            text = f.read()
            texts.append(text)
    print(len(texts))
    #预处理
    #删除开头和结尾的‘本书来自www.cr173.com免费txt小说下载站 更多更新免费电子书请关注www.cr173.com’
    for i in range(len(texts)):
        texts[i] = texts[i].replace("本书来自www.cr173.com免费txt小说下载站\n更多更新免费电子书请关注www.cr173.com", "")

    for i in range(len(texts)):

        #删除空白符
        texts[i] = texts[i].replace("\u3000", "")
    def fix_chinese_text(text):
        # 移除句中换行（不在句号、问号、叹号后）
        text = re.sub(r'(?<![。！？])\n(?!\n)', '', text)  
        # 统一段落换行（确保段落间保留空行）
        text = re.sub(r'\n{2,}', '\n\n', text)  
        return text

    for i in range(len(texts)):
        texts[i] = fix_chinese_text(texts[i])
    #先分段
    paragraphs = []
    for text in texts:
        # 按照两个换行符分割文本
        paragraph = text.split('\n')
        paragraphs.extend(paragraph)
    all_words = []
    for text in paragraphs:
        words = preprocess_text(text,keep_stopwords=True)#保留停用词
        # print(words[:10])  # 打印前10个分词
        #去除空字符串
        words = [word for word in words if word !='']
        #去除空段落
        if len(words) < 1:
            continue
        #添加<BOS>和<EOS>
        words = ['<BOS>'] + words + ['<EOS>']
        all_words.extend(words)
    #建立词表
    vocab = sorted(set(all_words))
    #保存json词表
    with open('vocab.json', 'w', encoding='utf-8') as f:
        json.dump(vocab, f, ensure_ascii=False, indent=4)
    return all_words,vocab



In [27]:
#训练代码
def train_model(model,vocab_size,dataloader,device, criterion, optimizer, num_epochs=10):
    
    model.to(device)
    #模型名
    model_name = type(model).__name__
    print(f"Training {model_name} model with {num_epochs} epochs...")
    for epoch in range(num_epochs):
        time_bar = tqdm.tqdm(total=len(dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch")
        model.train()
        total_loss = 0.0
        best_loss = float('inf')
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            if model_name == 'LSTMModel':
                outputs = model(inputs)
                loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
            else:
                outputs = model(inputs,targets)
                loss = criterion(outputs.view(-1, vocab_size), targets[:,-1])
            
            loss.backward()
            optimizer.step()
            time_bar.update(1)
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), model_name+'_best_model.pth')
        print("*" * 20+'正在训练'+model_name+"*" * 20)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_loss:.4f}")

In [28]:
import torch
import torch.nn as nn
import torch.nn.init as init
import math
import numpy as np

def load_and_migrate_embedding(
    model,
    checkpoint_path,
    old_vocab: dict,
    new_vocab: dict,
    embedding_dim: int,
    pad_token: str = "<PAD>"
):
    """
    从旧模型checkpoint中加载embedding参数，迁移到新模型词表上，
    并重初始化输出层fc（linear）。
    """

    # === 加载旧模型参数 ===
    print(f"📥 正在加载 checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
        checkpoint = checkpoint["state_dict"]  # 兼容多种保存格式

    # === 提取旧embedding参数 ===
    if "embedding.weight" not in checkpoint:
        raise KeyError("❌ checkpoint 中未找到 embedding.weight")

    old_embedding_weight = checkpoint["embedding.weight"].cpu().numpy()
    old_vocab_size = old_embedding_weight.shape[0]

    print(f"✅ 旧embedding大小: {old_embedding_weight.shape}, 旧词表大小: {len(old_vocab)}")

    # === 创建新的embedding矩阵 ===
    new_vocab_size = len(new_vocab)
    new_embedding_matrix = np.random.normal(0, 1, (new_vocab_size, embedding_dim)).astype(np.float32)

    num_copied = 0
    for word, new_idx in new_vocab.items():
        old_idx = old_vocab.get(word, None)
        if old_idx is not None and old_idx < old_vocab_size:
            new_embedding_matrix[new_idx] = old_embedding_weight[old_idx]
            num_copied += 1

    print(f"🔁 迁移embedding成功：{num_copied}/{new_vocab_size} 个token")

    # === 替换embedding层 ===
    pad_idx = new_vocab.get(pad_token, 0)
    model.embedding = nn.Embedding.from_pretrained(
        torch.tensor(new_embedding_matrix),
        freeze=False,
        padding_idx=pad_idx
    )

    # === 重新初始化 fc 层 ===
    hidden_dim = model.fc.in_features
    model.fc = nn.Linear(hidden_dim, new_vocab_size)

    init.kaiming_uniform_(model.fc.weight, a=math.sqrt(5))
    if model.fc.bias is not None:
        fan_in, _ = init._calculate_fan_in_and_fan_out(model.fc.weight)
        bound = 1 / math.sqrt(fan_in)
        init.uniform_(model.fc.bias, -bound, bound)

    print(f"✅ fc 层已重新初始化，输出维度变为：{new_vocab_size}")


In [29]:
texts_folder = r"D:\homework2025\DLandNLP\texts"
ind_file = r"D:\homework2025\DLandNLP\texts\inf.txt"

words,vocab = getvocab(texts_folder,ind_file)
print(words[:10])
print(len(vocab))

16
['<BOS>', '白马', '啸', '西风', '得', '得', '得', '，', '得', '得']
158631


In [30]:
#
special_tokens_add =['<PAD>', '<UNK>']
vocab = special_tokens_add + vocab
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
encoded_texts = [word2idx[word] for word in words]
# print(encoded_texts[:10])




In [33]:
dataset = TextDataset(encoded_texts, length=8)
batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
transformer_dataset = TransformerDataset(encoded_texts, length=8,shift=1)
transformer_dataloader = DataLoader(transformer_dataset, batch_size=batch_size, shuffle=True)
#设置设备为GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
#定义模型参数
embedding_dim = 64
hidden_dim = 128
num_layers = 2

#定义模型
model = LSTMModel(len(vocab), embedding_dim, hidden_dim, num_layers).to(device)
model2 = TransformerModel(len(vocab), embed_dim=embedding_dim, num_heads=4, hidden_dim=hidden_dim, num_layers=num_layers).to(device)
#定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [36]:
#加载旧词表
with open('pretrain/vocab2.json', 'r', encoding='utf-8') as f:
    old_vocab = json.load(f)
word2idx_old = {word: idx for idx, word in enumerate(old_vocab)}
#加载模型参数
old_LSTMmodel_path = 'pretrain/LSTMModel_best_model.pth'
old_transformer_model_path = 'pretrain/TransformerModel_best_model.pth'
#加载LSTM模型参数
load_and_migrate_embedding(
    model,
    old_LSTMmodel_path,
    word2idx_old,
    word2idx,
    embedding_dim=embedding_dim,
    pad_token='<PAD>'
)
#加载transformer模型参数
load_and_migrate_embedding(
    model2,
    old_transformer_model_path,
    word2idx_old,
    word2idx,
    embedding_dim=embedding_dim,
    pad_token='<PAD>')

📥 正在加载 checkpoint: pretrain/LSTMModel_best_model.pth


RuntimeError: PytorchStreamReader failed reading zip archive: failed finding central directory

In [19]:
#训练模型
train_model(model, len(vocab), dataloader, device, criterion, optimizer, num_epochs=10)
# train_model(model2, len(vocab), transformer_dataloader, device, criterion, optimizer, num_epochs=10)

Training LSTMModel model with 10 epochs...


Epoch 1/10:   0%|          | 0/44018 [00:00<?, ?batch/s]

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [14]:
optimizer = optim.Adam(model2.parameters(), lr=0.001)
train_model(model2, len(vocab), transformer_dataloader, device, criterion, optimizer, num_epochs=10)

Training TransformerModel model with 10 epochs...




********************正在训练TransformerModel********************
Epoch [1/10], Loss: 1.3719



Epoch 1/10: 100%|██████████| 174602/174602 [1:12:31<00:00, 40.13batch/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


********************正在训练TransformerModel********************
Epoch [2/10], Loss: 1.2459


Epoch 2/10: 100%|██████████| 174602/174602 [1:09:15<00:00, 42.02batch/s]


********************正在训练TransformerModel********************
Epoch [3/10], Loss: 1.2323



Epoch 3/10: 100%|██████████| 174602/174602 [1:07:55<00:00, 42.84batch/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


********************正在训练TransformerModel********************
Epoch [4/10], Loss: 1.2212


Epoch 4/10: 100%|██████████| 174602/174602 [1:07:38<00:00, 43.02batch/s]


********************正在训练TransformerModel********************
Epoch [5/10], Loss: 1.2207



Epoch 5/10: 100%|██████████| 174602/174602 [1:07:34<00:00, 43.06batch/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


********************正在训练TransformerModel********************
Epoch [6/10], Loss: 1.2209


Epoch 6/10: 100%|██████████| 174602/174602 [1:07:45<00:00, 42.94batch/s]


********************正在训练TransformerModel********************
Epoch [7/10], Loss: 1.2230



Epoch 7/10: 100%|██████████| 174602/174602 [1:07:48<00:00, 42.91batch/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


********************正在训练TransformerModel********************
Epoch [8/10], Loss: 1.2226


Epoch 8/10: 100%|██████████| 174602/174602 [1:07:59<00:00, 42.80batch/s]


********************正在训练TransformerModel********************
Epoch [9/10], Loss: 1.2221



Epoch 9/10: 100%|██████████| 174602/174602 [1:07:33<00:00, 43.07batch/s]

[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


********************正在训练TransformerModel********************
Epoch [10/10], Loss: 1.2223





In [37]:
def top_k_sampling(logits, k=10, temperature=1.0):
    logits = logits / temperature
    top_k_logits, top_k_indices = torch.topk(logits, k)
    probs = F.softmax(top_k_logits, dim=-1)
    next_token = top_k_indices[torch.multinomial(probs, 1)]
    return next_token.item()

def generate_text(model, start_words, word2idx, idx2word, length=50, temperature=1.0):
    model.eval()
    generated = start_words[:]
    input_seq = torch.tensor([word2idx[word] for word in start_words], dtype=torch.long).unsqueeze(0).to(device)

    with torch.no_grad():
        for _ in range(length):
            output = model(input_seq)  # (1, seq_len, vocab_size)
            logits = output.squeeze(0)
            next_word_idx = top_k_sampling(logits, k=10, temperature=temperature)
            generated.append(idx2word[next_word_idx])
            input_seq = torch.cat([input_seq, torch.tensor([[next_word_idx]], device=device)], dim=1)

    return ''.join(generated)

In [39]:
#测试生成
start_text = '张无忌立于山顶,'
#分词
start_words = preprocess_text(start_text,keep_stopwords=True)
#加载模型参数
model.load_state_dict(torch.load('LSTMModel_best_model.pth'))
#生成
generated_text = generate_text(model, start_words, word2idx, idx2word, length=50,temperature=0.8)
print(generated_text)


张无忌立于山顶 死灰「倒轻响的金头  固守。贞妃固守，杨将军固守，计杨将军固守。杨将军杨将军那实青面獠牙固守赌神固守，赌神固守赌神 固守，贞妃 狄云奇 大不大，赌神青面獠牙固守固守赌神固守。杨将军固守，


In [42]:
#测试生成
start_text = '张无忌立于山顶,'
#分词
start_words = preprocess_text(start_text,keep_stopwords=True)
#加载模型参数
model2.load_state_dict(torch.load('TransformerModel_best_model.pth'))
#生成
generated_text = generate_text(model, start_words, word2idx, idx2word, length=50,temperature=0.7)
print(generated_text)

张无忌立于山顶 死灰 来两穴，一善，杨将军贞妃固守，贞妃贞妃从狄云奇从贞妃固守固守赌神固守，贞妃固守。贞妃贞妃固守，杨将军贞妃固守。贞妃固守贞妃固守贞妃固守狄云奇 金头贞妃固守 固守贞妃固守贞妃固守
