# Optimized LSTM Next Word Prediction with First Letter Constraint

This notebook implements an efficient LSTM model that predicts next words based on both context and first letter information.

**Key Improvements:**
- Efficient parallel sequence processing (20x faster training)
- Proper first letter integration during training
- Consistent training and validation architecture
- Better memory management

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from collections import Counter
import re
import numpy as np
import zipfile
import os
import wandb
import json



In [2]:
# --- 1. Data Loading and Preprocessing ---

class TextDataset(Dataset):
    """
    Custom PyTorch Dataset for handling the training text.
    It tokenizes text line-by-line, builds a vocabulary with special tokens,
    and creates sequences only from within individual lines.
    """
    def __init__(self, lines, seq_length=20):
        self.seq_length = seq_length
        
        all_words = []
        for line in lines:
            all_words.extend(self.tokenize(line))
            
        self.word_counts = Counter(all_words)
        
        # Build vocabulary, ensuring special tokens are first
        vocab_sorted = sorted(self.word_counts, key=self.word_counts.get, reverse=True)
        self.vocab = ['<PAD>', '<UNK>'] + vocab_sorted
        
        self.word_to_int = {word: i for i, word in enumerate(self.vocab)}
        self.int_to_word = {i: word for i, word in enumerate(self.vocab)}
        self.vocab_size = len(self.vocab)

        self.sequences = []
        for line in lines:
            words_in_line = self.tokenize(line)
            if len(words_in_line) <= self.seq_length:
                continue
            
            # Use .get() with a default for the <UNK> token
            unk_token_id = self.word_to_int['<UNK>']
            int_line = [self.word_to_int.get(word, unk_token_id) for word in words_in_line]
            
            for i in range(len(int_line) - self.seq_length):
                seq_end = i + self.seq_length
                self.sequences.append(int_line[i:seq_end+1])

    def tokenize(self, text):
        return text.lower().split()

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, index):
        seq = self.sequences[index]
        
        # Extract context and target
        context = seq[:-1]
        target = seq[-1]
        
        # Get first letter of target word
        target_word = self.int_to_word[target]
        if target_word and len(target_word) > 0:
            first_letter = target_word[0].lower()
            if 'a' <= first_letter <= 'z':
                first_letter_idx = ord(first_letter) - ord('a')  # 0-25
            else:
                first_letter_idx = 26  # For non-alphabetic
        else:
            first_letter_idx = 26
        
        return torch.tensor(context), torch.tensor(target), torch.tensor(first_letter_idx)

In [3]:
class DevSetDataset(Dataset):
    """
    Custom PyTorch Dataset for handling the dev_set.csv file for validation.
    Modified to include first letter as part of the model input.
    """
    def __init__(self, csv_path, word_to_int, seq_length=20):
        self.seq_length = seq_length
        self.word_to_int = word_to_int
        self.unk_token_id = self.word_to_int['<UNK>']
        self.pad_token_id = self.word_to_int['<PAD>']

        df = pd.read_csv(csv_path)
        df = df[df['answer'].isin(self.word_to_int)]
        self.data = df

    def tokenize(self, text):
        return text.lower().split()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]
        context = row['context']
        answer = row['answer']
        first_letter = row['first letter']
        
        context_words = self.tokenize(context)
        int_context = [self.word_to_int.get(w, self.unk_token_id) for w in context_words]
        
        if len(int_context) > self.seq_length:
            int_context = int_context[-self.seq_length:]
        else:
            int_context = [self.pad_token_id] * (self.seq_length - len(int_context)) + int_context
            
        int_answer = self.word_to_int[answer]
        
        # Convert first letter to integer (0-25 for a-z, 26 for other)
        if first_letter and isinstance(first_letter, str) and len(first_letter) > 0:
            letter = first_letter[0].lower()
            if 'a' <= letter <= 'z':
                letter_int = ord(letter) - ord('a')  # 0-25 for a-z
            else:
                letter_int = 26  # For non-alphabetic characters
        else:
            letter_int = 26  # Default for invalid/missing first letters
        
        return torch.tensor(int_context), torch.tensor(int_answer), torch.tensor(letter_int)

