## CS310 Natural Language Processing
## Assignment 3. Recurrent Neural Networks for Language Modeling 

**Total points**: 

In this assignment, you will train a vanilla RNN-based language model on the Harry Potter text data. 

### 0. Import Necessary Libraries

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import re
import random
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter
from utils import build_training_visualization

### 1. Build the Model

In [2]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

os.environ["CUDA_VISIBLE_DEVICES"] = "3"  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

EMBEDDING_DIM = 128
HIDDEN_DIM = 256
NUM_LAYERS = 2
DROPOUT_RATE = 0.2
BATCH_SIZE = 64
SEQ_LENGTH = 50
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
MIN_WORD_FREQ = 2

DATA_PATH = '/home/stu_12310401/nlp/SUSTech-NLP25/Ass3/Harry_Potter_all_books_preprocessed.txt'
MODEL_SAVE_PATH = '/home/stu_12310401/nlp/SUSTech-NLP25/Ass3/rnn_lm_model.pth'
VISUALIZATION_PATH = '/home/stu_12310401/nlp/SUSTech-NLP25/Ass3/rnn_lm_training.png'

Using device: cuda


In [3]:
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

In [None]:
def load_and_preprocess_data(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    
    tokens = word_tokenize(text)
    
    word_counts = Counter(tokens)
    vocab = ['<PAD>', '<UNK>', '<START>', '<END>'] + [word for word, count in word_counts.most_common() if count >= MIN_WORD_FREQ]
    word_to_idx = {word: idx for idx, word in enumerate(vocab)}
    idx_to_word = {idx: word for idx, word in enumerate(vocab)}
    
    text_indices = []
    for token in tokens:
        if token in word_to_idx:
            text_indices.append(word_to_idx[token])
        else:
            text_indices.append(word_to_idx['<UNK>'])
    
    return ' '.join(tokens), text_indices, word_to_idx, idx_to_word, len(vocab)


In [None]:
class TextDataset(Dataset):
    def __init__(self, text_indices, seq_length):
        self.text_indices = text_indices
        self.seq_length = seq_length
        
    def __len__(self):
        return len(self.text_indices) - self.seq_length
    
    def __getitem__(self, idx):
        sequence = self.text_indices[idx:idx+self.seq_length]
        target = self.text_indices[idx+1:idx+self.seq_length+1]
        
        return torch.tensor(sequence, dtype=torch.long), torch.tensor(target, dtype=torch.long)

In [None]:
class RNNModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout_rate):
        super(RNNModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout_rate if num_layers > 1 else 0)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.dropout = nn.Dropout(dropout_rate)
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
    def forward(self, x, hidden=None):
        if hidden is None:
            hidden = self.init_hidden(x.size(0))
            
        embedded = self.embedding(x)
        embedded = self.dropout(embedded)
        
        output, hidden = self.rnn(embedded, hidden)
        output = self.dropout(output)
        
        output = self.fc(output)
        
        return output, hidden
    
    def init_hidden(self, batch_size):
        return torch.zeros(self.num_layers, batch_size, self.hidden_dim).to(device)

### 2. Train and Evaluate

In [None]:
def train(model, dataloader, criterion, optimizer, clip_value=5.0):
    model.train()
    total_loss = 0
    
    for inputs, targets in tqdm(dataloader, desc="Training"):
        inputs, targets = inputs.to(device), targets.to(device)
        
        hidden = model.init_hidden(inputs.size(0))
        
        optimizer.zero_grad()
        
        outputs, hidden = model(inputs, hidden)
        
        outputs = outputs.reshape(-1, outputs.size(2))
        targets = targets.reshape(-1)
        
        loss = criterion(outputs, targets)
        
        loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [None]:
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Evaluating"):
            inputs, targets = inputs.to(device), targets.to(device)
            
            hidden = model.init_hidden(inputs.size(0))
            
            outputs, hidden = model(inputs, hidden)
            
            outputs = outputs.reshape(-1, outputs.size(2))
            targets = targets.reshape(-1)
            
            loss = criterion(outputs, targets)
            
            total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [None]:
