Next word prediction is a task in natural language processing (NLP) where the goal is to predict the most likely word to follow a given sequence of words. This involves training deep learning models, like LSTMs or Transformers, on large text datasets to understand word patterns and dependencies. By using embeddings to represent words in high-dimensional space, the model can generate meaningful predictions based on context, which is useful for applications like text autocompletion, chatbots, and language generation systems.

In [2]:
import pandas as pd
import numpy as np
import re

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical

from tensorflow.keras.layers import Embedding, LSTM, Dense, Dropout
from tensorflow.keras.models import Sequential

2024-07-26 09:49:16.339303: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-26 09:49:16.339435: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-26 09:49:16.626672: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Load the TED talk transcripts dataset from kaggle(https://www.kaggle.com/datasets/miguelcorraljr/ted-ultimate-dataset)
talks = pd.read_csv("/path/to/ted-ultimate-dataset/2020-05-01/ted_talks_en.csv")

In [16]:
texts = talks['transcript']
texts = texts.to_list()

In [17]:
# Set a seed for reproducibility and select 100 random TED talk transcripts
import random
np.random.seed(42)
texts = random.sample(texts, 100)

In [18]:
# text[np.random.randint(1,1000)]

In [19]:
def clean_texts(text):
    
    # Remove text inside parentheses, backslashes, and extra spaces
    text = re.sub(r"\([^)]*\)", "", text)
    text = re.sub(r'\\', '', text)    
    text = re.sub(r'\s+', ' ', text).strip()
    return text

In [20]:
texts = [clean_texts(text) for text in texts]

In [21]:
# Split transcripts into individual sentences by punctuation marks
sentences_list = []

for transcript in texts:
    sentences = [sentence.strip() for sentence in re.split(r'(?<=[.!?])\s+', transcript) if sentence.strip()]
    sentences_list.extend(sentences)

In [22]:
# Keep only the sentences with length below 50
sentences = [sentence for sentence in sentences_list if len(sentence.split()) <= 50]

In [23]:
len(sentences)

10558

In [24]:
import string

def remove_punc(text):
    translator = str.maketrans('','',string.punctuation)
    text = text.translate(translator).lower()
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text

In [25]:
sentences = [remove_punc(sent) for sent in sentences]

In [26]:
tokenizer = Tokenizer()
tokenizer.fit_on_texts(sentences)
total_words = len(tokenizer.word_index) + 1

In [27]:
# Create input sequences using n-grams

input_sequences = []

for sentence in sentences:
    token_list = tokenizer.texts_to_sequences([sentence])[0]
    for i in range(1, len(token_list)):
            n_gram_sequence = token_list[:i+1]
            input_sequences.append(n_gram_sequence)

In [28]:
len(input_sequences)

156102

In [29]:
maxlen = max([len(seq) for seq in input_sequences])

In [30]:
padded_seq = pad_sequences(input_sequences, maxlen = maxlen, padding='pre')

In [31]:
X = padded_seq[:, :-1]
Y = padded_seq[:,-1]
y = to_categorical(Y, num_classes=total_words)

In [32]:
print(X.shape)
print(y.shape)

(156102, 49)
(156102, 12229)


In [33]:
model = Sequential()

model.add(Embedding(total_words, 256))

model.add(LSTM(256, return_sequences=True))
model.add(Dropout(0.1))

model.add(LSTM(128))

model.add(Dense(total_words, activation='softmax'))

In [34]:
model.compile(loss='categorical_crossentropy', 
              optimizer='adam', metrics=['accuracy']) 

In [35]:
model.summary()

In [None]:
history = model.fit(X, y, epochs=100, verbose=1)

Epoch 1/100
[1m4879/4879[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 16ms/step - accuracy: 0.0577 - loss: 6.9069
Epoch 2/100
[1m4879/4879[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 16ms/step - accuracy: 0.1107 - loss: 5.9587
Epoch 3/100
[1m4879/4879[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 16ms/step - accuracy: 0.1271 - loss: 5.6043
Epoch 4/100
[1m4879/4879[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 16ms/step - accuracy: 0.1403 - loss: 5.3351
Epoch 5/100
[1m4879/4879[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 16ms/step - accuracy: 0.1498 - loss: 5.1144
Epoch 6/100
[1m4879/4879[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 16ms/step - accuracy: 0.1587 - loss: 4.9398
Epoch 7/100
[1m4879/4879[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 16ms/step - accuracy: 0.1637 - loss: 4.7741
Epoch 8/100
[1m4879/4879[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m79s[0m 16ms/step - accuracy: 0.1715 - loss: 4.6180


In [26]:
def predict_next_word(model, tokenizer, words):
    token_list = tokenizer.texts_to_sequences([words])[0]
    token_list = pad_sequences([token_list], maxlen = maxlen-1, padding='pre')
    pred = model.predict(token_list)
    
    top_preds = np.argsort(pred)[-5:][::1]
    pred_word =  [tokenizer.index_word[index] for index in top_preds]
    return pred_word

In [37]:
text = "Make sure to"
for _ in range(10):
    word = predict_next_word(model, tokenizer, text)
    text = text + " " + word
    print(text)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
Make sure to look
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
Make sure to look at
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
Make sure to look at the
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
Make sure to look at the same
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
Make sure to look at the same thing
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
Make sure to look at the same thing that
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
Make sure to look at the same thing that you
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
Make sure to look at the same thing that you have
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
Make sure to look at the same thing that you have to
[1m1/1[0m [32m━━━━━━━━━━━━━━━━