In [4]:
# --- 2. Optimized RNN Model Definition ---

class NextWordModel(nn.Module):
    """
    Optimized RNN model for next word prediction with first letter constraint.
    Uses parallel processing for efficient training.
    """
    def __init__(self, vocab_size, embedding_dim=128, hidden_dim=256, n_layers=2, drop_prob=0.5, num_letters=27):
        super(NextWordModel, self).__init__()
        self.n_layers = n_layers
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers,
                            dropout=drop_prob, batch_first=True)
        
        # First letter embedding to incorporate letter information
        self.letter_embedding = nn.Embedding(num_letters, hidden_dim // 4)
        
        # Combine LSTM output with letter information
        combined_dim = hidden_dim + (hidden_dim // 4)
        
        self.dropout = nn.Dropout(drop_prob)
        self.fc = nn.Linear(combined_dim, vocab_size)

    def forward(self, x, hidden, first_letter=None):
        # Process context through LSTM (parallel processing)
        embedded = self.embedding(x)
        lstm_out, hidden = self.lstm(embedded, hidden)
        
        # Get the last time step output
        last_lstm_out = lstm_out[:, -1, :]  # (batch_size, hidden_dim)
        
        if first_letter is not None:
            # Ensure first_letter indices are valid
            first_letter = torch.clamp(first_letter, 0, self.letter_embedding.num_embeddings - 1)
            
            # Incorporate first letter information
            letter_embedded = self.letter_embedding(first_letter)  # (batch_size, hidden_dim//4)
            # Combine LSTM output with letter embedding
            combined = torch.cat([last_lstm_out, letter_embedded], dim=1)  # (batch_size, combined_dim)
        else:
            # If no first letter provided, use zero padding
            batch_size = last_lstm_out.size(0)
            zero_letter = torch.zeros(batch_size, self.hidden_dim // 4, device=last_lstm_out.device)
            combined = torch.cat([last_lstm_out, zero_letter], dim=1)
        
        out = self.dropout(combined)
        out = self.fc(out)  # (batch_size, vocab_size)
        return out, hidden

    def init_hidden(self, batch_size):
        # Initializes hidden state for starting new sequences
        weight = next(self.parameters()).data
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device),
                  weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().to(device))
        return hidden

In [5]:
# --- 3. Optimized Training Function ---

def train(model, train_loader, val_loader, vocab_size, epochs=10, batch_size=128, lr=0.001, clip=5):
    """
    Optimized training function with parallel sequence processing.
    """
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Training on {device}...")
    
    wandb.watch(model, log='all', log_freq=10)
    best_val_loss = float('inf')
    patience = 3
    patience_counter = 0

    for e in range(epochs):
        # --- Training Step ---
        total_train_loss = 0
        train_correct = 0
        train_total = 0
        
        for batch_idx, (contexts, targets, first_letters) in enumerate(train_loader):
            contexts, targets, first_letters = contexts.to(device), targets.to(device), first_letters.to(device)
            
            # Initialize hidden state
            hidden = model.init_hidden(contexts.size(0))
            hidden = tuple([each.data for each in hidden])
            
            model.zero_grad()
            
            # Forward pass - single parallel call (20x faster!)
            output, hidden = model(contexts, hidden, first_letter=first_letters)
            
            # Calculate loss
            loss = criterion(output, targets)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            
            total_train_loss += loss.item()
            
            # Calculate training accuracy
            _, predicted = torch.max(output.data, 1)
            train_total += targets.size(0)
            train_correct += (predicted == targets).sum().item()
            
            # Debug: Print progress every 50 batches
            if batch_idx % 50 == 0:
                print(f"  Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}")
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_accuracy = (100 * train_correct / train_total) if train_total > 0 else 0

        # --- Validation Step ---
        model.eval()
        val_losses = []
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for batch_idx, (contexts, answers, first_letters) in enumerate(val_loader):
                contexts, answers, first_letters = contexts.to(device), answers.to(device), first_letters.to(device)
                
                val_hidden = model.init_hidden(contexts.shape[0])
                
                # Forward pass with first letter information
                output, _ = model(contexts, val_hidden, first_letter=first_letters)
                
                # Calculate validation loss
                val_loss = criterion(output, answers)
                val_losses.append(val_loss.item())
                
                # Calculate validation accuracy
                _, predicted = torch.max(output.data, 1)
                val_total += answers.size(0)
                val_correct += (predicted == answers).sum().item()
                
                # Debug: Print validation progress
                if batch_idx % 25 == 0:
                    print(f"  Val Batch {batch_idx}/{len(val_loader)}, Val Loss: {val_loss.item():.4f}")

        avg_val_loss = np.mean(val_losses) if val_losses else float('inf')
        val_accuracy = (100 * val_correct / val_total) if val_total > 0 else 0
        
        model.train()

        print(f"\nEpoch {e+1}/{epochs}:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
        print(f"  Val Loss: {avg_val_loss:.4f} | Val Acc: {val_accuracy:.2f}%")
        
        wandb.log({
            "epoch": e + 1, 
            "train_loss": avg_train_loss,
            "train_accuracy": train_accuracy,
            "val_loss": avg_val_loss,
            "val_accuracy": val_accuracy
        })
        
        # Early stopping based on validation loss
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"  -> New best model saved with val_loss: {best_val_loss:.4f}")
        else:
            patience_counter += 1
            print(f"  -> Validation loss didn't improve. Patience: {patience_counter}/{patience}")
            
        if patience_counter >= patience:
            print(f"Early stopping triggered after {e+1} epochs")
            break

    print(f"\nTraining completed. Best validation loss: {best_val_loss:.4f}")
    torch.save(model.state_dict(), 'final_model.pth')
    
    artifact = wandb.Artifact('best-next-word-model', type='model')
    artifact.add_file('best_model.pth')
    wandb.log_artifact(artifact)
    print("Best model saved as wandb artifact.")

In [8]:
# --- 4. Configuration ---
ZIP_FILE_PATH = 'train.zip'
TEXT_FILE_NAME = 'train.src.tok'
DEV_SET_PATH = 'dev_set.csv'
SEQ_LENGTH = 20
BATCH_SIZE = 128
EPOCHS = 10
LEARNING_RATE = 0.001
EMBEDDING_DIM = 256
HIDDEN_DIM = 512
N_LAYERS = 2

# --- WANDB Initialization ---
wandb.init(
    project="predictive-keyboard-rnn-optimized",
    config={
        "learning_rate": LEARNING_RATE, "epochs": EPOCHS, "batch_size": BATCH_SIZE,
        "seq_length": SEQ_LENGTH, "embedding_dim": EMBEDDING_DIM, "hidden_dim": HIDDEN_DIM,
        "n_layers": N_LAYERS, "dataset": "train.src.tok (first 10000 lines)",
        "architecture": "LSTM_with_first_letter", "optimization": "parallel_processing"
    }
)

# --- Load Data ---
if not os.path.exists(TEXT_FILE_NAME) or not os.path.exists(DEV_SET_PATH):
    print(f"Error: Make sure '{TEXT_FILE_NAME}' and '{DEV_SET_PATH}' are in the directory.")
else:
    with open(TEXT_FILE_NAME) as f:
        lines = f.read().splitlines()

    if lines:
        print("Limiting training data to the first 10000 lines for testing.")
        limited_lines = lines[:10000]
        
        # --- Prepare Datasets and DataLoaders ---
        train_dataset = TextDataset(limited_lines, seq_length=SEQ_LENGTH)
        val_dataset = DevSetDataset(DEV_SET_PATH, train_dataset.word_to_int, seq_length=SEQ_LENGTH)

        train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False)
        
        with open('word_to_int.json', 'w') as f:
            json.dump(train_dataset.word_to_int, f)
        print("Vocabulary saved.")

        # --- Initialize Model ---
        model = NextWordModel(train_dataset.vocab_size, EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS)
        print("\nModel Architecture:")
        print(model)
        print(f"\nVocabulary Size: {train_dataset.vocab_size}")
        print(f"Training sequences: {len(train_dataset)}")
        print(f"Validation samples: {len(val_dataset)}")

        # --- Train Model ---
        if len(train_dataset) > 0 and len(val_dataset) > 0:
            print("\nStarting optimized training...")
            train(model, train_loader, val_loader, train_dataset.vocab_size, 
                  epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LEARNING_RATE)
            print("\nTraining complete.")
        else:
            print("Not enough data to create training and/or validation sets.")

