# LSTM Language Models

You guys probably very excited about ChatGPT.  In today class, we will be implementing a very simple language model, which is basically what ChatGPT is, but with a simple LSTM.  You will be surprised that it is not so difficult at all.

Paper that we base on is *Regularizing and Optimizing LSTM Language Models*, https://arxiv.org/abs/1708.02182

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchtext, datasets, math
from tqdm import tqdm


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.4.1 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/santhosh/Documents/DSAI/Semester 2/NLP/A2 Assignment/venv/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/santhosh/Documents/DSAI/Semester 2/NLP/A2 Assignment/venv/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/santhosh/Documents/DSAI/Semester 2/NLP/A2 Assignm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [3]:
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

## 1. Load data - CNN Dailymail
The dataset used in this notebook is provided through the **Hugging Face Datasets** library and is maintained by the open-source research community.

**Dataset Repository:**  
`abisee/cnn_dailymail`

This dataset is derived from the CNN/Daily Mail news articles and is widely used for abstractive and extractive text summarization tasks in Natural Language Processing.

### Original Research

The preparation and use of this dataset for modern NLP summarization tasks were introduced in the following work:

> Abigail See, Peter J. Liu, and Christopher D. Manning (2017).  
> *Get To The Point: Summarization with Pointer-Generator Networks.*  
> Proceedings of the 55th Annual Meeting of the Association for Computational Linguistics (ACL).

The dataset is used in this project strictly for academic and research purposes.

## Dataset Description

The model is trained using the **CNN/DailyMail Dataset (Version 2.0.0)**.

### Source Material
The dataset consists of over **300,000 unique news articles** originally published by **CNN** and the **Daily Mail**, covering a wide range of topics and writing styles.

### Structure
Each dataset entry contains:
- A **full-length news article**, and  
- A set of **human-written highlights** that summarize the key points of the article.

These highlights serve as high-quality reference summaries for supervised learning.

### Version 2.0.0 Features
Version 2.0.0 of the dataset is a **non-anonymized** variant in which named entities are preserved.  
This makes it particularly suitable for training language models that aim to capture **natural human prose, entity references, and contextual coherence**.

### Application in This Project
For this assignment, the **highlights** column was renamed to **`text`** and used as the primary training corpus.  
To balance computational efficiency with linguistic diversity, the dataset was **filtered to a subset of 10,000 samples**, reducing training time while retaining stylistic and topical variation.


In [4]:
dataset = datasets.load_dataset("abisee/cnn_dailymail", "2.0.0")

# remove the 'article' and 'id' fields as we will not be using them in our task
dataset = dataset.remove_columns(['article', 'id']) 
dataset = dataset.rename_column("highlights", "text")



In [5]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['text'],
        num_rows: 11490
    })
})


In [6]:
from numpy.random import default_rng

rng = default_rng(seed=SEED)
# create a list of non-repeated indices of size 10000 and use it to select the training samples
select_idx = rng.choice(len(dataset['train']), size=10000, replace=False)
dataset['train'] = dataset['train'].filter(lambda example, idx: idx in select_idx, with_indices=True)

Filter: 100%|██████████| 287113/287113 [00:02<00:00, 115602.90 examples/s]


In [7]:
print(dataset['train'].shape)

(10000, 1)


## 2. Preprocessing

### Tokenization
Text preprocessing is performed using the **`basic_english` tokenizer** from the `torchtext` library.  
This tokenizer applies standard normalization steps such as lowercasing and basic punctuation handling, producing a consistent token representation across the corpus.


In [8]:
tokenizer = torchtext.data.utils.get_tokenizer('basic_english')

tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}

tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

Map: 100%|██████████| 10000/10000 [00:00<00:00, 10855.38 examples/s]


In [9]:
print(tokenized_dataset['train'][223]['tokens'])

['german', 'girl', ',', '17', ',', 'was', 'only', 'survivor', 'of', '1971', 'plane', 'crash', 'in', 'peruvian', 'rainforest', '.', 'juliane', 'koepcke', 'fell', 'more', 'than', '3km', 'into', 'jungle', 'attached', 'to', 'a', 'row', 'of', 'seats', '.', 'koepcke', 'suffered', 'minor', 'injuries', ',', 'survived', 'for', '10', 'days', 'alone', 'in', 'rainforest', '.', 'koepcke', 'haunted', 'by', 'ordeal', 'especially', 'when', 'confronted', 'with', 'other', 'air', 'disasters', '.']


