# 语言模型

学习目标
- 学习语言模型，以及如何训练一个语言模型
- 学习torchtext的基本使用方法
    - 构建 vocabulary
    - word to inde 和 index to word
- 学习torch.nn的一些基本模型
    - Linear
    - RNN
    - LSTM
    - GRU
- RNN的训练技巧
    - Gradient Clipping
- 如何保存和读取模型

In [26]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader, Dataset
import numpy as np

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"device : {device}")

BATCH_SIZE = 32
EMBEDDING_DIM = 650
MAX_VOCAB_SIZE = 50_000

device : mps


In [27]:
with open("data/nietzsche.txt", "r", encoding="utf8") as f:
    text = f.read()

#定义tokenizer和vocabulary
tokenizer = get_tokenizer("basic_english")
tokens = tokenizer(text)
vocab = build_vocab_from_iterator([tokens], specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

def word_to_idx(word):
    return vocab[word]

def idx_to_word(idx):
    return vocab.itos[idx]
#将数据转化为索引
data = [word_to_idx(token) for token in tokens]

## 定义DataSet和DataLoader

In [28]:
class TextDataset(Dataset):
    def __init__(self, data, seq_length) -> None:
        super().__init__()
        self.data = data
        self.seq_length = seq_length

    def __len__(self):
        return len(self.data) - self.seq_length
    
    def __getitem__(self, idx):
        x = self.data[idx:idx+self.seq_length]
        y = self.data[idx+1:idx+self.seq_length+1]
        return torch.tensor(x), torch.tensor(y)
    
seq_length = 30
batch_size = 64

dataset = TextDataset(data, seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)


## 定义LSTM模型

In [29]:
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden):
        x = self.embedding(x) #输入x:[batch_size, seq_length] -> [batch_size, seq_length, embedding_dim]
        x, hidden = self.lstm(x, hidden) #输入x:[batch_size, seq_length, embedding_dim] -> [batch_size, seq_length, hidden_dim]
        #hidden的shape为 [num_layers, batch_size, hidden_dim]
        x = x.reshape(-1, x.size(2)) # batch_size * seq_length, hidden_dim
        x = self.fc(x) # x: [batch_size * seq_length, hidden_dim] => [batch_size * seq_length, vocab_size]
        return x, hidden
    
    def init_hidden(self, batch_size):
        weight = next(self.parameters())
        return (weight.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size),
                weight.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size))
    
vocab_size = len(vocab)
print(f"vocab_size : {vocab_size}")
hidden_dim = 256
embedding_dim = 128
num_layers = 2

model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers).to(device)

vocab_size : 11747


## 训练模型

In [31]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
clip = 5 # 用于梯度裁剪

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    hidden = model.init_hidden(batch_size)

    for inputs, targets in dataloader:
        hidden = tuple([h.data for h in hidden]) # 将隐藏层的梯度置零， h.data表示产生一个新的tensor，与h共享内存但不会跟踪梯度
        inputs, targets = inputs.to(device), targets.to(device)
        # print(f"inputs : {inputs.shape}, targets : {targets.shape}, hidden : {hidden[0].shape}")
        optimizer.zero_grad()
        output, hidden = model(inputs, hidden)
        loss = criterion(output, targets.view(-1)) #targets.view(-1)将targets展平,大小为batch_size * seq_length
        loss.backward()

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

#保存模型
torch.save(model.state_dict(), "lstm_model.pth")

Epoch 1/10, Loss: 4.367384910583496
Epoch 2/10, Loss: 3.279768705368042
Epoch 3/10, Loss: 2.437866449356079
Epoch 4/10, Loss: 2.0008645057678223
Epoch 5/10, Loss: 1.626522421836853
Epoch 6/10, Loss: 1.4504024982452393
Epoch 7/10, Loss: 1.1741880178451538
Epoch 8/10, Loss: 1.0635603666305542
Epoch 9/10, Loss: 0.8578975200653076
Epoch 10/10, Loss: 0.7697316408157349


## 读取模型和检验模型

In [32]:
# 读取模型
model = LSTMModel(vocab_size, embedding_dim, hidden_dim, num_layers)
model.load_state_dict(torch.load('lstm_model.pth'))
model.eval()

LSTMModel(
  (embedding): Embedding(11747, 128)
  (lstm): LSTM(128, 256, num_layers=2, batch_first=True)
  (fc): Linear(in_features=256, out_features=11747, bias=True)
)

In [40]:
def idx_to_word(idx):
    return vocab.lookup_token(idx)

In [57]:
#使用训练好的模型生成文本

def generate_text(model, start_text, vocab, tokenizer, index_to_word, word_to_index, gen_length=100):
    model.eval()
    generated_text = start_text
    tokens = tokenizer(start_text)
    indices = [word_to_index(token) for token in tokens]
    input_seq = torch.tensor(indices).unsqueeze(0)
    
    hidden = model.init_hidden(1) # batch_size = 1
    for _ in range(gen_length):
        output, hidden = model(input_seq, hidden) # output: [batch_size * seq_length, vocab_size]
        output = output[-1] # 取最后一个词
        output_dist = nn.functional.softmax(output, dim=-1).data
        # top_index = torch.max(output_dist, dim=-1)[1].item() # 取概率最大的词
        #使用多形式分布采样的方式可以增加生成文本的多样性，而不是总是取概率最大的词
        top_index = torch.multinomial(output_dist, 1).item() # 从output_dist中采样一个词, 采样概率由output_dist决定
        
        next_word = index_to_word(top_index)
        generated_text += " " + next_word

        input_seq = torch.cat((input_seq[:, 1:], torch.tensor([[top_index]])), dim=-1)

    return generated_text

start_text = "The meaning of life is"
generate_text(model, start_text, vocab, tokenizer, idx_to_word, word_to_idx, 20)

'The meaning of life is unbearable . a powerful , even against the ancient greeks ( or non-spirit ) of the race . a man'

In [58]:
import os
os.remove("lstm_model.pth")