In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from sklearn.metrics import confusion_matrix, classification_report
from collections import Counter
import re

PATH = "data\sherlock-holm.es_stories_plain-text_advs.txt"

# Read the text file
with open(PATH, 'r', encoding='utf-8') as file:
    text = file.read()

In [3]:
# Tokenize the text
def tokenize(text):
    text = text.lower()
    text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
    words = text.split()
    return words

tokens = tokenize(text)
word_counts = Counter(tokens)
vocab = sorted(word_counts, key=word_counts.get, reverse=True)
word_to_index = {word: index + 1 for index, word in enumerate(vocab)}
index_to_word = {index + 1: word for index, word in enumerate(vocab)}
total_words = len(word_to_index) + 1

In [4]:
# Create input-output pairs
input_sequences = []
for line in text.split('\n'):
    token_list = [word_to_index[word] for word in tokenize(line) if word in word_to_index]
    for i in range(1, len(token_list)):
        n_gram_sequence = token_list[:i+1]
        input_sequences.append(n_gram_sequence)

In [5]:
# Pad the sequences
max_sequence_len = max([len(seq) for seq in input_sequences])
input_sequences = np.array([np.pad(seq, (max_sequence_len - len(seq), 0), mode='constant') for seq in input_sequences])

In [10]:
# Split the sequences into input (X) and output (y)
X = input_sequences[:, :-1]
y = input_sequences[:, -1]

# Ensure y is an integer tensor
y_tensor = torch.tensor(y, dtype=torch.long)

# Convert output to one-hot encoded vectors
y_one_hot = torch.nn.functional.one_hot(y_tensor, num_classes=total_words)

# Convert back to NumPy array if necessary
y_one_hot = y_one_hot.numpy()

In [11]:
# Create a custom Dataset class
class TextDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.long)
        self.y = torch.tensor(y, dtype=torch.float)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

dataset = TextDataset(X, y)

# Split dataset into training and validation sets (90% training, 10% validation)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Define the model
class NextWordPredictor(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, output_dim):
        super(NextWordPredictor, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.fc(x)
        return x

model = NextWordPredictor(total_words, 200, 256, total_words)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [21]:
# Train the model
epochs = 2
for epoch in range(epochs):
    model.train()
    for inputs, labels in train_dataloader:
        outputs = model(inputs)
        # Ensure labels are in the correct shape and convert to indices
        if labels.dim() > 1:
            labels = labels.argmax(dim=1)
        # Convert labels to Long tensor
        labels = labels.long()
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item()}')

torch.save(model.state_dict(), "model_state.pth")

Epoch 1/2, Loss: 5.359200477600098
Epoch 2/2, Loss: 5.804989814758301


In [23]:
# Evaluate the model on the validation set
model.eval()
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    for inputs, labels in val_dataloader:
        outputs = model(inputs)
        predicted = outputs.argmax(dim=1)
        # Ensure labels are in the correct shape and convert to indices
        if labels.dim() > 1:
            labels = labels.argmax(dim=1)
        # Convert labels to Long tensor
        labels = labels.long()
        actual = labels
        correct_predictions += (predicted == actual).sum().item()
        total_predictions += labels.size(0)

accuracy = correct_predictions / total_predictions
print(f'Validation Accuracy: {accuracy:.4f}')

perplexity  = torch.exp(loss)
print('Loss:', loss, 'PP:', perplexity)

Validation Accuracy: 0.1425
Loss: tensor(5.8050, grad_fn=<NllLossBackward0>) PP: tensor(331.9518, grad_fn=<ExpBackward0>)


In [27]:
import tkinter as tk 
from tkinter import messagebox
model.eval()

# Extract sentences from the validation dataset for tkinter app
val_sentences = []
for idx in range(int(0.1 * len(val_dataset))):
    input_sequence, _ = val_dataset[idx]
    sentence = ' '.join([index_to_word[index.item()] for index in input_sequence if index.item() in index_to_word])
    val_sentences.append(sentence)