### Vocabulary Construction
A vocabulary is constructed from the tokenized corpus with a **minimum frequency threshold of 3**, ensuring that rare tokens are excluded to reduce noise.  
Special tokens are included to handle edge cases:
- `<unk>` for unknown or out-of-vocabulary words  
- `<eos>` to mark the end of a sentence  

The resulting vocabulary contains **14,174 unique tokens**.

In [10]:
vocab = torchtext.vocab.build_vocab_from_iterator(tokenized_dataset['train']['tokens'], min_freq=3)
vocab.insert_token('<unk>', 0)
vocab.insert_token('<eos>', 1)
vocab.set_default_index(vocab['<unk>'])

In [11]:
print(len(vocab))

14174


In [12]:
print(vocab.get_itos()[:10])

['<unk>', '<eos>', '.', 'the', ',', "'", 'to', 'in', 'of', 'a']


In [13]:
# save vocab
torch.save(vocab, './model/vocab')

## 3. Prepare the batch loader

### Batching Strategy
For efficient training, the dataset is converted into a **continuous stream of tokens**, which is then segmented into batches of size **128**.  
This batching strategy is commonly used in language modeling tasks, as it preserves sequential structure while enabling parallel computation.

In [14]:
def get_data(dataset, vocab, batch_size):
    data = []
    for example in dataset:
        if example['tokens']:
            tokens = example['tokens'].append('<eos>')
            tokens = [vocab[token] for token in example['tokens']]
            data.extend(tokens)
    data = torch.LongTensor(data)
    num_batches = data.shape[0] // batch_size
    data = data[:num_batches * batch_size]
    data = data.view(batch_size, num_batches) #view vs. reshape (whether data is contiguous)
    return data #[batch size, seq len]

In [15]:
batch_size = 128
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data  = get_data(tokenized_dataset['test'],  vocab, batch_size)

In [16]:
train_data.shape

torch.Size([128, 4395])

## 4. Modeling 

In [17]:
class LSTMLanguageModel(nn.Module):
    def __init__(self, vocab_size, emb_dim, hid_dim, num_layers, dropout_rate):
        super().__init__()
        self.num_layers = num_layers
        self.hid_dim    = hid_dim
        self.emb_dim    = emb_dim
        
        self.embedding  = nn.Embedding(vocab_size, emb_dim)
        self.lstm       = nn.LSTM(emb_dim, hid_dim, num_layers=num_layers, dropout=dropout_rate, batch_first=True)
        self.dropout    = nn.Dropout(dropout_rate)
        self.fc         = nn.Linear(hid_dim, vocab_size)
        
        self.init_weights()
    
    def init_weights(self):
        init_range_emb = 0.1
        init_range_other = 1/math.sqrt(self.hid_dim)
        self.embedding.weight.data.uniform_(-init_range_emb, init_range_other)
        self.fc.weight.data.uniform_(-init_range_other, init_range_other)
        self.fc.bias.data.zero_()
        for i in range(self.num_layers):
            self.lstm.all_weights[i][0] = torch.FloatTensor(self.emb_dim,
                self.hid_dim).uniform_(-init_range_other, init_range_other) #We
            self.lstm.all_weights[i][1] = torch.FloatTensor(self.hid_dim,   
                self.hid_dim).uniform_(-init_range_other, init_range_other) #Wh
    
    def init_hidden(self, batch_size, device):
        hidden = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        cell   = torch.zeros(self.num_layers, batch_size, self.hid_dim).to(device)
        return hidden, cell
        
    def detach_hidden(self, hidden):
        hidden, cell = hidden
        hidden = hidden.detach() #not to be used for gradient computation
        cell   = cell.detach()
        return hidden, cell
        
    def forward(self, src, hidden):
        #src: [batch_size, seq len]
        embedding = self.dropout(self.embedding(src)) #harry potter is
        #embedding: [batch-size, seq len, emb dim]
        output, hidden = self.lstm(embedding, hidden)
        #ouput: [batch size, seq len, hid dim]
        #hidden: [num_layers * direction, seq len, hid_dim]
        output = self.dropout(output)
        prediction =self.fc(output)
        #prediction: [batch_size, seq_len, vocab_size]
        return prediction, hidden

## 5. Training 

