# LSTM Language Models

this will be implementing a very simple language model, which is basically what ChatGPT is, but with a simple LSTM.

Paper that we base on is Regularizing and Optimizing LSTM Language Models. https://arxiv.org/abs/1708.02182

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext, datasets, math
from tqdm import tqdm # progress bar

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
SEED = 1234 # to generate the same results
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. Load data - Wiki Text

We will be using wikitext which contains a large corpus of text, perfect for language modeling task. This time, we will use the datasets library from HuggingFace to load.

In [None]:
# load the dataset from wiki
datasets = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')

In [None]:
print (datasets)

In [None]:
# access the text data inside the dataset.
# print(dataset['which dataset'][which row]['which features'])
print(datasets['train'][223]['text'])


In [None]:
print(datasets['train'].shape) # (row, col) (36718 rows of text)

## 2. Preprocessing

### Tokenization

Simply tokenize the given text to tokens.

In [None]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

# create a function
tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}

# def tokenize_data(example, tokenizer):
#     tokens = tokenizer(example['text'])
#     return tokens

tokenized_dataset = datasets.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

In [None]:
tokenized_dataset

In [None]:
print(tokenized_dataset['train'][223]['tokens'])

### Numericalization

We will tell torchtext to add any word that has occurred at least three times in the dataset to the vocabulary because otherwise it would be too big. Also we shall make sure to add unk and eos.

In [None]:
vocab = torchtext.vocab.build_vocab_from_iterator (tokenized_dataset['train']['tokens'], min_freq = 3)
vocab.insert_token('<unk>', 0)
vocab.insert_token('<eos>', 1)
vocab.set_default_index(vocab['<unk>'])

In [None]:
print(len(vocab))

In [None]:
print(vocab.get_itos()[:10])

## 3. Prepare the batch loader

### Prepare data

Given "Chaky loves eating at AIT", and "I really love deep learning", and given batch size = 3, we will get three batches of data (with 4 words)
- "Chaky loves eating at", 
- "AIT <eos> I really", 
- "love deep learning <eos>".

In [None]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset: # get >> example = Chaky loves eating at AIT
        if example['tokens']:
            
            # add '<eos>' at the end of each sentence, (example) in this case, inside dataset 
            # ['Chaky', 'loves', 'eating', 'at', 'AIT', '<eos>']
            tokens = example['tokens'].append('<eos>') 
            
            # apply numericalization
            tokens = [vocab[token] for token in example['tokens']] # [6,2,3,5,1]
            data.extend(tokens)
    
    # convert data type to torch for embedding        
    data = torch.LongTensor(data)
    
    num_batches = data.shape[0] // batch_size # // is integer division 4
    
    # just to make sure all batches are even
    data = data[:num_batches * batch_size] # to update the data size that is divisable by batch_size, consider odd number for data.shape[0]
    
    data = data.view(batch_size, num_batches) # (3,4) # view vs. reshape (whether data is contiguous)
    
    return data # [batch_size, seq_len]

In [None]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'], vocab, batch_size)

In [None]:
train_data.shape # each train total 128 batch and each batch contain 16214 words

## 4. Modeling

