# LSTM Language Models

You guys probably very excited about ChatGPT.  In today class, we will be implementing a very simple language model, which is basically what ChatGPT is, but with a simple LSTM.  You will be surprised that it is not so difficult at all.

Paper that we base on is *Regularizing and Optimizing LSTM Language Models*, https://arxiv.org/abs/1708.02182

In [34]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext, datasets, math
from tqdm import tqdm

In [35]:
# from convokit import Corpus, download
# corpus = Corpus(filename=download('diplomacy-corpus'))

In [36]:
import nltk
nltk.corpus.gutenberg.fileids()


['austen-emma.txt',
 'austen-persuasion.txt',
 'austen-sense.txt',
 'bible-kjv.txt',
 'blake-poems.txt',
 'bryant-stories.txt',
 'burgess-busterbrown.txt',
 'carroll-alice.txt',
 'chesterton-ball.txt',
 'chesterton-brown.txt',
 'chesterton-thursday.txt',
 'edgeworth-parents.txt',
 'melville-moby_dick.txt',
 'milton-paradise.txt',
 'shakespeare-caesar.txt',
 'shakespeare-hamlet.txt',
 'shakespeare-macbeth.txt',
 'whitman-leaves.txt']

In [37]:
dataset = datasets.load_dataset("text", "text", data_files={"train": "mobydick_train.txt","test":"mobydick_test.txt","validation":"mobydick_validate.txt"})

# https://www.gutenberg.org/ebooks/2701

In [38]:
corpus = nltk.corpus.gutenberg.words('shakespeare-hamlet.txt')

In [39]:
nltk.Text(nltk.corpus.gutenberg.words('shakespeare-caesar.txt'))



<Text: The Tragedie of Julius Caesar by William Shakespeare 1599>

In [40]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [41]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. Load data - Wiki Text

We will be using wikitext which contains a large corpus of text, perfect for language modeling task.  This time, we will use the `datasets` library from HuggingFace to load.

In [42]:
# dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')

In [43]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 18183
    })
    test: Dataset({
        features: ['text'],
        num_rows: 2965
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1162
    })
})


In [44]:
dataset['train'][222]['text']

'CHAPTER 90. Heads or Tails.'

In [45]:
print(dataset['train'].shape)

(18183, 1)


## 2. Preprocessing

### Tokenizing

Simply tokenize the given text to tokens.

In [46]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}

tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

In [47]:
# print(tokenized_dataset['train'][100]['tokens'])

In [48]:
# tokenized_dataset['train']['tokens']

In [49]:
# corpus_token = tokenizer(corpus)

In [50]:
# corpus_token

### Numericalizing

We will tell torchtext to add any word that has occurred at least three times in the dataset to the vocabulary because otherwise it would be too big.  Also we shall make sure to add `unk` and `eos`.

In [51]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], min_freq=3)
vocab.insert_token('<unk>', 0)
vocab.insert_token('<eos>', 1)
vocab.set_default_index(vocab['<unk>'])

In [52]:
print(len(vocab))

6008


In [53]:
print(vocab.get_itos()[:30])

['<unk>', '<eos>', ',', 'the', '.', 'of', 'and', 'a', 'to', 'in', 'that', 'his', 'it', 'i', 'is', 'he', 'with', 'as', 'but', 'was', 'for', 'all', 'this', '!', 'at', '”', 'by', 'not', 'from', 'be']


## 3. Prepare the batch loader

### Prepare data

Given "Chaky loves eating at AIT", and "I really love deep learning", and given batch size = 3, we will get three batches of data "Chaky loves eating at", "AIT `<eos>` I really", "love deep learning `<eos>`".  

In [54]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        if example['tokens']:
            tokens = example['tokens'].append('<eos>')
            tokens = [vocab[token] for token in example['tokens']]
            data.extend(tokens)
    data = torch.LongTensor(data)
    num_batches = data.shape[0] // batch_size
    data = data[:num_batches * batch_size]
    data = data.view(batch_size, num_batches) #view vs. reshape (whether data is contiguous)
    return data #[batch size, seq len]

