# **Implementation Plan: LSTM-based Next Word Prediction (Text-to-Text Generation)**

---

### 🔹 **Step 1: Load the Dataset**

**What we do**:
Load the **WikiText-2** dataset from Hugging Face ([`wikitext-2-raw-v1`](https://www.salesforce.com/blog/the-wikitext-long-term-dependency-language-modeling-dataset/)) or TorchText.

**Dataset Description**:

* A cleaned, curated subset of **Wikipedia articles**.
* Maintains natural sentence structure and punctuation.
* Specifically designed for **language modeling** and **next-token prediction** tasks.

**Statistics**:

| Split | # Articles | # Tokens    | Vocab Size |
| ----- | ---------- | ----------- | ---------- |
| Train | \~600      | \~2 million | \~33,000   |
| Valid | \~60       | \~217,000   | \~12,000   |
| Test  | \~60       | \~245,000   | \~12,000   |

**Why this step**:
Gives us real-world, high-quality English text suitable to train a language model (next word predictor).

---

### 🔹 **Step 2: Text Cleaning (Optional)**

**What we do**:

* Remove empty lines, special characters, or unnecessary headers.
* Normalize case if needed (e.g., lowercase all text).

**Examples**:

| Raw Line                      | Cleaned Output             |
| ----------------------------- | -------------------------- |
| `= Valkyria Chronicles III =` | *(removed)*                |
| `The story takes place...`    | `The story takes place...` |
| ` `                           | *(removed blank line)*     |


**Why**:

* Prepares clean, meaningful input for the tokenizer.
* Ensures the model isn’t biased by formatting artifacts.

---

### 🔹 **Step 3: Tokenization**

**What we do**:

* Break down each line of text into individual **tokens** (usually words or subwords).
* Use simple whitespace or `basic_english` (lowercasing + punctuation splitting) tokenizer.

**Examples**:

| Input Sentence            | Tokenized Output                                 |
| ------------------------- | ------------------------------------------------ |
| `The cat sat on the mat.` | `['the', 'cat', 'sat', 'on', 'the', 'mat', '.']` |
| `He didn't know.`         | `['he', 'didn', "'", 't', 'know', '.']`          |

**Why**:

* Language models work on tokens—not raw strings.
* Tokenization defines the unit of prediction (e.g., word vs subword).

---

### 🔹 **Step 4: Vocabulary Creation**

**What we do**:

* Build a **vocabulary dictionary**: `token → index` and `index → token`.
* Assign special tokens like `<pad>` and `<unk>`.

**Examples**:

| Token   | Index |
| ------- | ----- |
| `<pad>` | 0     |
| `<unk>` | 1     |
| `the`   | 2     |
| `cat`   | 3     |
| `sat`   | 4     |
| `.`     | 5     |


**Why**:

* Converts words into numerical IDs (**integers**) that can be processed by neural networks.
* Allows consistent mapping from token strings to embeddings.

---

### 🔹 **Step 5: Numericalization**

**What we do**:

* Convert the tokenized corpus into a long list of integers using the vocabulary.

**Example**:

| Tokens                  | Encoded     |
| ----------------------- | ----------- |
| `['the', 'cat', 'sat']` | `[2, 3, 4]` |


**Why**:

* Neural networks only understand numbers.
* Numericalization allows us to feed sequences into embedding layers.

---

### 🔹 **Step 6: Create (Input, Target) Training Pairs**

**What we do**:

* For a given sequence length `n`:

  * Input = first `n` tokens (e.g., `"the cat sat on"`)
  * Target = next token (`"the cat sat on → mat"`)
  
  
  | Input (X)               | Target (y) |
  | ----------------------- | ---------- |
  | `['the', 'cat', 'sat']` | `'on'`     |
  | `['cat', 'sat', 'on']`  | `'the'`    |

  Then numericalize:

  * `X = [2, 3, 4]`
  * `y = 5` (index of `'on'`)


**Why**:

* This forms the supervised data to train next-word prediction.
* The model learns to predict the next word given a context window.

---

### 🔹 **Step 7: Dataset and DataLoader (Batching)**

**What we do**:

* Wrap input-output pairs in a custom `Dataset` class.
* Use PyTorch’s `DataLoader` to batch and shuffle the data.

**Why**:

* Batching speeds up training.
* DataLoader automates shuffling, batching, and parallel loading.

---

### 🔹 **Step 8: Define the LSTM Model**

**What we do**:

* Build a simple neural network:

  * Embedding layer → LSTM → Fully Connected layer → Vocabulary size logits

**Why**:

* LSTM captures temporal sequence information.
* Embedding layer maps word IDs to dense vectors.
* Output layer predicts the next word by scoring each token in the vocab.

---

### 🔹 **Step 9: Train the Model**

**What we do**:

* Loop over batches, compute loss (`CrossEntropy`), backpropagate, and update weights using an optimizer (e.g., Adam).

**Why**:

* The model minimizes the difference between predicted and actual next tokens.
* Trains the model to assign high probability to correct next word.

---

### 🔹 **Step 10: Evaluate the Model (Optional)**

**What we do**:

* Measure **perplexity** on validation data.
* Try generating predictions for a few sample input sequences.

**Why**:

* Perplexity is a standard metric for next-word prediction.
* Sample outputs show how well the model generalizes to unseen inputs.

---

### 🔹 **Step 11: Save the Model (Optional)**

**What we do**:

* Save model weights and vocabulary for future use.

**Why**:

* Enables reuse for inference, fine-tuning, or switching to a Transformer later.

---

# **Step 0: Install python packages**

In [None]:
!pip install -U datasets huggingface_hub fsspec

# **Step 1: Load the Dataset**

In [None]:
from datasets import load_dataset

In [None]:
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
print(dataset)


In [None]:
# Access splits
train_data = dataset["train"]
valid_data = dataset["validation"]
test_data = dataset["test"]

In [None]:
# Show sample lines
for i in range(5):
    print(train_data[i]["text"])

# **Step 2: Text Cleaning**
1. Remove leading/trailing whitespace and filter out empty lines
2. Do not remove Prediction
  
  a. The model needs to learn how punctuation affects meaning and sentence boundaries.

  b. Punctuation like ```., ,, ?, !``` is part of the language structure and influences what the next word could be.

In [None]:
# Step 2: Text Cleaning (no tokenization)


def clean_text(data_split):
  # Remove leading/trailing whitespace and filter out empty lines
  cleaned_lines = []
  for entry in data_split:
      line = entry["text"].strip()
      if line:  # skip empty lines
          cleaned_lines.append(line)
  return cleaned_lines

cleaned_lines_train = clean_text(train_data)
cleaned_lines_valid = clean_text(valid_data)
cleaned_lines_test = clean_text(test_data)
# Show first 5 cleaned lines
for i in range(5):
    print(cleaned_lines_train[i])


for i in range(5):
    print(cleaned_lines_valid[i])


for i in range(5):
    print(cleaned_lines_test[i])

# **Step 3: Tokenization**

In [None]:
import re

def simple_tokenizer(text):
    # Lowercase, split on word boundaries and keep punctuation
    return re.findall(r"\w+|[^\w\s]", text.lower(), re.UNICODE)

# Tokenize cleaned lines
tokenized_lines_train = [simple_tokenizer(line) for line in cleaned_lines_train if line.strip()]
tokenized_lines_valid = [simple_tokenizer(line) for line in cleaned_lines_valid if line.strip()]
tokenized_lines_test = [simple_tokenizer(line) for line in cleaned_lines_test if line.strip()]

# Show first 5 tokenized lines
for i in range(5):
    print(tokenized_lines_train[i])
for i in range(5):
    print(tokenized_lines_valid[i])
for i in range(5):
    print(tokenized_lines_test[i])



# **Step 4: Vocabulary Creation**

In [None]:
from collections import Counter

tokenized_lines= tokenized_lines_train + tokenized_lines_valid + tokenized_lines_test

print("total sample: ",  len(tokenized_lines))

# Flatten all tokens into a single list
all_tokens = [token for line in tokenized_lines for token in line]

# Count token frequencies
token_freq = Counter(all_tokens)

# Build vocabulary: start indexing from 2 (reserve 0 for <pad>, 1 for <unk>)
vocab = {token: idx + 2 for idx, (token, _) in enumerate(token_freq.most_common())}
vocab["<pad>"] = 0
vocab["<unk>"] = 1

# Create inverse vocabulary for decoding later
inv_vocab = {idx: token for token, idx in vocab.items()}

# Show a few items
print("Vocabulary size:", len(vocab))
print("Sample vocab entries:", list(vocab.items())[:10])


# **Step 5: Numericalization**

In [None]:
# Step 5: Numericalization

# Convert each token in tokenized lines to its corresponding index
def numericalization(tokenized_line):
  numericalized_lines = []
  for line in tokenized_line:
      encoded = [vocab.get(token, vocab["<unk>"]) for token in line]
      numericalized_lines.append(encoded)
  return numericalized_lines


numericalized_lines_train = numericalization(tokenized_lines_train)
numericalized_lines_valid = numericalization(tokenized_lines_valid)
numericalized_lines_test = numericalization(tokenized_lines_test)

# Show first 5 numericalized lines
for i in range(5):
    print(numericalized_lines_train[i])
for i in range(5):
    print(numericalized_lines_valid[i])
for i in range(5):
    print(numericalized_lines_test[i])


# **Step 6: Create (Input, Target) Training Pairs.**

In [None]:
# Define context window size (sequence length)
sequence_length = 5  # You can change this

# Function to convert numericalized lines into (input, target) pairs
def create_input_target_pairs(numericalized_lines, sequence_length):
    pairs = []
    for line in numericalized_lines:
        if len(line) > sequence_length:
            for i in range(len(line) - sequence_length):
                input_seq = line[i:i + sequence_length]
                target = line[i + sequence_length]
                pairs.append((input_seq, target))
    return pairs

# Apply to each dataset split
train_pairs = create_input_target_pairs(numericalized_lines_train, sequence_length)
valid_pairs = create_input_target_pairs(numericalized_lines_valid, sequence_length)
test_pairs  = create_input_target_pairs(numericalized_lines_test, sequence_length)

# Show first 5 (input, target) pairs from train
for i in range(5):
    print("Input:", train_pairs[i][0], "→ Target:", train_pairs[i][1])


# **Step 7: Wrap into Dataset and DataLoader**

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

# Custom Dataset class
class NextWordDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        input_seq, target = self.pairs[idx]
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target, dtype=torch.long)

# Create datasets
train_dataset = NextWordDataset(train_pairs)
valid_dataset = NextWordDataset(valid_pairs)
test_dataset  = NextWordDataset(test_pairs)

# DataLoaders
batch_size = 64

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size)
test_loader  = DataLoader(test_dataset, batch_size=batch_size)