Follows very basic procedure.  One note is that some of the sequences that will be fed to the model may involve parts from different sequences in the original dataset or be a subset of one (depending on the decoding length). For this reason we will reset the hidden state every epoch, this is like assuming that the next batch of sequences is probably always a follow up on the previous in the original dataset.

In [None]:
vocab_size = len(vocab)
emb_dim = 512               
hid_dim = 512              
num_layers = 3               
dropout_rate = 0.65              
lr = 1e-3                   

In [19]:
model      = LSTMLanguageModel(vocab_size, emb_dim, hid_dim, num_layers, dropout_rate).to(device)
optimizer  = optim.Adam(model.parameters(), lr=lr)
criterion  = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

The model has 20,832,094 trainable parameters


The **LSTM Language Model** consists of approximately **20.8 million trainable parameters** and is composed of the following components:

### Embedding Layer
The embedding layer maps discrete token indices into **512-dimensional continuous vectors**, allowing the model to learn dense semantic representations of words.

### LSTM Layers
The core of the model comprises **three stacked LSTM layers**, each with a hidden dimension of **512**.  
Stacking multiple LSTM layers enables the model to capture increasingly abstract and hierarchical patterns in the input sequence.

### Dropout
A **dropout rate of 0.65** is applied between layers to reduce overfitting by randomly deactivating a portion of neurons during training, thereby encouraging better generalization.

### Output Projection Layer
A fully connected **linear layer** projects the final LSTM hidden states back into the **vocabulary space**, producing logits used to predict the next token in the sequence.

In [20]:
def get_batch(data, seq_len, idx):
    #data #[batch size, bunch of tokens]
    src    = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]  #target simply is ahead of src by 1            
    return src, target

In [21]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    # data #[batch size, seq len]
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]  #we need to -1 because we start at 0
    num_batches = data.shape[-1]
    
    #reset the hidden every epoch
    hidden = model.init_hidden(batch_size, device)
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):
        optimizer.zero_grad()
        
        #hidden does not need to be in the computational graph for efficiency
        hidden = model.detach_hidden(hidden)

        src, target = get_batch(data, seq_len, idx) #src, target: [batch size, seq len]
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction, hidden = model(src, hidden)               

        #need to reshape because criterion expects pred to be 2d and target to be 1d
        prediction = prediction.reshape(batch_size * seq_len, -1)  #prediction: [batch size * seq len, vocab size]  
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

In [22]:
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = model.init_hidden(batch_size, device)

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            hidden = model.detach_hidden(hidden)
            src, target = get_batch(data, seq_len, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction, hidden = model(src, hidden)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches

## Training Configuration

### Optimizer
The model is trained using the **Adam optimizer** with a learning rate of **1e-3 (0.001)**.  
Adam is chosen for its adaptive learning rate capabilities, which help stabilize and accelerate convergence during training.

### Learning Rate Scheduler
A **ReduceLROnPlateau** scheduler is employed to automatically decrease the learning rate when the validation loss stops improving.  
This strategy allows the model to fine-tune its parameters and avoid stagnation during later stages of training.

### Loss Function
Training is optimized using **Cross-Entropy Loss**, which is well-suited for language modeling tasks where the objective is to predict the next token from a large vocabulary.

### Evaluation Metric
Model performance during training is monitored using **Perplexity**, a standard metric for language models that measures how well the model predicts a sequence of tokens.


In [23]:
n_epochs = 20
seq_len  = 50 #<----decoding length
clip    = 0.25

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

best_valid_loss = float('inf')

for epoch in range(n_epochs):
    train_loss = train(model, train_data, optimizer, criterion, 
                batch_size, seq_len, clip, device)
    valid_loss = evaluate(model, valid_data, criterion, batch_size, 
                seq_len, device)

    lr_scheduler.step(valid_loss)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'best-val-lstm_lm.pt')

    print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
    print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')

                                                         

	Train Perplexity: 1265.043
	Valid Perplexity: 764.255


                                                         

	Train Perplexity: 946.504
	Valid Perplexity: 752.419


                                                         

	Train Perplexity: 879.165
	Valid Perplexity: 632.546


                                                         

	Train Perplexity: 704.028
	Valid Perplexity: 517.577


                                                         

	Train Perplexity: 590.239
	Valid Perplexity: 433.915


                                                         

	Train Perplexity: 516.964
	Valid Perplexity: 388.976


                                                         

	Train Perplexity: 465.581
	Valid Perplexity: 354.507


                                                         

	Train Perplexity: 424.476
	Valid Perplexity: 327.593


                                                         

	Train Perplexity: 391.424
	Valid Perplexity: 307.082


                                                         

	Train Perplexity: 365.786
	Valid Perplexity: 291.278


                                                         

	Train Perplexity: 342.936
	Valid Perplexity: 278.772


                                                         

	Train Perplexity: 324.339
	Valid Perplexity: 268.201


                                                         

	Train Perplexity: 308.377
	Valid Perplexity: 258.599


                                                         

	Train Perplexity: 293.791
	Valid Perplexity: 250.468


                                                         

	Train Perplexity: 280.457
	Valid Perplexity: 244.108


                                                         

	Train Perplexity: 268.350
	Valid Perplexity: 237.294


                                                         

	Train Perplexity: 257.487
	Valid Perplexity: 232.416


                                                         

	Train Perplexity: 246.921
	Valid Perplexity: 227.253


                                                         

	Train Perplexity: 237.211
	Valid Perplexity: 222.920


                                                         

	Train Perplexity: 228.283
	Valid Perplexity: 219.622


