# LSTM Language Models

In [50]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
import math
import datasets

from datasets import load_dataset
from tqdm import tqdm

In [13]:
print("Torch version:", torch.__version__)
print("Torchtext version:", torchtext.__version__)

Torch version: 2.1.1+cpu
Torchtext version: 0.16.1+cpu


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [15]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. Load data - Alice in Wonderland

In [33]:
dataset = datasets.load_dataset('myothiha/starwars')

README.md:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


(…)Heir to the Empire (by Timothy Zahn).txt:   0%|          | 0.00/687k [00:00<?, ?B/s]

(…) Dark Force Rising (by Timothy Zahn).txt:   0%|          | 0.00/764k [00:00<?, ?B/s]

(…)- The Last Command (by Timothy Zahn).txt:   0%|          | 0.00/820k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7860 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/8101 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/9236 [00:00<?, ? examples/s]

In [34]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 7860
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 8101
    })
    test: Dataset({
        features: ['text'],
        num_rows: 9236
    })
})


In [35]:
print(dataset['train'].shape)
print(dataset['validation'].shape)
print(dataset['test'].shape)

(7860, 1)
(8101, 1)
(9236, 1)


## 2. Preprocessing

### 1) Tokenizing

In [36]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english') 
# set up a tokenizer: splitting the text into tokens by spaces and punctuation
# e.g., Harry Potter > ['Harry', 'Potter']

tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}
# example: a single data entry (e.g., a sentence) from the dataset
# tokenizer(example['text']) takes the text from the example and breaks it into tokens
# the result is stored as 'tokens' in a dictionary

tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})
# dataset.map(tokenize_data) applies the tokenize_data function to every item in the dataset, tokenizing all sentences.
# remove_columns=['text']: only need the tokenized wirds after tokenizing
# fn_kwargs={'tokenizer': tokenizer} passes the tokenizer as a parameter to the tokenize_data function.

Map:   0%|          | 0/7860 [00:00<?, ? examples/s]

Map:   0%|          | 0/8101 [00:00<?, ? examples/s]

Map:   0%|          | 0/9236 [00:00<?, ? examples/s]

In [37]:
print(tokenized_dataset['train'][123]['tokens']) # tokens of the 123rd sentence in the training set

['yes', ',', 'i', 'see', ',', 'pellaeon', 'said', ',', 'not', 'entirely', 'truthfully', '.', 'admiral', ',', 'shouldn', "'", 't', 'we', 'be—', '?']


### 2) Numericalizing

In [38]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], min_freq=3)
# build_vocab_from_iterator: creates a vocabulary (a list of unique words) from the tokens in the dataset.
# tokenized_dataset['train']['tokens']: uses the tokens from the training set of the tokenized dataset.
# min_freq=3: only includes words that appear at least 3 times in the dataset; otherwise ignore.


vocab.insert_token('<unk>', 0)
vocab.insert_token('<eos>', 1)
# add special tokens <unk> and <eos> for unknown and end of sentence

vocab.set_default_index(vocab['<unk>'])
# if the model capture a word not in the vocab, it will use <unk> as an unknown word.

In [39]:
print(len(vocab))

3449


In [40]:
print(vocab.get_itos()[:10])
# itos stands for Index-to-String. 
# get_itos() returns a list of tokens, mapping each index in the vocabulary to its corresponding word.

['<unk>', '<eos>', '.', ',', 'the', "'", 'to', 'a', 'of', 'and']


## 3. Prepare the batch loader

In [41]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset: # loop each example in the dataset
        if example['tokens']: # check if the example contains any tokens
            tokens = example['tokens'].append('<eos>') # add <eos> at the end
            tokens = [vocab[token] for token in example['tokens']] # map each word in example['tokens'] to its index in vocab
            data.extend(tokens) # add the elements of tokens to data instead of a single list of all elements.
    data = torch.LongTensor(data) # convert the data list into a tensor
    num_batches = data.shape[0] // batch_size # no. of batches
    data = data[:num_batches * batch_size] # to confirm the no. of tokens and all batches have the same no. of tokens
    data = data.view(batch_size, num_batches) # to reshape the data tensor
    return data # [batch size, seq len] # seq len: no. of tokens in each sequence within a batch.

