In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import LSTM, Dense, Embedding
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import os
import logging
import pickle

class StoryGenerator:
    def __init__(self, sequence_length=40, max_vocab_size=10000):
        self.sequence_length = sequence_length
        self.max_vocab_size = max_vocab_size
        self.initialize_directories()  # Create necessary directories

        self.tokenizer = None
        self.model = None

    def initialize_directories(self):
        os.makedirs('models', exist_ok=True)         
        os.makedirs('tokenizers', exist_ok=True)     

    def load_text_in_chunks(self, file_path, chunk_size=1024 * 1024):  # 1MB chunks
        """Generator function to read a file in chunks."""
        with open(file_path, 'r', encoding='utf-8') as f:
            while True:
                chunk = f.read(chunk_size)
                if not chunk:
                    break
                yield chunk

    def preprocess_text(self, text, user_id):
        tokenizer = Tokenizer(num_words=self.max_vocab_size, oov_token='<OOV>')
        tokenizer.fit_on_texts([text])
        total_words = len(tokenizer.word_index) + 1

        input_sequences = []
        for line in text.split('. '):
            token_list = tokenizer.texts_to_sequences([line])[0]
            for i in range(1, len(token_list)):
                n_gram_sequence = token_list[:i+1]
                input_sequences.append(n_gram_sequence)

        input_sequences = np.array(pad_sequences(input_sequences, maxlen=self.sequence_length, padding='pre'))

        X = input_sequences[:, :-1]
        y = input_sequences[:, -1]
        y = tf.keras.utils.to_categorical(y, num_classes=total_words)

        self.save_tokenizer(tokenizer, user_id)

        return X, y, total_words

    def build_model(self, total_words):
        model = Sequential()
        model.add(Embedding(total_words, 100, input_length=self.sequence_length-1))
        model.add(LSTM(150, return_sequences=True))
        model.add(LSTM(100))
        model.add(Dense(total_words, activation='softmax'))

        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
        print("Model built successfully!")

        return model

    def train_model(self, user_id, text_file_path, epochs=30, batch_size=32):
        """Train the model on large files by processing in chunks."""
        print("Loading and preprocessing text data for a user:", user_id)
        total_words = 0
        all_X, all_y = [], []
        
        # Process file in chunks
        for chunk in self.load_text_in_chunks(text_file_path):
            X, y, total_words_chunk = self.preprocess_text(chunk, user_id)
            all_X.append(X)
            all_y.append(y)
            total_words = max(total_words, total_words_chunk)

        # Concatenate all data
        all_X = np.vstack(all_X)
        all_y = np.vstack(all_y)

        # Build the model
        print(f"Building the model for user {user_id}...")
        model = self.build_model(total_words)

        # Train the model
        print(f"Training the model for user {user_id}...")
        try:
            model.fit(all_X, all_y, epochs=epochs, batch_size=batch_size, verbose=1)
        except Exception as e:
            print(f"An error occurred during training: {e}")
            return
        
        # Save the model after training
        self.save_model(model, user_id)

    def save_model(self, model, user_id):
        model_path = f'models/story_generator_{user_id}.h5'
        model.save(model_path)
        print(f"Model for user {user_id} saved to {model_path}")

    def load_model(self, user_id):
        model_path = f'models/story_generator_{user_id}.h5'
        if os.path.exists(model_path):
            model = load_model(model_path)
            print(f"Model for user {user_id} loaded from {model_path}")
            return model
        else:
            print(f"No model found for user {user_id}. Please train a new model first.")
            return None

    def save_tokenizer(self, tokenizer, user_id):
        tokenizer_path = f'tokenizers/tokenizer_{user_id}.pkl'
        with open(tokenizer_path, 'wb') as f:
            pickle.dump(tokenizer, f)
        print(f"Tokenizer for user {user_id} saved to {tokenizer_path}")

    def load_tokenizer(self, user_id):
        tokenizer_path = f'tokenizers/tokenizer_{user_id}.pkl'
        if os.path.exists(tokenizer_path):
            with open(tokenizer_path, 'rb') as f:
                tokenizer = pickle.load(f)
            print(f"Tokenizer for user {user_id} loaded from {tokenizer_path}")
            return tokenizer
        else:
            print(f"No tokenizer found for user {user_id}. Please train a new model first.")
            return None

    def generate_story(self, user_id, seed_text, max_words=100):
        model = self.load_model(user_id)
        tokenizer = self.load_tokenizer(user_id)

        if model is None or tokenizer is None:
            print(f"Model or tokenizer not found for user {user_id}. Please train the model first.")
            return None

        logging.getLogger('tensorflow').setLevel(logging.ERROR)

        for _ in range(max_words):
            token_list = tokenizer.texts_to_sequences([seed_text])[0]
            token_list = pad_sequences([token_list], maxlen=self.sequence_length-1, padding='pre')
            predicted = np.argmax(model.predict(token_list, verbose=0), axis=-1)

            output_word = ""
            for word, index in tokenizer.word_index.items():
                if index == predicted:
                    output_word = word
                    break
            if output_word:  # Check if output_word is found
                seed_text += " " + output_word
            else:
                break  # Stop if no valid word is predicted

        return seed_text


# Example usage of the StoryGenerator class for a specific user
story_gen = StoryGenerator()



In [None]:
# To train the model for a user (user_id='user123')
path="/<path>/Story_generation/grandma_stories.txt"
story_gen.train_model(user_id='user123', text_file_path=path, epochs=50)



In [None]:
# To generate a story for the user
story = story_gen.generate_story(user_id='user123', seed_text="Once upon a time", max_words=100)
print(story)