## Training Result Summary

| Metric                  | Value                                   |
|-------------------------|-----------------------------------------|
| Device Used             | CPU                                     |
| Training Loss           | 5.4306 (from final perplexity)          |
| Training Perplexity     | 228.283                                 |
| Total Trainable Params  | 20,832,094                              |


## 6. Testing

In [26]:
model.load_state_dict(torch.load('./model/best-val-lstm_lm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

Test Perplexity: 220.573


In [None]:
# Evaluate on test set
print(f"Test Loss: {test_loss:.3f}")

Test Loss: 5.396


## Performance Results

After 20 epochs of training, the model achieved the following performance:

| Dataset Split | Perplexity |
|--------------|------------|
| Training     | 228.283    |
| Validation   | 219.622    |
| Test         | 220.573    |


## 7. Inference and Text Generation

The trained model supports **autoregressive text generation**, where each subsequent token is predicted based on the previously generated tokens and an initial input prompt.

### Temperature Scaling
A **temperature parameter** is used during inference to control the randomness of token sampling:

- **Low Temperature (e.g., 0.5):**  
  Produces more confident and predictable text by favoring high-probability tokens.

- **High Temperature (e.g., 1.0):**  
  Encourages greater diversity and creativity by allowing lower-probability tokens to be sampled more frequently.

This mechanism enables a trade-off between coherence and variability in generated text.


In [28]:
def generate(prompt, max_seq_len, temperature, model, tokenizer, vocab, device, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    tokens = tokenizer(prompt)
    indices = [vocab[t] for t in tokens]
    batch_size = 1
    hidden = model.init_hidden(batch_size, device)
    with torch.no_grad():
        for i in range(max_seq_len):
            src = torch.LongTensor([indices]).to(device)
            prediction, hidden = model(src, hidden)
            
            #prediction: [batch size, seq len, vocab size]
            #prediction[:, -1]: [batch size, vocab size] #probability of last vocab
            
            probs = torch.softmax(prediction[:, -1] / temperature, dim=-1)  
            prediction = torch.multinomial(probs, num_samples=1).item()    
            
            while prediction == vocab['<unk>']: #if it is unk, we sample again
                prediction = torch.multinomial(probs, num_samples=1).item()

            if prediction == vocab['<eos>']:    #if it is eos, we stop
                break

            indices.append(prediction) #autoregressive, thus output becomes input

    itos = vocab.get_itos()
    tokens = [itos[i] for i in indices]
    return tokens

In [32]:
prompt = 'He is a'
max_seq_len = 30
seed = 0

#smaller the temperature, more diverse tokens but comes 
#with a tradeoff of less-make-sense sentence
temperatures = [0.5, 0.7, 0.75, 0.8, 1.0]
for temperature in temperatures:
    generation = generate(prompt, max_seq_len, temperature, model, tokenizer, 
                          vocab, device, seed)
    print(str(temperature)+'\n'+' '.join(generation)+'\n')

0.5
he is a ' bad '

0.7
he is a lot of an year . he says he was not guilty to life through a road to the central day .

0.75
he is a lot of an year . he says he was not guilty to pounds to ' t die before the central press ?

0.8
he is a lot of an year . he says he was not guilty to pounds to ' t die before the central press ?

1.0
he is a boost of an high range . the new jersey has been selected by some six years .