# This function takes the dataset, tokenizes it, converts words to indices, 
# and then organizes the tokens into batches of a specified size.

In [42]:
batch_size = 32
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'],  vocab, batch_size)

In [46]:
print(f"Train Data Shape: {train_data.shape}")
print(f"Validation Data Shape: {valid_data.shape}")
print(f"Test Data Shape: {test_data.shape}")

Train Data Shape: torch.Size([32, 4619])
Validation Data Shape: torch.Size([32, 5157])
Test Data Shape: torch.Size([32, 5547])


## 4. Modeling

In [47]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim    = hid_dim
        self.emb_dim    = emb_dim
        
        self.embedding  = nn.Embedding(vocab_size, emb_dim)
        self.lstm       = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, dropout=dropout_rate, batch_first=True)
        self.dropout    = nn.Dropout(dropout_rate)
        self.fc         = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
    
    def init_weights(self): # initialize weights for all layers (embedding, LSTM, fully connected layer)
        init_range_emb = 0.1 # for embedding
        init_range_other = 1/math.sqrt(self.hid_dim) # for LSTM and fc
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_other) # randomly in uniform dist.
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_() # initialize the bias of fc to 0
        for i in range(self.num_layers): # loop each layer of the LSTM
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.emb_dim,
                self.hid_dim).uniform_(-init_range_other, init_range_other) # We # the weights between the embedding and LSTM layers
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hid_dim,   
                self.hid_dim).uniform_(-init_range_other, init_range_other) # Wh # the weights for the hidden-to-hidden connections within the LSTM
    
    def init_hidden(self, batch_size, device): # will be used when the model processes new batches
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device) # zeros tensor
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device) # zeros tensor
        return hidden, cell
        
    def detach_hidden(self, hidden): # detach: the model will stop remembering the history of those states during the learning process
        hidden, cell = hidden
        hidden = hidden.detach() # not to be used for gradient computation
        cell   = cell.detach()
        return hidden, cell
        
    def forward(self, src, hidden):
        # src: [batch_size, seq len]
       
        embedding = self.dropout(self.embedding(src)) # e.g., harry potter is
        # embedding: [batch-size, seq len, emb dim]
        # dropout: to prevent overfitting

        output, hidden = self.lstm(embedding, hidden)
        # output: [batch size, seq len, hid dim]
        # hidden: [num_layers * direction, seq len, hid_dim]
       
        output = self.dropout(output) # to prevent overfitting
       
        prediction =self.fc(output)
        # prediction: [batch_size, seq_len, vocab_size]
       
        return prediction, hidden
        # return the predictions for the next word and the updated hidden state

## 5. Training

In [48]:
vocab_size = len(vocab)
emb_dim = 1024                # 400 in the paper
hid_dim = 1024                # 1150 in the paper
num_layers = 2                # 3 in the paper
dropout_rate = 0.65              
lr = 1e-3                     

In [51]:
model      = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer  = optim.Adam(model.parameters(), lr=lr) # Adam optimizer to adjust the model's weights during training
criterion  = nn.CrossEntropyLoss() # loss function
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # the number of trainable parameters in the model
print(f'The model has {num_params:,} trainable parameters')

The model has 23,860,601 trainable parameters


In [52]:
def get_batch(data, seq_len, idx):
    # data # [batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  # target simply is ahead of src by 1            
    return src, target

In [53]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data #[batch size, seq len]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches - 1) % seq_len]  # we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    # reset the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        # hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) # src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        # need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  # prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1) # reshape into 1d
        loss = criterion(prediction, target)
        
        loss.backward() # backpropagate
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip) # clip gradients to prevent explosion (limits gradients above a threshold)
        optimizer.step() # update the model params using the optimizer
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches # avr loss

In [54]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad(): # disable gradient calculations to save memory and speed up evaluation
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1) # reshape into 1d

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches  # avr loss