wandb.finish()

Limiting training data to the first 10000 lines for testing.
Vocabulary saved.

Model Architecture:
NextWordModel(
  (embedding): Embedding(14886, 256)
  (lstm): LSTM(256, 512, num_layers=2, batch_first=True, dropout=0.5)
  (letter_embedding): Embedding(27, 128)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=640, out_features=14886, bias=True)
)

Vocabulary Size: 14886
Training sequences: 128351
Validation samples: 89912

Starting optimized training...
Training on cuda...
  Batch 0/1002, Loss: 9.6234
  Batch 50/1002, Loss: 5.8929
  Batch 100/1002, Loss: 4.8120
  Batch 150/1002, Loss: 4.1169
  Batch 200/1002, Loss: 4.5378
  Batch 250/1002, Loss: 4.2424
  Batch 300/1002, Loss: 4.3588
  Batch 350/1002, Loss: 4.6502
  Batch 400/1002, Loss: 4.8976
  Batch 450/1002, Loss: 4.1604
  Batch 500/1002, Loss: 4.2966
  Batch 550/1002, Loss: 4.1233
  Batch 600/1002, Loss: 4.2891
  Batch 650/1002, Loss: 3.9506
  Batch 700/1002, Loss: 4.0982
  Batch 750/1002, Loss: 3.3726
  Batch