In [55]:
# def get_data(token_corpus, vocab, batch_size):
#     data = []
#     tokens = token_corpus.append('<eos>')
#     tokens = [vocab[token] for token in token_corpus]
#     data.extend(tokens)
#     data = torch.LongTensor(data)
#     num_batches = data.shape[0] // batch_size
#     data = data[:num_batches * batch_size]
#     data = data.view(batch_size, num_batches) #view vs. reshape (whether data is contiguous)
#     return data #[batch size, seq len]

In [56]:
# token = get_data(corpus_token,vocab,batch_size=128)

In [57]:
# token.shape

In [58]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'],  vocab, batch_size)

In [59]:
train_data.shape

torch.Size([128, 1715])

## 4. Modeling 

<img src="figures/LM.png" width=600>

In [60]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim    = hid_dim
        self.emb_dim    = emb_dim
        
        self.embedding  = nn.Embedding(vocab_size, emb_dim)
        self.lstm       = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, dropout=dropout_rate, batch_first=True)
        self.dropout    = nn.Dropout(dropout_rate)
        self.fc         = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
    
    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hid_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_other)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.emb_dim,
                self.hid_dim).uniform_(-init_range_other, init_range_other) #We
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hid_dim,   
                self.hid_dim).uniform_(-init_range_other, init_range_other) #Wh
    
    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        return hidden, cell
        
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach() #not to be used for gradient computation
        cell   = cell.detach()
        return hidden, cell
        
    def forward(self, src, hidden):
        #src: [batch_size, seq len]
        embedding = self.dropout(self.embedding(src)) #harry potter is
        #embedding: [batch-size, seq len, emb dim]
        output, hidden = self.lstm(embedding, hidden)
        #ouput: [batch size, seq len, hid dim]
        #hidden: [num_layers * direction, seq len, hid_dim]
        output = self.dropout(output)
        prediction =self.fc(output)
        #prediction: [batch_size, seq_len, vocab_size]
        return prediction, hidden

## 5. Training 

Follows very basic procedure.  One note is that some of the sequences that will be fed to the model may involve parts from different sequences in the original dataset or be a subset of one (depending on the decoding length). For this reason we will reset the hidden state every epoch, this is like assuming that the next batch of sequences is probably always a follow up on the previous in the original dataset.

In [61]:
vocab_size = len(vocab)
emb_dim = 1024                # 400 in the paper
hid_dim = 1024                # 1150 in the paper
num_layers = 2                # 3 in the paper
dropout_rate = 0.65              
lr = 1e-3                     

In [62]:
model      = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer  = optim.Adam(model.parameters(), lr=lr)
criterion  = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 29,103,992 trainable parameters


In [63]:
def get_batch(data, seq_len, idx):
    #data #[batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [64]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data #[batch size, seq len]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]  #we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    #reset the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        #hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) #src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        #need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  #prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [65]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

Here we will be using a `ReduceLROnPlateau` learning scheduler which decreases the learning rate by a factor, if the loss don't improve by a certain epoch.

In [66]:
n_epochs = 50
seq_len  = 50 #<----decoding length
clip    = 0.25


lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model,train_data , optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                         

	Train Perplexity: 498.610
	Valid Perplexity: 259.241


                                                         

	Train Perplexity: 340.425
	Valid Perplexity: 249.278


                                                         

	Train Perplexity: 301.527
	Valid Perplexity: 208.440


                                                         

	Train Perplexity: 247.346
	Valid Perplexity: 185.468


                                                         

	Train Perplexity: 222.987
	Valid Perplexity: 174.084


                                                         

	Train Perplexity: 205.186
	Valid Perplexity: 163.020


                                                         

	Train Perplexity: 187.880
	Valid Perplexity: 152.531


                                                         

	Train Perplexity: 168.239
	Valid Perplexity: 137.687


                                                         

	Train Perplexity: 151.036
	Valid Perplexity: 132.028


                                                         

	Train Perplexity: 140.238
	Valid Perplexity: 127.934


                                                         

	Train Perplexity: 132.082
	Valid Perplexity: 123.730


                                                         

	Train Perplexity: 124.798
	Valid Perplexity: 127.163


                                                         

	Train Perplexity: 116.591
	Valid Perplexity: 123.186


                                                         

	Train Perplexity: 111.926
	Valid Perplexity: 122.025


                                                         

	Train Perplexity: 107.766
	Valid Perplexity: 120.240


                                                         

	Train Perplexity: 104.016
	Valid Perplexity: 121.105


                                                         

	Train Perplexity: 100.615
	Valid Perplexity: 118.340


                                                         

	Train Perplexity: 98.711
	Valid Perplexity: 117.635


                                                         

	Train Perplexity: 96.869
	Valid Perplexity: 117.235


                                                         

	Train Perplexity: 95.039
	Valid Perplexity: 116.571


                                                         

	Train Perplexity: 93.344
	Valid Perplexity: 116.621


                                                         

	Train Perplexity: 91.558
	Valid Perplexity: 116.049


                                                         

	Train Perplexity: 90.544
	Valid Perplexity: 115.837


                                                         

	Train Perplexity: 89.944
	Valid Perplexity: 115.532


                                                         

	Train Perplexity: 89.161
	Valid Perplexity: 115.413


                                                         

	Train Perplexity: 88.379
	Valid Perplexity: 115.341


                                                         

	Train Perplexity: 87.709
	Valid Perplexity: 115.416


                                                         

	Train Perplexity: 86.946
	Valid Perplexity: 115.709


                                                         

	Train Perplexity: 86.614
	Valid Perplexity: 114.972


                                                         

	Train Perplexity: 86.270
	Valid Perplexity: 114.867


                                                         

	Train Perplexity: 85.784
	Valid Perplexity: 114.833


                                                         

	Train Perplexity: 85.616
	Valid Perplexity: 115.049


                                                         

	Train Perplexity: 85.621
	Valid Perplexity: 115.065


                                                         

	Train Perplexity: 85.644
	Valid Perplexity: 115.014


                                                         

	Train Perplexity: 85.544
	Valid Perplexity: 115.037


                                                         

	Train Perplexity: 85.454
	Valid Perplexity: 115.036


                                                         

	Train Perplexity: 85.507
	Valid Perplexity: 115.023


                                                         

	Train Perplexity: 85.352
	Valid Perplexity: 115.028


                                                         

	Train Perplexity: 85.423
	Valid Perplexity: 115.027


                                                         

	Train Perplexity: 85.584
	Valid Perplexity: 115.029


                                                         

	Train Perplexity: 85.408
	Valid Perplexity: 115.029


                                                         

	Train Perplexity: 85.583
	Valid Perplexity: 115.029


                                                         

	Train Perplexity: 85.476
	Valid Perplexity: 115.030


                                                         

	Train Perplexity: 85.319
	Valid Perplexity: 115.030


                                                         

	Train Perplexity: 85.484
	Valid Perplexity: 115.031


                                                         

	Train Perplexity: 85.555
	Valid Perplexity: 115.031


                                                         

	Train Perplexity: 85.517
	Valid Perplexity: 115.031


                                                         

	Train Perplexity: 85.473
	Valid Perplexity: 115.032


                                                         

	Train Perplexity: 85.303
	Valid Perplexity: 115.032


                                                          

	Train Perplexity: 85.519
	Valid Perplexity: 115.032


## 6. Testing

In [67]:
model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 119.170


## 7. Real-world inference

Here we take the prompt, tokenize, encode and feed it into the model to get the predictions.  We then apply softmax while specifying that we want the output due to the last word in the sequence which represents the prediction for the next word.  We divide the logits by a temperature value to alter the model’s confidence by adjusting the softmax probability distribution.

Once we have the Softmax distribution, we randomly sample it to make our prediction on the next word. If we get <unk> then we give that another try.  Once we get <eos> we stop predicting.
    
We decode the prediction back to strings last lines.

In [68]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            
            #prediction: [batch size, seq len, vocab size]
            #prediction[:, -1]: [batch size, vocab size] #probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: #if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    #if it is eos, we stop
                break

            indices.append(prediction) #autoregressive, thus output becomes input

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [72]:
prompt = 'moby dick is '
max_seq_len = 30
seed = 0

#smaller the temperature, more diverse tokens but comes 
#with a tradeoff of less-make-sense sentence
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.5
moby dick is a way

0.7
moby dick is a very

0.75
moby dick is a cape

0.8
moby dick is crawl ,

1.0
moby dick is crawl ,

