In [2]:
text =[('Next character prediction is a fundamental task in the field of natural language processing (NLP) that involves predicting the next character in a sequence of text based on the characters that precede it. This task is essential for various applications, including text auto-completion, spell checking, and even in the development of sophisticated AI models capable of generating human-like text. At its core, next character prediction relies on statistical models or deep learning algorithms to analyze a given sequence of text and predict which character is most likely to follow. These predictions are based on patterns and relationships learned from large datasets of text during the training phase of the model. One of the most popular approaches to next character prediction involves the use of Recurrent Neural Networks (RNNs), and more specifically, a variant called Long Short-Term Memory (LSTM) networks. RNNs are particularly well-suited for sequential data like text, as they can maintain information in "memory" about previous characters to inform the prediction of the next character. LSTM networks enhance this capability by being able to remember long-term dependencies, making them even more effective for next character prediction tasks. Training a model for next character prediction involves feeding it large amounts of text data, allowing it to learn the probability of each characters appearance following a sequence of characters. During this training process, the model adjusts its parameters to minimize the difference between its predictions and the actual outcomes, thus improving its predictive accuracy over time. Once trained, the model can be used to predict the next character in a given piece of text by considering the sequence of characters that precede it. This can enhance user experience in text editing software, improve efficiency in coding environments with auto-completion features, and enable more natural interactions with AI-based chatbots and virtual assistants. In summary, next character prediction plays a crucial role in enhancing the capabilities of various NLP applications, making text-based interactions more efficient, accurate, and human-like. Through the use of advanced machine learning models like RNNs and LSTMs, next character prediction continues to evolve, opening new possibilities for the future of text-based technology.')]

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm

class CharacterDataset(Dataset):
    def __init__(self, text, sequence_length=10):
        """
        Initialize the dataset with a text and the sequence length.
        Args:
        text (str): The full text to train on.
        sequence_length (int): The number of characters in each input sequence.
        """
        self.text = text
        self.sequence_length = sequence_length
        # Create a set of all unique characters in the text
        self.characters = sorted(list(set(text)))
        # Map characters to indices and vice versa
        self.char_to_idx = {ch: idx for idx, ch in enumerate(self.characters)}
        self.idx_to_char = {idx: ch for idx, ch in enumerate(self.characters)}

    def __len__(self):
        """
        The length of the dataset is the number of sequences that can be formed,
        which is the total number of characters minus the sequence length.
        """
        # Ensure that the length is non-negative
        return max(0, len(self.text) - self.sequence_length)

    def __getitem__(self, index):
        """
        Get the input sequence and the target sequence from the text.
        Args:
        index (int): The index of the sequence in the dataset.
        Returns:
        tuple: (input sequence, target sequence) where both are tensors of character indices.
        """
        inputs = self.text[index:index + self.sequence_length]
        targets = self.text[index + 1:index + self.sequence_length + 1]
        input_indices = torch.tensor([self.char_to_idx[ch] for ch in inputs], dtype=torch.long)
        target_indices = torch.tensor([self.char_to_idx[ch] for ch in targets], dtype=torch.long)
        return input_indices, target_indices

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)]

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_decoder_layers=3, dim_feedforward=2048):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward),
            num_decoder_layers
        )
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.d_model = d_model

    def forward(self, src):
        src = self.embedding(src) * np.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_decoder(src, src)
        return self.fc_out(output)

def train(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for inputs, targets in data_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        output = model(inputs)
        loss = criterion(output.view(-1, output.size(-1)), targets.view(-1))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(data_loader)

def main():
    text
    sequence_lengths = [10, 20, 30]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = {}

    for seq_len in sequence_lengths:
        dataset = CharacterDataset(text, sequence_length=seq_len)
        data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
        model = TransformerModel(vocab_size=len(dataset.characters)).to(device)
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        criterion = nn.CrossEntropyLoss()

        epochs = 10
        for epoch in range(epochs):
            loss = train(model, data_loader, optimizer, criterion, device)
            print(f'Epoch {epoch+1}, Loss: {loss}')
        results[seq_len] = loss
    
    print(results)

if __name__ == "__main__":
    text = "Here is a simple example."
    sequence_length = 10
    dataset = CharacterDataset(text, sequence_length)
    print(f'Dataset length: {len(dataset)}')  # Outputs the length of the dataset
    for i in range(len(dataset)):
        x, y = dataset[i]
        print(f'Input sequence indices: {x}')
        print(f'Target sequence indices: {y}')

Dataset length: 15
Input sequence indices: tensor([ 2,  4,  9,  4,  0,  5, 10,  0,  3,  0])
Target sequence indices: tensor([ 4,  9,  4,  0,  5, 10,  0,  3,  0, 10])
Input sequence indices: tensor([ 4,  9,  4,  0,  5, 10,  0,  3,  0, 10])
Target sequence indices: tensor([ 9,  4,  0,  5, 10,  0,  3,  0, 10,  5])
Input sequence indices: tensor([ 9,  4,  0,  5, 10,  0,  3,  0, 10,  5])
Target sequence indices: tensor([ 4,  0,  5, 10,  0,  3,  0, 10,  5,  7])
Input sequence indices: tensor([ 4,  0,  5, 10,  0,  3,  0, 10,  5,  7])
Target sequence indices: tensor([ 0,  5, 10,  0,  3,  0, 10,  5,  7,  8])
Input sequence indices: tensor([ 0,  5, 10,  0,  3,  0, 10,  5,  7,  8])
Target sequence indices: tensor([ 5, 10,  0,  3,  0, 10,  5,  7,  8,  6])
Input sequence indices: tensor([ 5, 10,  0,  3,  0, 10,  5,  7,  8,  6])
Target sequence indices: tensor([10,  0,  3,  0, 10,  5,  7,  8,  6,  4])
Input sequence indices: tensor([10,  0,  3,  0, 10,  5,  7,  8,  6,  4])
Target sequence indices: t