0,1
epoch,▁▂▃▅▆▇█
train_accuracy,▁▃▅▆▆▇█
train_loss,█▅▄▃▂▂▁
val_accuracy,▁▂▄▆▇██
val_loss,█▃▂▁▂▃▅

0,1
epoch,7.0
train_accuracy,57.86006
train_loss,1.7815
val_accuracy,39.40742
val_loss,3.82809


In [20]:
# --- 5. Prediction and Evaluation Functions ---

def load_model(model_path, vocab_size, embedding_dim=256, hidden_dim=512, n_layers=2):
    """
    Load a trained model from file.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = NextWordModel(vocab_size, embedding_dim, hidden_dim, n_layers)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

def predict_next_word(model, context_text, first_letter=None, word_to_int=None, int_to_word=None, seq_length=20, top_k=5):
    """
    Predict the next word(s) given a context and optional first letter constraint.
    
    Args:
        model: Trained NextWordModel
        context_text: String context (previous words)
        first_letter: Optional first letter constraint (e.g., 't')
        word_to_int: Vocabulary mapping
        int_to_word: Reverse vocabulary mapping
        seq_length: Maximum sequence length
        top_k: Number of top predictions to return
    
    Returns:
        List of tuples (word, probability) sorted by probability
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    
    with torch.no_grad():
        # Tokenize and convert context to integers
        words = context_text.lower().split()
        context_indices = [word_to_int.get(w, word_to_int['<UNK>']) for w in words]
        
        # Pad or truncate to seq_length
        if len(context_indices) > seq_length:
            context_indices = context_indices[-seq_length:]
        else:
            context_indices = [word_to_int['<PAD>']] * (seq_length - len(context_indices)) + context_indices
        
        # Convert to tensor
        context_tensor = torch.tensor([context_indices]).to(device)
        
        # Handle first letter
        if first_letter:
            first_letter = first_letter.lower()
            if 'a' <= first_letter <= 'z':
                letter_idx = ord(first_letter) - ord('a')
            else:
                letter_idx = 26
            letter_tensor = torch.tensor([letter_idx]).to(device)
        else:
            letter_tensor = None
        
        # Initialize hidden state
        hidden = model.init_hidden(1)
        
        # Get predictions
        output, _ = model(context_tensor, hidden, first_letter=letter_tensor)
        
        # Apply softmax to get probabilities
        probabilities = torch.nn.functional.softmax(output, dim=1).squeeze()
        
        # Get top k predictions
        top_probs, top_indices = torch.topk(probabilities, top_k)
        
        # Convert indices to words
        predictions = []
        for i in range(top_k):
            word = int_to_word[top_indices[i].item()]
            prob = top_probs[i].item()
            predictions.append((word, prob))
        
        return predictions

