## 2. 使用LSTM进行语言建模

<img src='img/rnn-lm.jpg' width=500>

reference:
- https://pytorch.org/tutorials/beginner/transformer_tutorial.html
- https://pytorch.org/tutorials/advanced/dynamic_quantization_tutorial.html?highlight=language%20modeling
- https://github.com/pytorch/examples/tree/master/word_language_model


### 设置超参数

一般，这些超参数是通过 argparse.ArgumentParser 从命令行中获取的。

In [1]:
import argparse

hparams = argparse.Namespace(**{
    'batch_size': 16,
    'learning_rate': 5,
    'max_grad_norm': 1.,
    'bptt': 32,  # sequence_length
    'dropout': 0.2,
    'embedding_dim': 200,
    'hidden_dim': 200,
    'n_layers': 4,
    'tie_weights': True,
    'seed': 42,
    'num_train_epochs': 20,
    'lm_data_dir': 'data/PennTreebank',
    'model_save_path': 'data/save_model/lstm_lm.path',
    'temperature': 1.
})

hparams

Namespace(batch_size=16, bptt=32, dropout=0.2, embedding_dim=200, hidden_dim=200, learning_rate=5, lm_data_dir='data/PennTreebank', max_grad_norm=1.0, model_save_path='data/save_model/lstm_lm.path', n_layers=4, num_train_epochs=20, seed=42, temperature=1.0, tie_weights=True)

### 加载数据

- 加载数据
- tokenize
- 建立词表，并将词映射为id
- 将数据打包到batch中

In [2]:
from PennTreebankCorpus import Corpus


corpus = Corpus(root_dir=hparams.lm_data_dir)
corpus.load_datasets()
train_data, val_data, test_data = corpus.preprocess(hparams.batch_size)

hparams.vocab_size = corpus.vocab_size()

hparams.vocab_size

1. Tokenization
---------------------
train: tokenizing ...
val: tokenizing ...
test: tokenizing ...
Done!


2. Build Vocab
Done!


3. To ids
Done!


4. batchify
Done!




9924

In [3]:
train_data.size()
# 加载词向量

torch.Size([60659, 16])

### 建立模型

In [4]:
import torch
from torch import nn

class LSTMModel(nn.Module):
    
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams    

        self.drop = nn.Dropout(hparams.dropout)
        self.embedding = nn.Embedding(hparams.vocab_size, hparams.embedding_dim)
        self.rnn_layer = nn.LSTM(
            hparams.embedding_dim,
            hparams.hidden_dim, 
            hparams.n_layers,
            dropout=hparams.dropout)

        self.decoder = nn.Linear(hparams.hidden_dim, hparams.vocab_size)
        if hparams.tie_weights:
            assert hparams.embedding_dim == hparams.hidden_dim
            self.decoder.weight = self.embedding.weight

        # 参数初始化
        self.init_weights()

    def init_weights(self):
        init_range = 0.1
        self.embedding.weight.data.uniform_(-init_range, init_range)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-init_range, init_range)

    def forward(self, input, hidden):

        # (L, B) => (L, B, embedding_dim)
        emb = self.embedding(input)        
        emb = self.drop(emb)

        # emb (L, B, embedding_dim) => output (L, B, hidden_dim)
        output, hidden = self.rnn_layer(emb, hidden)
        output = self.drop(output)
        
        # (L, B, hidden_dim) => (L, B, vocab_size)
        # 为每个位置预测下个词
        decoded = self.decoder(output)
        return decoded, hidden

    def init_hidden(self, batch_size):
        """
        初始化 第一个隐状态和细胞状态
        """
        weight = next(self.parameters())
        return (weight.new_zeros(self.hparams.n_layers, batch_size, self.hparams.hidden_dim),
                weight.new_zeros(self.hparams.n_layers, batch_size, self.hparams.hidden_dim))


In [5]:
model = LSTMModel(hparams)
model.cuda()