# Initialize tkinter app
class NextWordApp:
    def __init__(self, master, val_sentences):
        self.master = master
        self.val_sentences = val_sentences
        self.current_sentence_index = 0
        self.correct_count = 0
        self.total_count = 0

        self.label = tk.Label(master, text=self.val_sentences[self.current_sentence_index])
        self.label.pack()

        self.correct_button = tk.Button(master, text="Correct", command=self.correct_callback)
        self.correct_button.pack(side=tk.LEFT, padx=10)

        self.wrong_button = tk.Button(master, text="Wrong", command=self.wrong_callback)
        self.wrong_button.pack(side=tk.RIGHT, padx=10)

    def predict_next_word(self):
        seed_text = self.val_sentences[self.current_sentence_index]
        token_list = [word_to_index[word] for word in tokenize(seed_text) if word in word_to_index]
        token_list = np.pad(token_list, (max_sequence_len - len(token_list), 0), mode='constant')
        token_list = torch.tensor(token_list[-max_sequence_len+1:], dtype=torch.long).unsqueeze(0)

        with torch.no_grad():
            predicted = model(token_list).argmax(dim=1).item()

        output_word = index_to_word[predicted]
        return output_word

    def correct_callback(self):
        self.total_count += 1
        self.correct_count += 1
        self.current_sentence_index += 1
        if self.current_sentence_index < len(self.val_sentences):
            self.label.config(text=self.val_sentences[self.current_sentence_index])
        else:
            self.show_accuracy()

    def wrong_callback(self):
        self.total_count += 1
        self.current_sentence_index += 1
        corrected_word = messagebox.askstring("Correct the prediction", "Please provide the correct next word:")
        if corrected_word:
            self.val_sentences[self.current_sentence_index] += " " + corrected_word.lower()
            self.label.config(text=self.val_sentences[self.current_sentence_index])
        else:
            self.val_sentences[self.current_sentence_index] += " " + self.predict_next_word()
            self.label.config(text=self.val_sentences[self.current_sentence_index])

        if self.current_sentence_index >= len(self.val_sentences):
            self.show_accuracy()

    def show_accuracy(self):
        accuracy = self.correct_count / self.total_count if self.total_count > 0 else 0
        accuracy_percent = accuracy * 100
        messagebox.showinfo("Accuracy", f"Accuracy: {accuracy_percent:.2f}%")

# Create the main window
root = tk.Tk()
root.title("Next Word Prediction App")

# Initialize the app with validation sentences
app = NextWordApp(root, val_sentences)

# Run the application
root.mainloop()

Exception in Tkinter callback
Traceback (most recent call last):
  File "s:\anaconda3\Lib\tkinter\__init__.py", line 1948, in __call__
    return self.func(*args)
           ^^^^^^^^^^^^^^^^
  File "C:\Users\Alvaro\AppData\Local\Temp\ipykernel_17608\2386581990.py", line 54, in wrong_callback
    corrected_word = messagebox.askstring("Correct the prediction", "Please provide the correct next word:")
                     ^^^^^^^^^^^^^^^^^^^^
AttributeError: module 'tkinter.messagebox' has no attribute 'askstring'


In [None]:
import time

# Evaluate the model on the validation set
model.eval()
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    for inputs, labels in val_dataloader:
        outputs = model(inputs)
        predicted = outputs.argmax(dim=1)
        actual = labels.argmax(dim=1)

        # Display results
        for i in range(len(inputs)):
            input_sentence = " ".join([index_to_word[idx.item()] for idx in inputs[i] if idx.item() != 0])
            predicted_word = index_to_word[predicted[i].item()]
            actual_word = index_to_word[actual[i].item()]
            correct = predicted[i].item() == actual[i].item()
            print(f"Sentence: {input_sentence}")
            print(f"Predicted: {predicted_word}")
            print(f"Actual: {actual_word}")
            print(f"Correct: {correct}")
            print()

            if correct:
                correct_predictions += 1
            total_predictions += 1

            time.sleep(3)  # Wait for 1 second between predictions

accuracy = correct_predictions / total_predictions
print(f'Validation Accuracy: {accuracy:.4f}')

In [None]:
# Generate predictions
def predict_next_words(model, tokenizer, seed_text, next_words):
    for _ in range(next_words):
        token_list = [word_to_index[word] for word in tokenize(seed_text)]
        token_list = np.pad(token_list, (max_sequence_len - len(token_list), 0), mode='constant')
        token_list = torch.tensor(token_list[-max_sequence_len+1:], dtype=torch.long).unsqueeze(0)

        with torch.no_grad():
            predicted = model(token_list).argmax(dim=1).item()

        output_word = index_to_word[predicted]
        seed_text += " " + output_word

    return seed_text

seed_text = "Scotland"
next_words = 1
print(predict_next_words(model, word_to_index, seed_text, next_words))