In [124]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader, Dataset
import numpy as np


BATCH_SIZE = 32
EMBEDDING_DIM = 650
MAX_VOCAB_SIZE = 50_000

In [125]:
with open("data/nietzsche.txt", "r", encoding="utf8") as f:
    text = f.read()

#定义tokenizer和vocabulary
tokenizer = get_tokenizer("basic_english")
tokens = tokenizer(text)
vocab = build_vocab_from_iterator([tokens], specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])

def word_to_idx(word):
    return vocab[word]

def idx_to_word(idx):
    return vocab.itos[idx]
#将数据转化为索引
data = [word_to_idx(token) for token in tokens]

## 定义DataSet和DataLoader

In [126]:
class TextDataset(Dataset):
    def __init__(self, data, seq_length) -> None:
        super().__init__()
        self.data = data
        self.seq_length = seq_length

    def __len__(self):
        return len(self.data) - self.seq_length
    
    def __getitem__(self, idx):
        x = self.data[idx:idx+self.seq_length]
        y = self.data[idx+1:idx+self.seq_length+1]
        return torch.tensor(x), torch.tensor(y)
    
seq_length = 30
batch_size = 64

dataset = TextDataset(data, seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

x, y = next(iter(dataloader))
print(x.shape, y.shape)

torch.Size([64, 30]) torch.Size([64, 30])


In [127]:
import torch.nn.functional as F

def g(x):
    return torch.where(x >= 0, x+0.5, torch.sigmoid(x))
def log_g(x):
    return torch.where(x >= 0, (F.relu(x)+0.5).log(), -F.softplus(-x))

def parallel_scan_log(log_coeffs, log_values):
    # log_coeffs: (batch_size, seq_len, hidden_size)
    # log_values: (batch_size, seq_len + 1, hidden_size)
    a_star = F.pad(torch.cumsum(log_coeffs, dim=1), (0, 0, 1, 0))
    log_h0_plus_b_star = torch.logcumsumexp(log_values - a_star, dim=1)
    log_h = a_star + log_h0_plus_b_star
    return torch.exp(log_h)[:, 1:]

class ParallelLogMiniGRU(nn.Module):

    def __init__(self, embedding_dim, num_hiddens, device):
        super().__init__()
        
        self.linear_z = nn.Linear(embedding_dim, num_hiddens)
        self.linear_h = nn.Linear(embedding_dim, num_hiddens)
        

        self.embedding_dim = embedding_dim
        self.device = device
        self.num_hiddens = num_hiddens
        

    def forward(self, x, h_0):
        # x: [batch_size, seq_len, embedding_dim]
        # h_0: [batch_size,1,  num_hiddens]
        
        k = self.linear_z(x)
        log_z = -F.softplus(-k)                   # log_z : [batch_size, seq_len, num_hiddens]
        log_coeffs = -F.softplus(k)
        # print(h_0.shape, log_z.shape)
        log_h_0 = log_g(h_0)                      # log_h_0 : [batch_size, 1, num_hiddens]
        log_tilde_h = log_g(self.linear_h(x))     # log_tilde_h : [batch_size, seq_len, num_hiddens]
        
        h = parallel_scan_log(log_coeffs, torch.cat([log_h_0, log_z + log_tilde_h], dim=1))
        return h, h[:, -1].unsqueeze(1)
    
    def begin_state(self, batch_size, device):
        return torch.zeros((batch_size, self.num_hiddens), device=device).unsqueeze(1)

In [128]:
class GRULauangeModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, device, num_layers=None) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.miniGru = ParallelLogMiniGRU(embedding_dim, hidden_dim, device)
        self.output = nn.Linear(hidden_dim, vocab_size)

        self.device = device

    def forward(self, x, hidden):
        x = self.embedding(x) #输入x:[batch_size, seq_length] -> [batch_size, seq_length, embedding_dim]
        x, hidden = self.miniGru(x, hidden) #输入x:[batch_size, seq_length, embedding_dim] -> [batch_size, seq_length, hidden_dim]
        #hidden的shape为 [num_layers, batch_size, hidden_dim]
        x = x.reshape(-1, x.size(2)) # batch_size * seq_length, hidden_dim
        x = self.output(x) # x: [batch_size * seq_length, hidden_dim] => [batch_size * seq_length, vocab_size]
        return x, hidden
    
    def init_hidden(self, batch_size):
        return self.miniGru.begin_state(batch_size, self.device)
        # weight = next(self.parameters())
        # return ()
        # return (weight.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size),
        #         weight.new_zeros(self.lstm.num_layers, batch_size, self.lstm.hidden_size))


vocab_size = len(vocab)
print(f"vocab_size : {vocab_size}")
hidden_dim = 256
embedding_dim = 128
num_layers = 2
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # 由于mps不能计算一些东西，现在只能用cpu
print(f"device : {device}")

model = GRULauangeModel(vocab_size, embedding_dim, hidden_dim, device, num_layers).to(device)

vocab_size : 11747
device : cpu


## 训练模型

In [129]:
from tqdm import tqdm
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
clip = 5 # 用于梯度裁剪

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    hidden = model.init_hidden(batch_size)

    # 创建一个tqdm对象，这里我们用range(100)作为示例迭代器
    pbar = tqdm(total=(len(data) - seq_length) / batch_size)  # 总进度设置为100
    for inputs, targets in dataloader:
        hidden = [h.data for h in hidden] # 将隐藏层的梯度置零， h.data表示产生一个新的tensor，与h共享内存但不会跟踪梯度
        hidden = torch.stack(hidden, 0)
        # print(hidden.shape)
        inputs, targets = inputs.to(device), targets.to(device)
        # print(f"inputs : {inputs.shape}, targets : {targets.shape}, hidden : {hidden[0].shape}")
        optimizer.zero_grad()
        output, hidden = model(inputs, hidden)
        loss = criterion(output, targets.view(-1)) #targets.view(-1)将targets展平,大小为batch_size * seq_length
        loss.backward()
        pbar.set_description(f"Loss: {loss.item()}")
        pbar.update(1)

        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")

#保存模型
torch.save(model.state_dict(), "lstm_model.pth")

Loss: 5.81491756439209:  11%|█         | 186/1758.140625 [00:23<03:20,  7.85it/s]  

KeyboardInterrupt: 