In [None]:
class LSTMLanguageModel(nn.Module):
    
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
        
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim    = hid_dim
        self.emb_dim    = emb_dim
        
        self.embedding  = nn.Embedding(vocab_size, emb_dim)
        self.lstm       = nn.LSTM(emb_dim, hid_dim, num_layers = num_layers, 
                                  dropout = dropout_rate, batch_first = True)
        self.dropout    = nn.Dropout(dropout_rate)
        self.fc         = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
        
    def init_weights(self):
        init_range_emb   = 0.1
        init_range_other = 1 / math.sqrt(self.hid_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_emb)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.emb_dim, self.hid_dim).uniform_(
                -init_range_other, init_range_other) # We
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hid_dim, self.hid_dim).uniform_(
                -init_range_other, init_range_other) # Wh
    
    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        return hidden, cell
    
    def detach_hidden(self, hidden):
        hidden, cell = hidden # return from LSTM that is tuple that contains hidden and cell values
        hidden       = hidden.detach() # not to be used for gradient computation
        cell         = cell.detach()
        return hidden, cell
    
    def forward(self, src, hidden):
        
        # src: [batch_size, seq_len]
        
        embedding      = self.dropout(self.embedding(src)) # src = harry potter is 
        # embedding: [batch-size, seq_len, emb_dim]
        
        output, hidden = self.lstm(embedding, hidden)
        # output: [batch_size, seq_len, emb_dim]
        #hidden: [num_layers * direction, seq_len, hid_dim]
        
        output = self.dropout(output)
        
        prediction = self.fc(output)
        # predcition: [batch_size, seq_len, vocab_size]
        
        return prediction, hidden

## 5. Training

One note is that some of the sequences that will be fed to the model may involve parts from different sequences in the original dataset or be a subset of one (depending on the decoding length). For this reason we will reset the hidden state every epoch, this is like assuming that the next batch of sequences is probably always a follow up on the previous in the original dataset.

In [None]:
vocab_size = len(vocab)
emb_dim = 1024 # 400 in the paper
hid_dim = 1024 # 1150 in the paper
num_layers = 2 # 3 in the paper
dropout_rate = 0.65
lr = 1e-3

In [None]:
model      = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer  = optim.Adam(model.parameters(), lr = lr)
criterion  = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'the model has {num_params:,} trainable parameters')

In [None]:
def get_batch (data, seq_len, idx):
    # data #[batch_size, bunch of tokens]
    src    = data[:, idx  : idx+seq_len]
    target = data[:, idx+1: idx+seq_len+1] # target simply is ahead of src by 1
    return src, target

In [None]:
def train (model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    
    # drop all batches that are not a multiple of seq_len
    # data #[batch_size, seq_len]
    num_batches = data.shape[-1]
    data        = data[:, :num_batches - (num_batches - 1) % seq_len] # we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    # reset the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches-1, seq_len), desc = 'Training: ', leave = False):
        optimizer.zero_grad()
        
        # hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)
        
        src, target = get_batch(data, seq_len, idx) # src, target: [batch_size, seq_len]
        src, target = src.to(device), target.to(device)
        batch_size  = src.shape[0]
        prediction, hidden = model(src, hidden)      
        
        # need to reshape because criterion expects pred to be 2d and target to be 1d 
        prediction = prediction.reshape(batch_size * seq_len, -1)
        target     = target.reshape(-1)
        loss       = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [None]:
def evaluate (model, data, criterion, batch_size, seq_len, device):
    
    epoh_loss = 0
    model.eval() # evaluation
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches - 1) % seq_len]
    num_batches = data.shape[-1]
    
    hidden = model.init_hidden(batch_size, device)
    
    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden      = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size  = src.shape[0]
            
            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape (batch_size * seq_len, -1)
            target = target.reshape(-1)
            
            loss = criterion(prediction, target)
            epoch_loss +=loss.item() * seq_len
            
    return epoch_loss / num_batches

In [None]:
n_epochs = 50
seq_len  = 50 # decoding length
clip = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau (optimizer, factor = 0.5, patience = 0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, batch_size, seq_len, clip, device)
    valid_loss = evaluate (model, valid_data, criterion, batch_size, seq_len, device)
    
    lr_scheduler.step(valid_loss)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-val-lstm_lm.pt')
        
    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

## 6. Testing

In [None]:
model.load_state_dict(torch.load('best-val-lstm_lm.pt', map_location = device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test perplexity: {math.exp(test_loss):.3f}')

##   7. Real-worl inference

In [None]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed = None):
    
    if seed is not None:
        torch.manual_seed(seed)
        
    model.eval()
    
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)