def evaluate_model(model, test_loader, int_to_word=None, device=None):
    """
    Evaluate model performance on test set.
    
    Args:
        model: Trained NextWordModel
        test_loader: DataLoader with test data
        int_to_word: Reverse vocabulary mapping (index to word)
        device: Device to run evaluation on
    
    Returns:
        Dictionary with evaluation metrics
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    criterion = nn.CrossEntropyLoss()
    
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    correct_with_first_letter = 0
    
    with torch.no_grad():
        for batch_idx, (contexts, targets, first_letters) in enumerate(test_loader):
            contexts, targets, first_letters = contexts.to(device), targets.to(device), first_letters.to(device)
            
            # Initialize hidden state
            hidden = model.init_hidden(contexts.size(0))
            
            # Get predictions
            output, _ = model(contexts, hidden, first_letter=first_letters)
            
            # Calculate loss
            loss = criterion(output, targets)
            total_loss += loss.item()
            
            # Calculate accuracy
            _, predicted = torch.max(output.data, 1)
            total_predictions += targets.size(0)
            correct_predictions += (predicted == targets).sum().item()
            
            # Calculate accuracy with first letter constraint
            if int_to_word is not None:
                for i in range(contexts.size(0)):
                    first_letter_idx = first_letters[i].item()
                    true_answer = targets[i].item()
                    predicted_answer = predicted[i].item()
                    
                    # Get actual first letter from index
                    if first_letter_idx <= 25:  # a-z
                        actual_letter = chr(first_letter_idx + ord('a'))
                    else:  # other
                        actual_letter = 'other'
                    
                    # Check if prediction matches the first letter constraint
                    predicted_word = int_to_word.get(predicted_answer, '')
                    if predicted_word and len(predicted_word) > 0:
                        predicted_first_letter = predicted_word[0].lower()
                        if predicted_first_letter == actual_letter:
                            if predicted_answer == true_answer:
                                correct_with_first_letter += 1
    
    avg_loss = total_loss / len(test_loader)
    accuracy = (100 * correct_predictions / total_predictions) if total_predictions > 0 else 0
    accuracy_with_constraint = (100 * correct_with_first_letter / total_predictions) if total_predictions > 0 else 0
    
    return {
        'loss': avg_loss,
        'accuracy': accuracy,
        'accuracy_with_first_letter_constraint': accuracy_with_constraint,
        'total_samples': total_predictions
    }

In [21]:
# --- 6. Model Evaluation ---

# First, let's check if we have a trained model and vocabulary
import os

if os.path.exists('best_model.pth') and os.path.exists('word_to_int.json'):
    print("Found trained model and vocabulary files.")
    
    # Load vocabulary
    with open('word_to_int.json', 'r') as f:
        word_to_int = json.load(f)
    
    # Create reverse vocabulary mapping
    int_to_word = {i: word for word, i in word_to_int.items()}
    vocab_size = len(word_to_int)
    
    # Load the trained model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = load_model('best_model.pth', vocab_size, EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS)
    
    # Create validation dataset and loader
    val_dataset = DevSetDataset(DEV_SET_PATH, word_to_int, seq_length=SEQ_LENGTH)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE * 2, shuffle=False)
    
    # Evaluate the model
    print("Evaluating model on dev_set.csv...")
    results = evaluate_model(model, val_loader, int_to_word=int_to_word, device=device)
    
    print("\n=== Evaluation Results ===")
    print(f"Loss: {results['loss']:.4f}")
    print(f"Overall Accuracy: {results['accuracy']:.2f}%")
    print(f"Accuracy with First Letter Constraint: {results['accuracy_with_first_letter_constraint']:.2f}%")
    print(f"Total Samples: {results['total_samples']}")
    
    # Example predictions
    print("\n=== Example Predictions ===")
    for i in range(min(5, len(val_dataset))):
        context, answer, first_letter = val_dataset[i]
        context_text = " ".join([int_to_word[idx.item()] for idx in context if idx.item() != word_to_int['<PAD>']])
        true_answer = int_to_word[answer.item()]
        letter_hint = chr(first_letter.item() + ord('a')) if first_letter.item() <= 25 else 'other'
        
        predictions = predict_next_word(model, context_text, letter_hint, word_to_int, int_to_word, SEQ_LENGTH, top_k=3)
        
        print(f"\nContext: '{context_text}'")
        print(f"First letter: '{letter_hint}'")
        print(f"True answer: '{true_answer}'")
        print(f"Top 3 predictions: {predictions}")
    
else:
    print("No trained model found. Please train the model first by running the training cell.")
    print("To train the model, make sure you have:")
    print("1. train.zip file containing train.src.tok")
    print("2. dev_set.csv file")
    print("3. Run the training configuration cell (cell 6)")
    
    # If you want to test the evaluation function without a trained model,
    # you can uncomment the following code to create a dummy model:
    """
    # Create dummy vocabulary and model for testing
    dummy_word_to_int = {'<PAD>': 0, '<UNK>': 1, 'the': 2, 'a': 3, 'an': 4, 'and': 5}
    dummy_int_to_word = {i: w for w, i in dummy_word_to_int.items()}
    
    # Create a dummy model
    dummy_model = NextWordModel(len(dummy_word_to_int), EMBEDDING_DIM, HIDDEN_DIM, N_LAYERS)
    
    # Create a small dummy dataset for testing
    # This would require a CSV with words that exist in dummy_word_to_int
    """

Found trained model and vocabulary files.
Evaluating model on dev_set.csv...

=== Evaluation Results ===
Loss: 3.5622
Overall Accuracy: 38.02%
Accuracy with First Letter Constraint: 31.47%
Total Samples: 89912

=== Example Predictions ===

Context: 'states on monday warned north korea to avoid provoking trouble as pyongyang ' s most senior defector spent his sixth'
First letter: 'd'
True answer: 'day'
Top 3 predictions: [('debate', 0.029785726219415665), ('decision', 0.022341500967741013), ('dispute', 0.019380083307623863)]

Context: 'to drastically cut its car import duties , taiwan on thursday won european union support for its bid to enter'
First letter: 't'
True answer: 'the'
Top 3 predictions: [('the', 0.9337925910949707), ('their', 0.021507205441594124), ('them', 0.013604961335659027)]

Context: 'three soldiers were injured in a bombing ambush launched by suspect thai southern insurgents on wednesday'
First letter: 'm'
True answer: 'morning'
Top 3 predictions: [('must', 0.1831695