LSTMModel(
  (drop): Dropout(p=0.2, inplace=False)
  (embedding): Embedding(9924, 200)
  (rnn_layer): LSTM(200, 200, num_layers=4, dropout=0.2)
  (decoder): Linear(in_features=200, out_features=9924, bias=True)
)

In [6]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=hparams.learning_rate, momentum=0.9)

### 训练

In [7]:
def get_batch(source, i, hparams):
    seq_len = min(hparams.bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len]
    return data, target


def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""

    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [8]:
from tqdm import tqdm
import math

def train(model, train_data, loss_func, optimizer, epoch_idx, hparams):
    model.train()
    hidden = model.init_hidden(hparams.batch_size)
    
    pbar = tqdm(range(0, train_data.size(0)-1, hparams.bptt))
    pbar.set_description(f'Epoch {epoch_idx}')

    
    for i in pbar:
        # left_context: (L, B)
        # target: (L, B)
        left_context, targets = get_batch(train_data, i, hparams)

        left_context = left_context.cuda()
        targets = targets.cuda()

        output, hidden = model(left_context, hidden)
        
        optimizer.zero_grad()
        loss = loss_func(output.view(-1, hparams.vocab_size), targets.view(-1))
        loss.backward()

        # 梯度裁剪，防止梯度爆炸
        torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.max_grad_norm)
        optimizer.step()

        # 混淆度：评估语言模型的一个指标
        ppl = math.exp(loss.item())
        pbar.set_postfix(loss=loss.item(), ppl=ppl)

        # 中断梯度
        hidden = repackage_hidden(hidden)

In [9]:
def evaluate(model, test_val_data, loss_func, hparams):
    model.eval()
    hidden = model.init_hidden(hparams.batch_size)
    total_loss = 0.

    hidden = model.init_hidden(hparams.batch_size)
    with torch.no_grad():
        n_steps = 0
        pbar = tqdm(range(0, test_val_data.size(0)-1, hparams.bptt))
        pbar.set_description('Valid')
        for i in pbar:
            left_context, targets = get_batch(test_val_data, i, hparams)
            
            left_context = left_context.cuda()
            targets = targets.cuda()

            output, hidden = model(left_context, hidden)
            hidden = repackage_hidden(hidden)
            loss = loss_func(output.view(-1, hparams.vocab_size), targets.view(-1))
            total_loss += loss.item()

            n_steps += 1
    avg_loss = total_loss / n_steps
    ppl = math.exp(avg_loss)
    return avg_loss, ppl


In [10]:
best_val_loss = None
learning_rate = hparams.learning_rate

for epoch_idx in range(hparams.num_train_epochs):
    train(model, train_data, loss_func, optimizer, epoch_idx+1, hparams)
    val_loss, ppl = evaluate(model, val_data, loss_func, hparams)
    print(f'\r[Validation] loss: {val_loss:.4f}, PPL: {ppl:.0f}, LR: {learning_rate}                       ')
    
    if not best_val_loss or val_loss < best_val_loss:
        torch.save(model.state_dict(), hparams.model_save_path)
        print(F'save model to {hparams.model_save_path}\n')
        best_val_loss = val_loss
    else:
        for param_group in optimizer.param_groups:
            learning_rate /= 4
            param_group['lr'] = learning_rate


Epoch 1: 100%|████████████████████████████████████████████████| 1896/1896 [00:16<00:00, 113.00it/s, loss=5.27, ppl=194]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 414.81it/s]


[Validation] loss: 5.0441, PPL: 155, LR: 5                       
save model to data/save_model/lstm_lm.path



Epoch 2: 100%|████████████████████████████████████████████████| 1896/1896 [00:16<00:00, 113.94it/s, loss=5.08, ppl=162]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 418.24it/s]


[Validation] loss: 4.8678, PPL: 130, LR: 5                       
save model to data/save_model/lstm_lm.path



Epoch 3: 100%|████████████████████████████████████████████████| 1896/1896 [00:16<00:00, 112.98it/s, loss=5.01, ppl=150]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 412.54it/s]