def generate_text(model, seed_text, word_to_idx, idx_to_word, length=20, temperature=1.0):
    model.eval()
    
    seed_tokens = word_tokenize(seed_text)
    
    seed_indices = []
    for token in seed_tokens:
        if token in word_to_idx:
            seed_indices.append(word_to_idx[token])
        else:
            seed_indices.append(word_to_idx['<UNK>'])
    
    if len(seed_indices) < SEQ_LENGTH:
        seed_indices = [word_to_idx['<PAD>']] * (SEQ_LENGTH - len(seed_indices)) + seed_indices
    elif len(seed_indices) > SEQ_LENGTH:
        seed_indices = seed_indices[-SEQ_LENGTH:]
    
    current_indices = seed_indices.copy()
    generated_tokens = seed_tokens.copy()
    
    hidden = model.init_hidden(1)
    
    with torch.no_grad():
        for _ in range(length):
            x = torch.tensor([current_indices], dtype=torch.long).to(device)
            output, hidden = model(x, hidden)
            
            output = output[0, -1, :] / temperature
            probabilities = torch.softmax(output, dim=0)
            
            next_index = torch.multinomial(probabilities, 1).item()
            
            generated_tokens.append(idx_to_word[next_index])
            
            current_indices = current_indices[1:] + [next_index]
    
    return ' '.join(generated_tokens)

In [None]:
def generate_text_greedy(model, seed_text, word_to_idx, idx_to_word, length=20):
    model.eval()
    
    seed_tokens = word_tokenize(seed_text)
    
    seed_indices = []
    for token in seed_tokens:
        if token in word_to_idx:
            seed_indices.append(word_to_idx[token])
        else:
            seed_indices.append(word_to_idx['<UNK>'])
    
    if len(seed_indices) < SEQ_LENGTH:
        seed_indices = [word_to_idx['<PAD>']] * (SEQ_LENGTH - len(seed_indices)) + seed_indices
    elif len(seed_indices) > SEQ_LENGTH:
        seed_indices = seed_indices[-SEQ_LENGTH:]
    
    current_indices = seed_indices.copy()
    generated_tokens = seed_tokens.copy()
    
    hidden = model.init_hidden(1)
    
    with torch.no_grad():
        for _ in range(length):
            x = torch.tensor([current_indices], dtype=torch.long).to(device)
            output, hidden = model(x, hidden)
            
            output = output[0, -1, :]
            
            next_index = torch.argmax(output).item()
            
            generated_tokens.append(idx_to_word[next_index])
            
            current_indices = current_indices[1:] + [next_index]
    
    return ' '.join(generated_tokens)

