In [2]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding, LSTM, Dense
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
import tkinter as tk
from spellchecker import SpellChecker
from transformers import BertTokenizer, TFBertForMaskedLM

# Load the text data from a file
with open("data2.txt", "r", encoding='utf-8') as file:
    data = file.read()

# Tokenization
tokenizer = Tokenizer()
tokenizer.fit_on_texts([data])

# Word Frequency Analysis
word_freq = {}
for word, freq in tokenizer.word_counts.items():
    word_freq[word] = freq

input_sequences = []
for sentence in data.split('\n'):
    tokenized_sentence = tokenizer.texts_to_sequences([sentence])[0]

    for i in range(1, len(tokenized_sentence)):
        input_sequences.append(tokenized_sentence[:i + 1])

max_len = max([len(x) for x in input_sequences])

padded_input_sequence = pad_sequences(input_sequences, maxlen=max_len, padding='pre')

x = padded_input_sequence[:, :-1]
y = padded_input_sequence[:, -1]

# Define the LSTM model
lstm_model = Sequential()
lstm_model.add(Embedding(len(tokenizer.word_index) + 1, 100, input_length=max_len - 1))
lstm_model.add(LSTM(150))
lstm_model.add(Dense(len(tokenizer.word_index) + 1, activation='softmax'))
lstm_model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
lstm_model.fit(x, y, epochs=10)

# BERT Tokenizer
tokenizer_bert = BertTokenizer.from_pretrained('bert-base-uncased')
model_bert = TFBertForMaskedLM.from_pretrained('bert-base-uncased')

# Function to correct typos
def correct_typos(text):
    spell = SpellChecker()
    corrected_words = [spell.correction(word) for word in text.split()]
    return ' '.join(corrected_words)

# GUI prediction function with LSTM
def predict_words_lstm():
    try:
        input_word = input_word_entry.get()
        corrected_word = correct_typos(input_word)
        input_word_entry.delete(0, tk.END)  # Clear the input field
        input_word_entry.insert(0, corrected_word)  # Update input field with corrected word

        num_predictions = 3
        text = corrected_word

        predictions = []

        token_text = tokenizer.texts_to_sequences([text])[0]
        padded_token_text = pad_sequences([token_text], maxlen=max_len - 1, padding="pre")
        predictions_for_input = lstm_model.predict(padded_token_text, verbose=0)

        for i in range(num_predictions):
            predicted_index = np.argmax(predictions_for_input, axis=1)[0]
            prediction_word = tokenizer.index_word[predicted_index]

            if prediction_word:
                predictions.append(prediction_word)

            # Remove the predicted word from the input sequence
            token_text.append(predicted_index)
            token_text = token_text[-(max_len - 1):]
            padded_token_text = pad_sequences([token_text], maxlen=max_len - 1, padding="pre")
            predictions_for_input = lstm_model.predict(padded_token_text, verbose=0)

        for i, prediction_word in enumerate(predictions):
            prediction_labels[i].config(text=prediction_word, fg='black')  # Update prediction labels

    except Exception as e:
        print("An error occurred during LSTM prediction:", e)


# GUI prediction function with BERT
def predict_words_bert():
    try:
        input_word = input_word_entry.get()
        corrected_word = correct_typos(input_word)
        input_word_entry.delete(0, tk.END)  # Clear the input field
        input_word_entry.insert(0, corrected_word)  # Update input field with corrected word

        text = corrected_word + ' ' + tokenizer_bert.mask_token  # Add mask token at the end
        predicted_words = predict_next_word(text)

        # Filter out tokens equal to "
        predicted_words = [word for word in predicted_words if word != '"']

        for i, prediction_word in enumerate(predicted_words[:3]):
            prediction_labels[i].config(text=prediction_word, fg='black')  # Update prediction labels

    except Exception as e:
        print("An error occurred during BERT prediction:", e)


# Function to predict the next word using BERT
def predict_next_word(text):
    # Tokenize input and convert to ids
    input_ids = tokenizer_bert.encode(text, return_tensors='tf')
    
    # Create masked input ids
    masked_input_ids = input_ids.numpy()
    masked_input_ids[0, -1] = tokenizer_bert.mask_token_id  # Mask the last token
    masked_input_ids = tf.constant(masked_input_ids)

    # Predict the masked token with BERT
    predictions = model_bert(masked_input_ids)[0]
    
    # Get the index of the masked token
    masked_index = np.where(masked_input_ids == tokenizer_bert.mask_token_id)[1][0]
    
    # Get the top 3 token predictions of the masked token
    predicted_index = np.argsort(predictions[0, masked_index, :])[-3:]
    predicted_tokens = tokenizer_bert.convert_ids_to_tokens(predicted_index)
    
    return predicted_tokens[::-1]  # Return predictions in descending order of probability

# Setup tkinter GUI
root = tk.Tk()
root.title("Word Prediction")
root.configure(bg='lightblue')

input_word_label = tk.Label(root, text="Input Word:", font=("Palatino", 24, "bold"), fg="blue")
input_word_label.pack()
input_word_entry = tk.Entry(root, font=("Palatino", 22), bg="lightgray", width=60)
input_word_entry.pack()

predict_button_lstm = tk.Button(root, text="Predict LSTM", command=predict_words_lstm, font=("Palatino", 18, "bold"), bg='black', fg='white')
predict_button_lstm.pack()

predict_button_bert = tk.Button(root, text="Predict BERT", command=predict_words_bert, font=("Palatino", 18, "bold"), bg='black', fg='white')
predict_button_bert.pack()

prediction_labels = [tk.Label(root, text="", font=("Palatino", 24, "italic"), fg='darkblue') for _ in range(3)]
for label in prediction_labels:
    label.pack()

root.mainloop()


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


All PyTorch model weights were used when initializing TFBertForMaskedLM.

All the weights of TFBertForMaskedLM were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForMaskedLM for predictions without further training.
