# LSTM Language Models

You guys probably very excited about ChatGPT.  In today class, we will be implementing a very simple language model, which is basically what ChatGPT is, but with a simple LSTM.  You will be surprised that it is not so difficult at all.

Paper that we base on is *Regularizing and Optimizing LSTM Language Models*, https://arxiv.org/abs/1708.02182

In [1]:
import torch    
import torch.nn as nn
import torch.optim as optim
import torchtext, math
import huggingface_hub
from datasets import load_dataset, DatasetDict,concatenate_datasets
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [3]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. Load data 
We will be using '18828_rec.sport.hockey' Dataset.
The '18828_rec.sport.hockey' dataset is a subset of Newsgroups collection
- **Content:** The dataset contains posts from the 'rec.sport.hockey' newsgroup, encompassing a variety of topics such as game analyses, player performances, team discussions, and other hockey-related subjects.
- **Language:** All documents are in English.

#### Initially i was using PennTreebank and it took around 12 hours for 48 epochs and during the 49th epoch, the kernel crashed so i replaced the model with the above dataset which took around 4 hours for 50 epochs as i did not have any more patience left😅.

In [None]:
dataset_hockey = load_dataset('newsgroup','18828_rec.sport.hockey' )['train']

In [49]:
# Split train set into train (80%) and validation (20%)
train_valid = dataset_hockey.train_test_split(test_size=0.2, seed=42)
train_set = train_valid['train']
valid_set = train_valid['test']

In [50]:
# Split test set into test (10%)
test_valid = dataset_hockey.train_test_split(test_size=0.5, seed=42)
test_set = test_valid['test']

In [51]:
# Create final dataset dictionary
dataset = DatasetDict({
    "train": train_set,
    "validation": valid_set,
    "test": test_set
})

In [52]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 799
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 200
    })
    test: Dataset({
        features: ['text'],
        num_rows: 500
    })
})


In [53]:
# Check dataset sizes
for split, data in dataset.items():
    print(f"{split} size: {len(data)}")

train size: 799
validation size: 200
test size: 500


In [54]:
print(dataset['train'].shape)

(799, 1)


## 2. Preprocessing

### Tokenizing

Simply tokenize the given text to tokens.

In [55]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}

tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

Map: 100%|██████████| 799/799 [00:00<00:00, 2824.97 examples/s]
Map: 100%|██████████| 200/200 [00:00<00:00, 2636.03 examples/s]
Map: 100%|██████████| 500/500 [00:00<00:00, 3333.08 examples/s]


In [56]:
print(tokenized_dataset['train'][223]['tokens'])

['from', 'j3david@sms', '.', 'business', '.', 'uwo', '.', 'ca', '(', 'james', 'david', ')', 'subject', 'plus', 'minus', 'stat', '>post', '51246', 'of', '51422', '>newsgroups', 'rec', '.', 'sport', '.', 'hockey', '>from', 'j3david@sms', '.', 'business', '.', 'uwo', '.', 'ca', '(', 'james', 'david', ')', '>subject', 'plus', 'minus', 'stat', '>organization', 'university', 'of', 'western', 'ontario', '>date', 'fri', ',', '16', 'apr', '1993', '04', '42', '11', 'gmt', '>nntp-posting-host', 'sms', '.', 'business', '.', 'uwo', '.', 'ca', '>lines', '165', '>i', "'", 'm', 'not', 'defending', 'bob', 'gainey', '.', '.', '.', 'frankly', ',', 'i', 'don', "'", 't', 'care', 'for', 'him', 'all', '>that', 'much', '.', 'but', 'your', 'dismissal', 'of', 'him', 'as', 'something', 'less', 'than', 'an', '>effective', 'hockey', 'player', 'is', 'tiresome', '.', '.', '.', 'it', 'has', 'no', 'basis', 'in', '>anything', '.', 'how', 'many', 'calders', 'did', 'he', 'win', '?', 'i', 'think', 'it', 'was', 'four', '('

### Numericalizing

We will tell torchtext to add any word that has occurred at least three times in the dataset to the vocabulary because otherwise it would be too big.  Also we shall make sure to add `unk` and `eos`.

In [57]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], min_freq=3)
vocab.insert_token('<unk>', 0)
vocab.insert_token('<eos>', 1)
vocab.set_default_index(vocab['<unk>'])

In [59]:
print(vocab.get_itos()[:10])

['<unk>', '<eos>', '.', ',', 'the', ')', '(', '0', "'", 'to']


## 3. Prepare the batch loader

### Prepare data

Given "Chaky loves eating at AIT", and "I really love deep learning", and given batch size = 3, we will get three batches of data "Chaky loves eating at", "AIT `<eos>` I really", "love deep learning `<eos>`".  

