# Next Character Predictor using LSTM in PyTorch

This notebook demonstrates how to build an LSTM-based next-character predictor. It uses a sample text (you can replace it with your own corpus) to build a character vocabulary, prepares training sequences, defines the LSTM model, trains the model to predict the next character given a sequence of previous characters, and finally generates text based on a seed string.

Adjust hyperparameters (e.g., sequence length, embedding size, hidden size, learning rate, number of epochs) as needed.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import sys
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# -------------------------------
# 1. Data Preparation
# -------------------------------

# Sample text for demonstration; replace with your own text if desired
text = (
    "In the beginning God created the heaven and the earth. "
    "And the earth was without form, and void; and darkness was upon the face of the deep. "
    "And the Spirit of God moved upon the face of the waters. "
    "And God said, Let there be light: and there was light."
)

# Build the character vocabulary
chars = sorted(list(set(text)))
vocab_size = len(chars)
print("Unique characters:", vocab_size)
print("Characters:", chars)

# Create mappings from characters to indices and vice versa
char2idx = {ch: i for i, ch in enumerate(chars)}
idx2char = {i: ch for i, ch in enumerate(chars)}

# Convert the text into a sequence of indices
text_indices = np.array([char2idx[ch] for ch in text])

# Define sequence length: using last 10 characters as input to predict the next character
sequence_length = 10

def create_sequences(data, seq_length):
    inputs = []
    targets = []
    for i in range(len(data) - seq_length):
        inputs.append(data[i:i+seq_length])
        targets.append(data[i+seq_length])
    return np.array(inputs), np.array(targets)

inputs, targets = create_sequences(text_indices, sequence_length)
print("Total sequences:", len(inputs))

In [None]:
# -------------------------------
# 2. Dataset and DataLoader
# -------------------------------

from torch.utils.data import Dataset, DataLoader

class TextDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = torch.tensor(inputs, dtype=torch.long)
        self.targets = torch.tensor(targets, dtype=torch.long)
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

dataset = TextDataset(inputs, targets)
batch_size = 64
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# -------------------------------
# 3. Define the LSTM Model
# -------------------------------

class LSTMNextChar(nn.Module):
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1):
        super(LSTMNextChar, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(input_size=embed_size,
                            hidden_size=hidden_size,
                            num_layers=num_layers,
                            batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, x, hidden=None):
        # x shape: (batch_size, sequence_length)
        x = self.embed(x)  # (batch_size, sequence_length, embed_size)
        output, hidden = self.lstm(x, hidden)  # output: (batch_size, seq_length, hidden_size)
        output = output[:, -1, :]  # take output from the last time step
        logits = self.fc(output)  # (batch_size, vocab_size)
        return logits, hidden

embed_size = 32
hidden_size = 128
num_layers = 1

model = LSTMNextChar(vocab_size, embed_size, hidden_size, num_layers).to(device)
print(model)

In [None]:
# -------------------------------
# 4. Training Setup and Loop
# -------------------------------

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.003)
num_epochs = 200

model.train()
for epoch in range(num_epochs):
    total_loss = 0
    for batch_inputs, batch_targets in data_loader:
        batch_inputs, batch_targets = batch_inputs.to(device), batch_targets.to(device)
        optimizer.zero_grad()
        logits, _ = model(batch_inputs)
        loss = criterion(logits, batch_targets)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if (epoch + 1) % 20 == 0:
        avg_loss = total_loss / len(data_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

print("Training complete.")

In [None]:
# -------------------------------
# 5. Text Generation Function
# -------------------------------

def generate_text(model, seed_text, length=100, temperature=1.0):
    """
    Generate text given a seed string.
    
    Args:
        model: Trained LSTM model
        seed_text: Initial text string (length >= sequence_length)
        length: Number of characters to generate
        temperature: Controls randomness (higher -> more random)
    
    Returns:
        Generated text string
    """
    model.eval()
    generated = seed_text
    # Convert seed_text to indices (take only the last `sequence_length` characters)
    input_seq = [char2idx[ch] for ch in seed_text[-sequence_length:]]
    input_seq = torch.tensor(input_seq, dtype=torch.long).unsqueeze(0).to(device)
    hidden = None
    
    for _ in range(length):
        logits, hidden = model(input_seq, hidden)
        logits = logits / temperature
        probs = torch.softmax(logits, dim=-1).detach().cpu().numpy().squeeze()
        next_idx = np.random.choice(len(probs), p=probs)
        next_char = idx2char[next_idx]
        generated += next_char
        
        # Update input sequence: drop first element and append the predicted index
        new_input = torch.tensor([[next_idx]], dtype=torch.long).to(device)
        input_seq = torch.cat([input_seq[:, 1:], new_input], dim=1)
    
    return generated

# Generate sample text
seed = "And God s"  # Ensure the seed length is at least `sequence_length`
generated_text = generate_text(model, seed, length=200, temperature=0.8)
print("Generated Text:\n")
print(generated_text)