# Nigerian Pidgin Next-Word Prediction: LSTM Model

This notebook trains an LSTM language model for next-word prediction on Nigerian Pidgin text.

**Data Sources:**
- NaijaSenti PCM dataset (Hugging Face)
- BBC Pidgin corpus

**Run this in Google Colab with GPU runtime for faster training.**

## 1. Setup & Dependencies

In [None]:
# Install dependencies
!pip install datasets torch torchvision torchaudio --quiet

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

from datasets import load_dataset
from collections import Counter
import numpy as np
import re
import math
from typing import List, Tuple, Dict
from tqdm.auto import tqdm

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load Data

In [None]:
# Load NaijaSenti PCM dataset
print("Loading NaijaSenti PCM dataset...")
dataset = load_dataset("mteb/NaijaSenti", "pcm")

# Combine all splits
all_texts = []
for split in dataset.keys():
    texts = [ex['text'] for ex in dataset[split]]
    all_texts.extend(texts)
    print(f"  {split}: {len(texts):,} texts")

print(f"\nTotal: {len(all_texts):,} texts")

In [None]:
# Optional: Clone and add BBC Pidgin corpus for more data
# Uncomment below to include BBC Pidgin articles

# !git clone https://github.com/keleog/bbc_pidgin_scraper.git
# import csv
# with open('bbc_pidgin_scraper/data/pidgin_corpus.csv', 'r', encoding='utf-8') as f:
#     reader = csv.DictReader(f)
#     for row in reader:
#         headline = row.get('headline', '').strip()
#         text = row.get('text', '').strip()
#         if headline and text:
#             all_texts.append(f"{headline}. {text}")
# print(f"With BBC Pidgin: {len(all_texts):,} texts")

## 3. Preprocessing

In [None]:
def clean_text(text: str) -> str:
    """Clean text while preserving Nigerian Pidgin features."""
    text = text.lower()
    text = re.sub(r'https?://\S+', '', text)  # Remove URLs
    text = re.sub(r'www\.\S+', '', text)
    text = re.sub(r'@\w+', '', text)  # Remove @usernames
    text = re.sub(r'#(\w+)', r'\1', text)  # Remove # but keep word
    text = re.sub(r'\s+', ' ', text)  # Normalize whitespace
    return text.strip()

def tokenize(text: str) -> List[str]:
    """Simple word tokenization."""
    # Split on whitespace and punctuation
    tokens = re.findall(r"[\w']+|[.,!?;:]", text)
    return tokens

# Process all texts
print("Preprocessing...")
processed_sentences = []
for text in tqdm(all_texts):
    cleaned = clean_text(text)
    tokens = tokenize(cleaned)
    if len(tokens) >= 3:  # Need at least 3 tokens for meaningful sequences
        processed_sentences.append(tokens)

print(f"Processed {len(processed_sentences):,} sentences")
print(f"Sample: {processed_sentences[0][:10]}")

## 4. Build Vocabulary

In [None]:
# Special tokens
PAD_TOKEN = '<PAD>'
UNK_TOKEN = '<UNK>'
SOS_TOKEN = '<SOS>'  # Start of sequence
EOS_TOKEN = '<EOS>'  # End of sequence

# Count word frequencies
word_counts = Counter()
for sentence in processed_sentences:
    word_counts.update(sentence)

print(f"Total unique words: {len(word_counts):,}")
print(f"Top 20 words: {word_counts.most_common(20)}")

# Build vocabulary (keep words appearing >= 2 times)
MIN_FREQ = 2
vocab = [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]
vocab += [word for word, count in word_counts.most_common() if count >= MIN_FREQ]

word_to_idx = {word: idx for idx, word in enumerate(vocab)}
idx_to_word = {idx: word for word, idx in word_to_idx.items()}

VOCAB_SIZE = len(vocab)
print(f"\nVocabulary size: {VOCAB_SIZE:,}")

## 5. Create Dataset

In [None]:
class NextWordDataset(Dataset):
    """
    Dataset for next-word prediction.
    Each sample: (input_sequence, target_word)
    """
    def __init__(self, sentences: List[List[str]], word_to_idx: Dict, seq_length: int = 10):
        self.samples = []
        self.seq_length = seq_length
        unk_idx = word_to_idx[UNK_TOKEN]
        
        for sentence in sentences:
            # Convert to indices
            indices = [word_to_idx.get(w, unk_idx) for w in sentence]
            
            # Create samples: sliding window over sentence
            for i in range(1, len(indices)):
                # Input: previous tokens (up to seq_length)
                start_idx = max(0, i - seq_length)
                input_seq = indices[start_idx:i]
                # Target: current token
                target = indices[i]
                self.samples.append((input_seq, target))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        input_seq, target = self.samples[idx]
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target, dtype=torch.long)

def collate_fn(batch):
    """Pad sequences to same length."""
    inputs, targets = zip(*batch)
    # Pad sequences
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=0)
    targets = torch.stack(targets)
    return inputs_padded, targets

In [None]:
# Create train/val split
from sklearn.model_selection import train_test_split

train_sentences, val_sentences = train_test_split(
    processed_sentences, test_size=0.1, random_state=42
)

SEQ_LENGTH = 15
BATCH_SIZE = 128