[Validation] loss: 4.8040, PPL: 122, LR: 5                       
save model to data/save_model/lstm_lm.path



Epoch 4: 100%|████████████████████████████████████████████████| 1896/1896 [00:17<00:00, 110.96it/s, loss=4.93, ppl=139]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 417.09it/s]


[Validation] loss: 4.7751, PPL: 119, LR: 5                       
save model to data/save_model/lstm_lm.path



Epoch 5: 100%|████████████████████████████████████████████████| 1896/1896 [00:16<00:00, 111.82it/s, loss=4.92, ppl=137]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 385.25it/s]


[Validation] loss: 4.7862, PPL: 120, LR: 5                       


Epoch 6: 100%|████████████████████████████████████████████████| 1896/1896 [00:15<00:00, 120.39it/s, loss=4.76, ppl=117]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 414.81it/s]


[Validation] loss: 4.5971, PPL: 99, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 7: 100%|████████████████████████████████████████████████| 1896/1896 [00:16<00:00, 113.28it/s, loss=4.63, ppl=103]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 413.67it/s]


[Validation] loss: 4.5595, PPL: 96, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 8: 100%|████████████████████████████████████████████████| 1896/1896 [00:16<00:00, 116.56it/s, loss=4.63, ppl=102]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 414.81it/s]


[Validation] loss: 4.5301, PPL: 93, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 9: 100%|███████████████████████████████████████████████| 1896/1896 [00:16<00:00, 114.33it/s, loss=4.53, ppl=93.1]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 383.30it/s]


[Validation] loss: 4.5083, PPL: 91, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 10: 100%|███████████████████████████████████████████████| 1896/1896 [00:15<00:00, 123.51it/s, loss=4.5, ppl=89.8]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 419.40it/s]


[Validation] loss: 4.4906, PPL: 89, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 11: 100%|██████████████████████████████████████████████| 1896/1896 [00:15<00:00, 124.20it/s, loss=4.51, ppl=90.9]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 421.74it/s]


[Validation] loss: 4.4777, PPL: 88, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 12: 100%|██████████████████████████████████████████████| 1896/1896 [00:15<00:00, 125.53it/s, loss=4.44, ppl=84.5]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 418.24it/s]


[Validation] loss: 4.4633, PPL: 87, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 13: 100%|██████████████████████████████████████████████| 1896/1896 [00:15<00:00, 123.54it/s, loss=4.41, ppl=82.3]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 417.09it/s]


[Validation] loss: 4.4512, PPL: 86, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 14: 100%|██████████████████████████████████████████████| 1896/1896 [00:15<00:00, 122.66it/s, loss=4.41, ppl=82.3]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 411.42it/s]


[Validation] loss: 4.4379, PPL: 85, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 15: 100%|████████████████████████████████████████████████| 1896/1896 [00:16<00:00, 117.07it/s, loss=4.33, ppl=76]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 414.81it/s]


[Validation] loss: 4.4300, PPL: 84, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 16: 100%|██████████████████████████████████████████████| 1896/1896 [00:16<00:00, 117.08it/s, loss=4.33, ppl=76.2]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 413.67it/s]


[Validation] loss: 4.4186, PPL: 83, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 17: 100%|██████████████████████████████████████████████| 1896/1896 [00:16<00:00, 115.27it/s, loss=4.35, ppl=77.2]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 397.39it/s]


[Validation] loss: 4.4146, PPL: 83, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 18: 100%|██████████████████████████████████████████████| 1896/1896 [00:15<00:00, 123.42it/s, loss=4.32, ppl=75.3]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 418.19it/s]


[Validation] loss: 4.4049, PPL: 82, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 19: 100%|██████████████████████████████████████████████| 1896/1896 [00:15<00:00, 124.28it/s, loss=4.39, ppl=80.5]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 421.70it/s]


[Validation] loss: 4.3901, PPL: 81, LR: 1.25                       
save model to data/save_model/lstm_lm.path