In [56]:
n_epochs = 30
seq_len  = 30 #<----decoding length
clip    = 0.25 # if gradients exceed 0.25, they will be clipped.

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)
# if the validation loss plateaus (doesn’t improve), the lr will be reduced by a factor of 0.5
# patience=0: it will reduce the lr immediately when no improvement is seen

best_valid_loss = float('inf') # intialize best val loss to inf

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                           

	Train Perplexity: 175.991
	Valid Perplexity: 120.382


                                                           

	Train Perplexity: 127.953
	Valid Perplexity: 99.953


                                                           

	Train Perplexity: 105.477
	Valid Perplexity: 88.962


                                                           

	Train Perplexity: 92.029
	Valid Perplexity: 82.713


                                                           

	Train Perplexity: 82.017
	Valid Perplexity: 78.179


                                                           

	Train Perplexity: 74.724
	Valid Perplexity: 75.677


                                                           

	Train Perplexity: 68.778
	Valid Perplexity: 73.137


                                                           

	Train Perplexity: 63.752
	Valid Perplexity: 71.700


                                                           

	Train Perplexity: 59.157
	Valid Perplexity: 70.894


                                                           

	Train Perplexity: 55.362
	Valid Perplexity: 70.042


                                                           

	Train Perplexity: 51.672
	Valid Perplexity: 70.165


                                                           

	Train Perplexity: 47.035
	Valid Perplexity: 68.543


                                                           

	Train Perplexity: 44.624
	Valid Perplexity: 68.569


                                                           

	Train Perplexity: 42.425
	Valid Perplexity: 67.987


                                                           

	Train Perplexity: 41.371
	Valid Perplexity: 67.807


                                                           

	Train Perplexity: 40.374
	Valid Perplexity: 67.954


                                                           

	Train Perplexity: 39.161
	Valid Perplexity: 67.443


                                                           

	Train Perplexity: 38.655
	Valid Perplexity: 67.447


                                                           

	Train Perplexity: 38.152
	Valid Perplexity: 67.328


                                                           

	Train Perplexity: 37.732
	Valid Perplexity: 67.348


                                                           

	Train Perplexity: 37.519
	Valid Perplexity: 67.381


                                                           

	Train Perplexity: 37.449
	Valid Perplexity: 67.390


                                                           

	Train Perplexity: 37.325
	Valid Perplexity: 67.402


                                                           

	Train Perplexity: 37.173
	Valid Perplexity: 67.400


                                                           

	Train Perplexity: 37.163
	Valid Perplexity: 67.401


                                                           

	Train Perplexity: 37.011
	Valid Perplexity: 67.399


                                                           

	Train Perplexity: 37.238
	Valid Perplexity: 67.399


                                                           

	Train Perplexity: 37.120
	Valid Perplexity: 67.400


                                                           

	Train Perplexity: 37.121
	Valid Perplexity: 67.399


                                                           

	Train Perplexity: 37.067
	Valid Perplexity: 67.399


## 6. Testing

In [57]:
model.load_state_dict(torch.load('best-val-lstm_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 64.319


## 7. Inference

In [58]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens] # convert the list of tokens into indices
    batch_size = 1 # generating text for a single prompt
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            
            # prediction: [batch size, seq len, vocab size]
            # prediction[:, -1]: [batch size, vocab size] # probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: # if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    # if it is eos, we stop
                break

            indices.append(prediction) # autoregressive, thus output becomes input

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices] # convert the list of indices back into the corresponding tokens
    return tokens

In [75]:
prompt = 'yoda'
max_seq_len = 30 # output including the given prompt
seed = 0

# smaller the temperature, more diverse tokens but comes with a tradeoff of less-make-sense sentence
temperatures = [0.3, 0.5, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.3
yoda ' s voice .

0.5
yoda ' s eyes . it ' s no right , he said .

0.75
yoda ' s hand . i ' ve never met to yourself .

0.8
yoda ' s hand was close solidly .

1.0
yoda ' s quiet flight were than easing . he didn ' t look enough running enough to notice or secure kept this way . he glanced at lando . you



In [68]:
import pickle
with open('vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)