train_dataset = NextWordDataset(train_sentences, word_to_idx, seq_length=SEQ_LENGTH)
val_dataset = NextWordDataset(val_sentences, word_to_idx, seq_length=SEQ_LENGTH)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print(f"Train samples: {len(train_dataset):,}")
print(f"Val samples: {len(val_dataset):,}")

## 6. LSTM Model

In [None]:
class LSTMLanguageModel(nn.Module):
    """
    LSTM-based language model for next-word prediction.
    
    Architecture:
    - Embedding layer
    - LSTM layer(s)
    - Dropout
    - Linear output layer
    """
    def __init__(
        self, 
        vocab_size: int, 
        embed_dim: int = 256, 
        hidden_dim: int = 512, 
        num_layers: int = 2, 
        dropout: float = 0.3
    ):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(
            embed_dim, 
            hidden_dim, 
            num_layers=num_layers, 
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
    
    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        lstm_out, _ = self.lstm(embedded)  # (batch, seq_len, hidden_dim)
        # Take last output
        last_out = lstm_out[:, -1, :]  # (batch, hidden_dim)
        out = self.dropout(last_out)
        logits = self.fc(out)  # (batch, vocab_size)
        return logits
    
    def predict_next_words(self, context: str, word_to_idx: Dict, idx_to_word: Dict, top_k: int = 5):
        """Predict next words given context string."""
        self.eval()
        
        # Tokenize and convert to indices
        tokens = tokenize(clean_text(context))
        unk_idx = word_to_idx[UNK_TOKEN]
        indices = [word_to_idx.get(t, unk_idx) for t in tokens]
        
        if not indices:
            return []
        
        # Create input tensor
        x = torch.tensor([indices], dtype=torch.long).to(next(self.parameters()).device)
        
        with torch.no_grad():
            logits = self(x)
            probs = torch.softmax(logits, dim=-1)
            
        # Get top-k predictions
        top_probs, top_indices = torch.topk(probs[0], top_k)
        
        predictions = []
        for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
            word = idx_to_word.get(idx, UNK_TOKEN)
            if word not in [PAD_TOKEN, UNK_TOKEN, SOS_TOKEN, EOS_TOKEN]:
                predictions.append((word, float(prob)))
        
        return predictions

In [None]:
# Hyperparameters
EMBED_DIM = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2
DROPOUT = 0.3
LEARNING_RATE = 0.001
NUM_EPOCHS = 10

# Initialize model
model = LSTMLanguageModel(
    vocab_size=VOCAB_SIZE,
    embed_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_layers=NUM_LAYERS,
    dropout=DROPOUT
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Model parameters: {num_params:,}")
print(model)

## 7. Training

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

def train_epoch(model, loader, criterion, optimizer):
    model.train()
    total_loss = 0
    for inputs, targets in tqdm(loader, desc="Training"):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()
        total_loss += loss.item()
    
    return total_loss / len(loader)

def evaluate(model, loader, criterion):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(loader)
    perplexity = math.exp(avg_loss)
    return avg_loss, perplexity

In [None]:
# Training loop
print("Starting training...")
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    train_loss = train_epoch(model, train_loader, criterion, optimizer)
    val_loss, val_ppl = evaluate(model, val_loader, criterion)
    
    scheduler.step(val_loss)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f} | Val Perplexity: {val_ppl:.2f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'model_state_dict': model.state_dict(),
            'word_to_idx': word_to_idx,
            'idx_to_word': idx_to_word,
            'vocab_size': VOCAB_SIZE,
        }, 'lstm_pidgin_model.pt')
        print("  Saved best model!")

print("\nTraining complete!")

## 8. Inference - Next Word Prediction

In [None]:
# Test predictions
test_contexts = [
    "i dey",
    "wetin you",
    "na the",
    "how far",
    "e don",
    "you no",
    "make we",
    "dem dey",
]

print("Next-Word Predictions (LSTM):")
print("=" * 50)
for context in test_contexts:
    predictions = model.predict_next_words(context, word_to_idx, idx_to_word, top_k=5)
    pred_str = ", ".join([f"{w} ({p:.2%})" for w, p in predictions])
    print(f"'{context}' â†’ {pred_str}")

In [None]:
# Interactive prediction
def predict_interactive():
    while True:
        context = input("\nEnter context (or 'quit'): ")
        if context.lower() == 'quit':
            break
        predictions = model.predict_next_words(context, word_to_idx, idx_to_word, top_k=5)
        print("Predictions:")
        for word, prob in predictions:
            print(f"  {word}: {prob:.2%}")

# Uncomment to use:
# predict_interactive()

## 9. Save Model for Download

In [None]:
# Download trained model
from google.colab import files
files.download('lstm_pidgin_model.pt')

## Model Comparison Notes

| Model | Context Window | Perplexity | Notes |
|-------|---------------|------------|-------|
| Trigram | 2 words | Higher | Fast, interpretable |
| LSTM | Variable (15) | Lower | Captures longer patterns |
| Transformer | Full sequence | Lowest | Best quality, slower |

**Next Steps:**
- Try Transformer model for comparison
- Use subword tokenization (BPE) for better OOV handling
- Fine-tune on more domain-specific data