In [27]:
!pip install tensorflow datasets




In [28]:
from datasets import load_dataset

# Load the dataset
dataset = load_dataset("wikitext", "wikitext-2-v1", split='train[:10%]')

# Inspect the dataset
print(dataset)



Dataset({
    features: ['text'],
    num_rows: 3672
})


In [29]:
# Import necessary libraries
from datasets import load_dataset
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
import tensorflow as tf

# Check if GPU is available
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

# Initialize the tokenizer
tokenizer = Tokenizer()

# Fit the tokenizer on a smaller subset to save memory
small_subset = dataset.select(range(1000))
tokenizer.fit_on_texts(small_subset['text'])


Num GPUs Available:  2


In [30]:
import matplotlib.pyplot as plt

# Plot training history
def plot_history(history):
    # Plot training & validation accuracy values
    plt.plot(history.history['accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train'], loc='upper left')
    plt.show()

    # Plot training & validation loss values
    plt.plot(history.history['loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train'], loc='upper left')
    plt.show()

In [31]:
# Function to process the data in batches
def process_batch(batch, tokenizer):
    sequences = []
    for text in batch['text']:
        token_list = tokenizer.texts_to_sequences([text])[0]
        for i in range(1, len(token_list)):
            n_gram_sequence = token_list[:i+1]
            sequences.append(n_gram_sequence)
    return sequences

# Process and pad the sequences in batches
batch_size = 100
max_sequence_len = 0
input_sequences = []

for i in range(0, len(dataset), batch_size):
    batch = dataset.select(range(i, min(i + batch_size, len(dataset))))
    sequences = process_batch(batch, tokenizer)
    max_sequence_len = max(max_sequence_len, max([len(x) for x in sequences]))
    input_sequences.extend(sequences)

# Pad sequences
input_sequences = np.array(pad_sequences(input_sequences, maxlen=max_sequence_len, padding='pre'))

# Create predictors and labels
xs, labels = input_sequences[:,:-1], input_sequences[:,-1]
ys = tf.keras.utils.to_categorical(labels, num_classes=len(tokenizer.word_index) + 1)


In [32]:
# Inference function
def predict_next_word(input_text):
    # Tokenize the input text
    token_list = tokenizer.texts_to_sequences([input_text])[0]
    # Pad sequences
    token_list = pad_sequences([token_list], maxlen=max_sequence_len-1, padding='pre')
    # Predict probabilities for the next word
    predicted = model.predict(token_list, verbose=0)
    # Get the predicted word index with highest probability
    predicted_word_index = np.argmax(predicted)
    # Convert index to word
    predicted_word = tokenizer.index_word.get(predicted_word_index, '<Unknown>')
    return predicted_word

In [36]:
# Example of a more complex model with increased LSTM units and dropout
model = tf.keras.Sequential([
    tf.keras.layers.Embedding(len(tokenizer.word_index) + 1, 200, input_length=max_sequence_len - 1),
    tf.keras.layers.LSTM(512, return_sequences=True),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.LSTM(128),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(len(tokenizer.word_index) + 1, activation='softmax')
])

optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

history = model.fit(xs, ys, epochs=20, batch_size=64, verbose=1, validation_split=0.2)


# Save the model
model.save('/kaggle/working/next_word_model.h5')

KeyboardInterrupt: 

In [None]:
while True:
    # Get user input
    seed_text = input("Enter a seed text (or type 'exit' to quit): ").strip().lower()
    
    # Exit loop if user enters 'exit'
    if seed_text == 'exit':
        break
    
    # Predict next word
    next_word = predict_next_word(seed_text)
    
    # Display prediction
    print(f"Predicted next word: {next_word}\n")


Enter a seed text (or type 'exit' to quit):  hi


Predicted next word: egyptian