In [10]:
def calculate_perplexity(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    total_words = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Calculating Perplexity"):
            inputs, targets = inputs.to(device), targets.to(device)
            hidden = model.init_hidden(inputs.size(0))
            outputs, _ = model(inputs, hidden)
            
            outputs = outputs.reshape(-1, outputs.size(2))
            targets = targets.reshape(-1)
            
            loss = criterion(outputs, targets)
            total_loss += loss.item() * targets.size(0)
            total_words += targets.size(0)
    
    avg_loss = total_loss / total_words
    perplexity = torch.exp(torch.tensor(avg_loss)).item()
    return perplexity

In [None]:
print("Loading and preprocessing data...")
text, text_indices, word_to_idx, idx_to_word, vocab_size = load_and_preprocess_data(DATA_PATH)
print(f"Vocabulary size: {vocab_size}")

dataset = TextDataset(text_indices, SEQ_LENGTH)

total_size = len(dataset)
train_size = int(0.9 * total_size)
val_size = int(0.05 * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, 
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Loading and preprocessing data...
Vocabulary size: 18365


In [12]:
model = RNNModel(vocab_size, EMBEDDING_DIM, HIDDEN_DIM, NUM_LAYERS, DROPOUT_RATE).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

## Attention 
Because I actually use a .py to run the code, I will load the model directly, but you may need to skip the load cell and train the model .

由于我在服务器上训练，为了避免训练中断，我要使用nohup命令。所以实际运行时我将code汇总到一个py脚本来训练模型，并保存checkpoint，所以在notebook上直接加载了权重。

In [25]:
# 从检查点加载模型
if os.path.exists(MODEL_SAVE_PATH):
    print(f"Loading checkpoint from {MODEL_SAVE_PATH}")
    checkpoint = torch.load(MODEL_SAVE_PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    train_loss = checkpoint['train_loss']
    val_loss = checkpoint['val_loss']
    model.eval()  
    print(model)
    print(f"train_loss is {train_loss}")


Loading checkpoint from /home/stu_12310401/nlp/SUSTech-NLP25/Ass3/rnn_lm_model.pth
RNNModel(
  (embedding): Embedding(18365, 128)
  (rnn): RNN(128, 256, num_layers=2, batch_first=True, dropout=0.2)
  (fc): Linear(in_features=256, out_features=18365, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)
train_loss is 4.019659431076344


  checkpoint = torch.load(MODEL_SAVE_PATH)


In [None]:
# skiped
print("Starting training...")
train_losses = []
val_losses = []

for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    
    # Train
    train_loss = train(model, train_dataloader, criterion, optimizer)
    train_losses.append(train_loss)
    
    # Evaluate
    val_loss = evaluate(model, val_dataloader, criterion)
    val_losses.append(val_loss)
    
    print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    
    # Generate sample text
    seed_text = text[:SEQ_LENGTH]
    generated_text = generate_text(model, seed_text, word_to_idx, idx_to_word, length=30)
    print(f"Generated Text Sample:\n{generated_text}\n")
    
    # Save model checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, MODEL_SAVE_PATH)

In [None]:
# skiped
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.savefig(VISUALIZATION_PATH)

train_metrics = {'loss': train_losses}
validation_metrics = {'loss': val_losses}
build_training_visualization('RNN Language Model', train_metrics, train_losses, validation_metrics, VISUALIZATION_PATH)


In [19]:
seed_text = "Harry Potter "
generated_text = generate_text(model, seed_text, word_to_idx, idx_to_word, length=50, temperature=0.8)
print(f"Final Generated Text Sample:\n{generated_text}")

Final Generated Text Sample:
Harry Potter would he have had to do it in in Harrys mind he felt a thrill of terror .then hatred of having the confusion has been placed upon the office itself in the castle .The last of the castle had been plastered silent and silent but by elves who was now


In [24]:
prefixes=['Harry look','Hermione open','Ron run','Magic is','Professor Dumbledore']
print("Generating sentences using greedy search:")
print("-" * 50)

for prefix in prefixes:
    generated_text = generate_text_greedy(model, prefix, word_to_idx, idx_to_word, length=15)
    print(f"Prefix: '{prefix}'")
    print(f"Generated: '{generated_text}'")
    print("-" * 50)

Generating sentences using greedy search:
--------------------------------------------------
Prefix: 'Harry look'
Generated: 'Harry look as though he had been forced to say that he was not to mention it'
--------------------------------------------------
Prefix: 'Hermione open'
Generated: 'Hermione open the door of the hall and the door opened and Harry saw the sound of'
--------------------------------------------------
Prefix: 'Ron run'
Generated: 'Ron run off the castle and the castle was completely empty and silent as they were all'
--------------------------------------------------
Prefix: 'Magic is'
Generated: 'Magic is a lot of <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>'
--------------------------------------------------
Prefix: 'Professor Dumbledore'
Generated: 'Professor Dumbledore who was unsticking his lemon <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>'
--------------------------------------------------


In [20]:
print("\nCalculating final perplexity scores...")
# train_perplexity = calculate_perplexity(model, train_dataloader, criterion)
# val_perplexity = calculate_perplexity(model, val_dataloader, criterion)
test_perplexity = calculate_perplexity(model, test_dataloader, criterion)

print(f"\nFinal Perplexity Scores:")
# print(f"Train Perplexity: {train_perplexity:.2f}")
# print(f"Validation Perplexity: {val_perplexity:.2f}")
print(f"Test Perplexity: {test_perplexity:.2f}")


Calculating final perplexity scores...


Calculating Perplexity: 100%|██████████| 864/864 [00:05<00:00, 154.17it/s]


Final Perplexity Scores:
Test Perplexity: 32.94