In [60]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        if example['tokens']:
            tokens = example['tokens'].append('<eos>')
            tokens = [vocab[token] for token in example['tokens']]
            data.extend(tokens)
    data = torch.LongTensor(data)
    num_batches = data.shape[0] // batch_size
    data = data[:num_batches * batch_size]
    data = data.view(batch_size, num_batches) #view vs. reshape (whether data is contiguous)
    return data #[batch size, seq len]

In [61]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'],  vocab, batch_size)

In [62]:
train_data.shape

torch.Size([128, 2254])

## 4. Modeling 

<img src="figures/LM.png" width=600>

In [63]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim    = hid_dim
        self.emb_dim    = emb_dim
        
        self.embedding  = nn.Embedding(vocab_size, emb_dim)
        self.lstm       = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, dropout=dropout_rate, batch_first=True)
        self.dropout    = nn.Dropout(dropout_rate)
        self.fc         = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
    
    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hid_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_other)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.emb_dim,
                self.hid_dim).uniform_(-init_range_other, init_range_other) #We
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hid_dim,   
                self.hid_dim).uniform_(-init_range_other, init_range_other) #Wh
    
    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        return hidden, cell
        
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach() #not to be used for gradient computation
        cell   = cell.detach()
        return hidden, cell
        
    def forward(self, src, hidden):
        #src: [batch_size, seq len]
        embedding = self.dropout(self.embedding(src)) #harry potter is
        #embedding: [batch-size, seq len, emb dim]
        output, hidden = self.lstm(embedding, hidden)
        #ouput: [batch size, seq len, hid dim]
        #hidden: [num_layers * direction, seq len, hid_dim]
        output = self.dropout(output)
        prediction =self.fc(output)
        #prediction: [batch_size, seq_len, vocab_size]
        return prediction, hidden

## 5. Training 

Follows very basic procedure.  One note is that some of the sequences that will be fed to the model may involve parts from different sequences in the original dataset or be a subset of one (depending on the decoding length). For this reason we will reset the hidden state every epoch, this is like assuming that the next batch of sequences is probably always a follow up on the previous in the original dataset.

In [64]:
vocab_size = len(vocab)
emb_dim = 1024                # 400 in the paper
hid_dim = 1024                # 1150 in the paper
num_layers = 2                # 3 in the paper
dropout_rate = 0.65              
lr = 1e-3                     

In [65]:
model      = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer  = optim.Adam(model.parameters(), lr=lr)
criterion  = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 31,964,396 trainable parameters


In [66]:
def get_batch(data, seq_len, idx):
    #data #[batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [67]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data #[batch size, seq len]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]  #we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    #reset the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        #hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) #src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        #need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  #prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [68]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

Here we will be using a `ReduceLROnPlateau` learning scheduler which decreases the learning rate by a factor, if the loss don't improve by a certain epoch.