Epoch 20: 100%|██████████████████████████████████████████████| 1896/1896 [00:15<00:00, 118.71it/s, loss=4.27, ppl=71.5]
Valid: 100%|████████████████████████████████████████████████████████████████████████| 151/151 [00:00<00:00, 421.71it/s]

[Validation] loss: 4.3925, PPL: 81, LR: 1.25                       





加载模型

In [14]:
model = LSTMModel(hparams)
model.load_state_dict(torch.load(hparams.model_save_path))
model.eval()

LSTMModel(
  (drop): Dropout(p=0.2, inplace=False)
  (embedding): Embedding(9924, 200)
  (rnn_layer): LSTM(200, 200, num_layers=4, dropout=0.2)
  (decoder): Linear(in_features=200, out_features=9924, bias=True)
)

In [19]:
torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)


input_sentence = ['<sos>', 'we', 'have', 'no', 'useful', 'information']

input_ = torch.tensor([corpus.vocab[word] for word in input_sentence])
input_ = input_.view(-1, 1)

hidden = model.init_hidden(1)

with torch.no_grad():
    for i in range(20):
        # (4, 1, vocab_size)
        output, hidden = model(input_, hidden)
        # (1, vocab_size)
        # 最后一个位置的预测结果
        output = output[-1]
        # (vocab_size)
        word_weights = output.squeeze().div(hparams.temperature).exp()
        # (10, )
        topk_word = word_weights.topk(k=10)[1]
        topk_word = [word_idx for word_idx in topk_word if word_idx not in (corpus.vocab['<eos>'], corpus.vocab['<unk>'])]

        print('[next_word] candidates:', end=' ')
        for word_idx in topk_word:
            word = corpus.vocab.get_itos()[word_idx]
#             word = corpus.vocab.itos[word_idx.item()]

            print(f'{word}', end=' ')
        print()

        input_word = input()
        if input_word == '.':
            break
        word_idx = corpus.vocab[input_word]
        input_sentence.append(input_word)
        print(' '.join(input_sentence))
        # word_idx = topk_word[0]
        input_= torch.tensor([word_idx]).view(1, 1)

[next_word] candidates: in and on of for to than with says 
in
<sos> we have no useful information in
[next_word] candidates: the a this recent which their our any n 
the
<sos> we have no useful information in the
[next_word] candidates: past world company u future market stock next country 
.


In [27]:
torch.manual_seed(hparams.seed)
torch.cuda.manual_seed(hparams.seed)


# input_ = torch.randint(hparams.vocab_size, (1, 1), dtype=torch.long)

input_ = torch.tensor(corpus.vocab['<sos>'])
input_ = input_.view(-1, 1)
hidden = model.init_hidden(1)

word = corpus.vocab.itos[input_.item()]
print(word, end=' ')

with torch.no_grad():
    for i in range(100):
        output, hidden = model(input_, hidden)

        word_weights = output.squeeze().div(1).exp()
        word_weights[corpus.vocab['<unk>']] = 0.
        word_idx = torch.multinomial(word_weights, 1)[0]

        input_.fill_(word_idx)
        word = corpus.vocab.itos[word_idx]
        print(word, end=' ')
        if word == '<eos>':
            print()

<sos> <eos> 
<sos> it has ruled that no magazine saying might draw new professional for any part to big insurance <eos> 
<sos> a white house spokesman called the rules of congress with appropriations safety and milk system by gte ' s market <eos> 
<sos> such losses are to likely said stop estimated the japanese sides <eos> 
<sos> the increasing start in recent months are n ' t expected to join the united aid program <eos> 
<sos> congress was n ' t likely to kill approved general government confidence in the wake of a federal bank <eos> 
<sos> the new jersey meeting 

### 拓展

1. LSTM得到的句子表示本身就可以用于下游任务
1. LSTM的并行性不够友好，难以用于大规模语料上的训练 => Transformer
2. 文本生成的解码策略：Greedy Search vs Beam Search

refernece: https://huggingface.co/blog/how-to-generate