# Check one batch
for X, y in train_loader:
    print("Input batch shape:", X.shape)
    print("Target batch shape:", y.shape)
    print("Example input:", X[0])
    print("Example target:", y[0])
    break


# **Step 8: Define the LSTM Model**

In [None]:
import torch.nn as nn

class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        # x: [batch_size, sequence_length]
        embeds = self.embedding(x)  # [batch_size, sequence_length, embedding_dim]
        lstm_out, _ = self.lstm(embeds)  # [batch_size, sequence_length, hidden_dim]
        last_output = lstm_out[:, -1, :]  # Take the last output for next-word prediction
        logits = self.fc(last_output)  # [batch_size, vocab_size]
        return logits

# Define model with hyperparameters
embedding_dim = 128
hidden_dim = 256
vocab_size = len(vocab)

model = LSTMModel(vocab_size, embedding_dim, hidden_dim)

# Print model summary
print(model)


# **Step 9: Train the Model.**

In [None]:
import torch.optim as optim

# Set device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
import math
from tqdm import tqdm
# Training loop
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch {epoch+1} [Training]")

    for inputs, targets in train_loader_tqdm:
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        train_loader_tqdm.set_postfix(loss=loss.item())

    avg_train_loss = total_loss / len(train_loader)
    train_perplexity = math.exp(avg_train_loss)

    # Validation loop
    model.eval()
    val_loss = 0
    with torch.no_grad():
        val_loader_tqdm = tqdm(valid_loader, desc=f"Epoch {epoch+1} [Validation]")
        for inputs, targets in val_loader_tqdm:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()
            val_loader_tqdm.set_postfix(loss=loss.item())

    avg_val_loss = val_loss / len(valid_loader)
    val_perplexity = math.exp(avg_val_loss)

    print(f"\nEpoch {epoch+1} Summary:")
    print(f"Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.2f}")
    print(f"Val   Loss: {avg_val_loss:.4f}, Val   Perplexity: {val_perplexity:.2f}\n")


