# Train a Text Generation  model using LSTM and TensorFlow

Welcome to this notebook on text generation using LSTM and TensorFlow. This notebook offers a straightforward approach to understanding and implementing a simple LSTM-based text generation model.

In [None]:
from datasets import load_dataset
import numpy as np
import pandas as pd 
from collections import Counter
import matplotlib.pyplot as plt      
from sklearn.model_selection import train_test_split
import random
import re                                  
import string         
import nltk   
from nltk.tokenize import word_tokenize                     
from nltk.corpus import stopwords  
from nltk.tokenize import sent_tokenize
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import to_categorical
from keras.models import Sequential
from keras.layers import Embedding, Bidirectional, LSTM, Dropout, Dense
from keras.layers import Attention
from tensorflow.keras.models import load_model


In [None]:
# load the dataset 
dataset = load_dataset("bookcorpus",split='train[:50%]') 
dataset

In [None]:
ds=dataset.to_pandas()
ds=ds['text']

In [None]:
def text_processing(text_example):
    
    # convert all letters to lower case
    example = text_example.lower()

    # Remove links
    example = re.sub(r'http\S+|www.\S+|@|️#|', '', example)

    # Remove other non-alphanumeric characters 
    example = re.sub(r'[^a-zA-Z0-9 .]', ' ', example)
    
    # Tokenize the sentence
    sentence_tokens = word_tokenize(example)
    
    sentence_tokens = [token for token in sentence_tokens if token.strip() and token not in stopwords and token !='#']
        
    sentence_tokens = [token for token in sentence_tokens if token and token != '️']
    
#     sentence_tokens =['<s>']+sentence_tokens +['</s>']
    
    return sentence_tokens


print('original text: ',ds[0])
print('processing text : ',text_processing(ds[0]))
    

In [None]:
ds=ds.apply(text_processing)

In [None]:
def check_token(token, tokenized_sentences):
    for sentence in tokenized_sentences:
        if token in sentence:
            return True
    return False

token = '#'
exists = check_token(token, ds)

print(f"Does the token '{token}' exist in the tokenized sentences? {exists}")


In [None]:
tokenizer = Tokenizer(oov_token='<OOV>')
tokenizer.fit_on_texts(ds)
total_words = len(tokenizer.word_index)+1

In [None]:
all_words = [word for sentence in ds for word in sentence]
print(len(Counter(all_words)))

In [None]:
print(len(tokenizer.word_index))
print(total_words)

In [None]:
def generate_tokenized_sequences(sentences,total_words):
#     # Initialize the tokenizer
#     tokenizer = Tokenizer()
    
#     # Flatten the list of sentences and fit the tokenizer
#     all_words = [word for sentence in sentences for word in sentence]
#     tokenizer.fit_on_texts(all_words)
    
    sequences = []
    targets = []
    
    for sentence in sentences:
        for i in range(1, len(sentence)):
            # Convert the sentence up to i+1 to a sequence
            sequence = tokenizer.texts_to_sequences([sentence[:i+1]])[0]
            sequences.append(sequence)
            
    max_len= 5 # find_optimal_maxlen(sequences)
    
    # Pad the sequences
    sequences = pad_sequences(sequences,maxlen=max_len, padding='pre')
    
    # Split the padded sequences into inputs and targets
    inputs = sequences[:, :-1]
    targets = sequences[:, -1]
    targets = to_categorical(targets, num_classes=total_words)

    
    return inputs, targets #,tokenizer

inputs, targets = generate_tokenized_sequences(ds,total_words)

print("Inputs:", inputs)
print("Targets:", targets)


In [None]:
# Print the first few inputs and targets
for i in range(min(5, len(inputs))):
    input_words = [tokenizer.index_word[idx] if idx in tokenizer.index_word else '<OOV>' for idx in inputs[i]]
    target_word = tokenizer.index_word[targets[i].argmax()] if targets[i].argmax() in tokenizer.index_word else '<OOV>'
    
    print(f"Input {i+1}: {inputs[i]} ({input_words})")
    print(f"Target {i+1}: {targets[i]} ({target_word})")

In [None]:
def check_sequences_length(sequences):
    # Get the length of the first sequence
    first_sequence_length = len(sequences[0])

    # Check if all sequences have the same length
    return all(len(sequence) == first_sequence_length for sequence in sequences)

print(check_sequences_length(inputs))  # Outputs: True


In [None]:
vocab_size =len(tokenizer.word_index)+1
maxlen=len(inputs[0])
maxlen

In [None]:
model = Sequential([
#     Embedding(input_dim=vocab_size, output_dim=embedding_dim, trainable=False),
    Embedding(total_words, 300),
    Bidirectional(LSTM(units=128, return_sequences=True, dropout=0.2)),
    Bidirectional(LSTM(units=128, dropout=0.2)),
    Dense(units=128, activation='relu'),
#     Dropout(0.5),
    Dense(units=vocab_size, activation='softmax')
])


model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
batch_size = 300
history = model.fit(inputs, targets,batch_size=batch_size,epochs=10)

In [None]:
def data_generator(inputs, targets, batch_size):
    num_samples = len(inputs)
    while True:  # Loop forever, so the generator never runs out of data
        for i in range(0, num_samples, batch_size):
            batch_inputs = inputs[i:i+batch_size]
            batch_targets = targets[i:i+batch_size]
            yield batch_inputs, batch_targets


In [None]:

batch_size = 2860 # batch size
generator = data_generator(inputs, targets, batch_size)

steps_per_epoch = len(inputs) // batch_size  # Number of batches per epoch

history = model.fit(generator, steps_per_epoch=steps_per_epoch, epochs=10)