In [69]:
n_epochs = 50
seq_len  = 50 #<----decoding length
clip    = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), './models/best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                         

	Train Perplexity: 803.228
	Valid Perplexity: 475.274


                                                         

	Train Perplexity: 548.508
	Valid Perplexity: 365.829


                                                         

	Train Perplexity: 367.020
	Valid Perplexity: 252.561


                                                         

	Train Perplexity: 257.335
	Valid Perplexity: 185.610


                                                         

	Train Perplexity: 196.114
	Valid Perplexity: 146.581


                                                         

	Train Perplexity: 155.576
	Valid Perplexity: 122.014


                                                         

	Train Perplexity: 127.522
	Valid Perplexity: 103.832


                                                         

	Train Perplexity: 106.853
	Valid Perplexity: 91.440


                                                         

	Train Perplexity: 91.751
	Valid Perplexity: 83.339


                                                         

	Train Perplexity: 80.218
	Valid Perplexity: 75.962


                                                         

	Train Perplexity: 71.243
	Valid Perplexity: 70.481


                                                         

	Train Perplexity: 64.192
	Valid Perplexity: 66.657


                                                         

	Train Perplexity: 58.008
	Valid Perplexity: 63.412


                                                         

	Train Perplexity: 53.036
	Valid Perplexity: 60.235


                                                         

	Train Perplexity: 48.937
	Valid Perplexity: 57.438


                                                         

	Train Perplexity: 45.254
	Valid Perplexity: 55.237


                                                         

	Train Perplexity: 41.999
	Valid Perplexity: 53.423


                                                         

	Train Perplexity: 39.054
	Valid Perplexity: 51.830


                                                         

	Train Perplexity: 36.679
	Valid Perplexity: 50.116


                                                         

	Train Perplexity: 34.626
	Valid Perplexity: 48.922


                                                         

	Train Perplexity: 32.600
	Valid Perplexity: 47.665


                                                         

	Train Perplexity: 30.649
	Valid Perplexity: 46.451


                                                         

	Train Perplexity: 28.974
	Valid Perplexity: 45.382


                                                         

	Train Perplexity: 27.300
	Valid Perplexity: 44.744


                                                         

	Train Perplexity: 25.936
	Valid Perplexity: 44.418


                                                         

	Train Perplexity: 24.747
	Valid Perplexity: 43.775


                                                         

	Train Perplexity: 23.631
	Valid Perplexity: 42.483


                                                         

	Train Perplexity: 22.551
	Valid Perplexity: 42.130


                                                         

	Train Perplexity: 21.653
	Valid Perplexity: 41.853


                                                         

	Train Perplexity: 20.836
	Valid Perplexity: 41.256


                                                         

	Train Perplexity: 20.018
	Valid Perplexity: 40.777


                                                         

	Train Perplexity: 19.208
	Valid Perplexity: 40.558


                                                         

	Train Perplexity: 18.473
	Valid Perplexity: 40.267


                                                         

	Train Perplexity: 17.774
	Valid Perplexity: 40.121


                                                         

	Train Perplexity: 17.127
	Valid Perplexity: 39.691


                                                         

	Train Perplexity: 16.496
	Valid Perplexity: 39.293


                                                         

	Train Perplexity: 15.963
	Valid Perplexity: 39.173


                                                         

	Train Perplexity: 15.522
	Valid Perplexity: 39.454


                                                         

	Train Perplexity: 14.645
	Valid Perplexity: 38.253


                                                         

	Train Perplexity: 14.105
	Valid Perplexity: 38.239


                                                         

	Train Perplexity: 13.703
	Valid Perplexity: 38.237


                                                         

	Train Perplexity: 13.197
	Valid Perplexity: 37.989


                                                         

	Train Perplexity: 12.948
	Valid Perplexity: 37.790


                                                         

	Train Perplexity: 12.743
	Valid Perplexity: 37.793


                                                         

	Train Perplexity: 12.495
	Valid Perplexity: 37.409


                                                         

	Train Perplexity: 12.339
	Valid Perplexity: 37.418


                                                         

	Train Perplexity: 12.186
	Valid Perplexity: 37.371


                                                         

	Train Perplexity: 12.146
	Valid Perplexity: 37.354


                                                         

	Train Perplexity: 12.052
	Valid Perplexity: 37.353


                                                         

	Train Perplexity: 12.015
	Valid Perplexity: 37.372


## 6. Testing

In [73]:
model.load_state_dict(torch.load('./models/best-val-lstm_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 14.509


In [71]:
import pickle

In [72]:
# save the model
with open('./models/LSTM_model.pkl', 'wb') as f:
    pickle.dump(model, f)
    
# save the tokenizer
with open('./models/tokenizer.pkl', 'wb') as f:
    pickle.dump(tokenizer, f)

# save the tokenized dataset
with open('./models/tkzdDataset.pkl', 'wb') as f:
    pickle.dump(tokenized_dataset, f)
    
# save the vocab
with open('./models/vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)

## 7. Real-world inference

Here we take the prompt, tokenize, encode and feed it into the model to get the predictions.  We then apply softmax while specifying that we want the output due to the last word in the sequence which represents the prediction for the next word.  We divide the logits by a temperature value to alter the model’s confidence by adjusting the softmax probability distribution.

Once we have the Softmax distribution, we randomly sample it to make our prediction on the next word. If we get <unk> then we give that another try.  Once we get <eos> we stop predicting.
    
We decode the prediction back to strings last lines.

In [74]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            
            #prediction: [batch size, seq len, vocab size]
            #prediction[:, -1]: [batch size, vocab size] #probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: #if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    #if it is eos, we stop
                break

            indices.append(prediction) #autoregressive, thus output becomes input

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [76]:
prompt = 'Harry Potter is '
max_seq_len = 30
seed = 0

#smaller the temperature, more diverse tokens but comes 
#with a tradeoff of less-make-sense sentence
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.5
harry <unk> is the best player in the league . the previous gm . the st . louis blues . the replay he was approached by the leafs , and the first goal

0.7
harry <unk> is a great . . . it ' s not in the playoffs . i don ' t know whether the exact replay on the game when you ' ll either

0.75
harry <unk> is a great . . . it ' s not in the playoffs . even you didn ' t have been been size . i ' m not sure what he

0.8
harry <unk> is just the best player in the nhl . he called a lot at the end of the ice , a challenge of the blue or a defenseman on the old

1.0
harry <unk> is love by benoit superstars or 12 games in at the end of the season . a challenge of the blue or national pressure on the maine franchise . again i