## **Step 10: Evaluate the Model**

In [None]:
import math
from tqdm import tqdm

# Step 10: Evaluate the Model on Test Set
def evaluate_model(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        test_loader_tqdm = tqdm(dataloader, desc="Testing")
        for inputs, targets in test_loader_tqdm:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
            test_loader_tqdm.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(dataloader)
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity

# Run evaluation
test_loss, test_perplexity = evaluate_model(model, test_loader, criterion)

print(f"\nTest Loss: {test_loss:.4f}")
print(f"Test Perplexity: {test_perplexity:.2f}")


# **Step 11: Save the Model and Vocab**

In [None]:
import torch
import json
import os

# Create output directory if not exists
os.makedirs("model_output", exist_ok=True)

# Save model state
model_path = "model_output/lstm_nextword_model.pt"
torch.save(model.state_dict(), model_path)

# Save vocab as JSON
vocab_path = "model_output/vocab.json"
with open(vocab_path, "w", encoding="utf-8") as f:
    json.dump(vocab, f)

# Save inverse vocab as well (optional, for decoding)
inv_vocab_path = "model_output/inv_vocab.json"
with open(inv_vocab_path, "w", encoding="utf-8") as f:
    json.dump(inv_vocab, f)

print(f"✅ Model saved to: {model_path}")
print(f"✅ Vocab saved to: {vocab_path}")


# **Inference**

In [None]:
import torch
import torch.nn as nn
import json

# --- Load saved vocab and model ---
with open("model_output/vocab.json", "r", encoding="utf-8") as f:
    vocab = json.load(f)
with open("model_output/inv_vocab.json", "r", encoding="utf-8") as f:
    inv_vocab = {int(k): v for k, v in json.load(f).items()}

vocab_size = len(vocab)
embedding_dim = 128
hidden_dim = 256
sequence_length = 5

# Define the same LSTM model architecture
class LSTMModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(LSTMModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embeds = self.embedding(x)
        lstm_out, _ = self.lstm(embeds)
        last_output = lstm_out[:, -1, :]
        logits = self.fc(last_output)
        return logits

# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LSTMModel(vocab_size, embedding_dim, hidden_dim)
model.load_state_dict(torch.load("model_output/lstm_nextword_model.pt", map_location=device))
model.to(device)
model.eval()

# --- Text generation ---
def generate_next_words(prompt, num_words):
    tokens = prompt.strip().lower().split()
    indices = [vocab.get(tok, vocab["<unk>"]) for tok in tokens]

    for _ in range(num_words):
        # Pad/truncate to fixed sequence length
        input_seq = indices[-sequence_length:]
        if len(input_seq) < sequence_length:
            input_seq = [vocab["<pad>"]] * (sequence_length - len(input_seq)) + input_seq
        input_tensor = torch.tensor([input_seq], dtype=torch.long).to(device)

        # Predict next word
        with torch.no_grad():
            logits = model(input_tensor)
            predicted_index = torch.argmax(logits, dim=1).item()
            predicted_token = inv_vocab.get(predicted_index, "<unk>")

        indices.append(predicted_index)

    # Convert all indices back to tokens
    output_tokens = [inv_vocab.get(idx, "<unk>") for idx in indices]
    return " ".join(output_tokens)




In [None]:
# --- Example ---
user_prompt = "India is a land of"
N = 10
output = generate_next_words(user_prompt, N)
print("\nGenerated text:")
print(output)

# **Exercise 1**
Write a code for inference that generates complete sentences.