In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sentence_transformers import SentenceTransformer, util

# Load the pre-trained GPT-2 model and tokenizer
MODEL_NAME = "distilgpt2"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

# Set pad token to avoid attention mask issues
tokenizer.pad_token = tokenizer.eos_token

# Load a lightweight SentenceTransformer model for semantic similarity
semantic_model = SentenceTransformer("paraphrase-MiniLM-L3-v2")

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


def generate_story_chunk(prompt, max_new_tokens=50, temperature=0.6, top_p=0.8):
    """
    Generates a chunk of the story based on the given prompt (optimized parameters).
    """
    input_data = tokenizer.encode_plus(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=256,  # Reduce input size
        add_special_tokens=True,
    )
    input_ids = input_data["input_ids"].to(device)
    attention_mask = input_data["attention_mask"].to(device)

    outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens,
        temperature=temperature,  # Lower temperature for faster convergence
        top_p=top_p,  # Less diversity
        do_sample=True,
        repetition_penalty=1.3,
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id,
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()



def remove_repetition(text, all_previous_chunks, similarity_threshold=0.85):
    """
    Removes repetitive sentences based on semantic similarity, checking against previous chunks.
    """
    sentences = text.split(". ")
    filtered_sentences = []
    embeddings = semantic_model.encode(sentences, convert_to_tensor=True)  # Batch encode

    for i, sentence_embedding in enumerate(embeddings):
        is_similar = False
        # Check for repetition within the current chunk
        for prev in filtered_sentences:
            prev_embedding = semantic_model.encode(prev, convert_to_tensor=True)
            similarity = util.cos_sim(sentence_embedding, prev_embedding).item()
            if similarity > similarity_threshold:
                is_similar = True
                break

        # Check for repetition against previously added chunks
        if not is_similar:
            for prev_chunk in all_previous_chunks:
                chunk_embedding = semantic_model.encode(prev_chunk, convert_to_tensor=True)
                similarity = util.cos_sim(sentence_embedding, chunk_embedding).item()
                if similarity > similarity_threshold:
                    is_similar = True
                    break
        
        if not is_similar:
            filtered_sentences.append(sentences[i].strip())

    return ". ".join(filtered_sentences).strip()


def clean_story_ending(text):
    """
    Cleans the story ending by removing incomplete trailing sentences.
    """
    sentences = text.split(". ")
    if sentences[-1].endswith(".") or len(sentences) < 2:
        return text.strip()
    return ". ".join(sentences[:-1]) + "."


def validate_chunk(chunk, current_story, similarity_threshold=0.75):
    """
    Validates a generated chunk to ensure it adds meaningful content to the story.
    """
    current_story_embedding = semantic_model.encode(current_story, convert_to_tensor=True)
    chunk_embedding = semantic_model.encode(chunk, convert_to_tensor=True)
    similarity = util.cos_sim(current_story_embedding, chunk_embedding).item()
    return similarity < similarity_threshold


def count_words(text):
    """
    Counts the number of words in a given text.
    """
    return len(text.split())


def build_story(prompt, word_count=150):
    """
    Builds the story in chunks to match the target word count, handling repetition more effectively.
    """
    max_new_tokens = 40  # Reduced for faster iterations
    current_story = prompt
    all_previous_chunks = []  # List to keep track of all previous chunks
    while count_words(current_story) < word_count:
        remaining_words = word_count - count_words(current_story)
        continuation_prompt = f"{current_story.strip()} What happens next?"
        chunk = generate_story_chunk(continuation_prompt, max_new_tokens=min(max_new_tokens, remaining_words * 2))
        chunk = remove_repetition(chunk, all_previous_chunks)  # Check for repetition across chunks
        chunk = clean_story_ending(chunk)

        # Skip validation every chunk
        if count_words(chunk) > 5:
            current_story += " " + chunk
            all_previous_chunks.append(chunk)  # Save the new chunk to track repetition

    return current_story.strip()


def evaluate_story(prompt, story):
    """
    Evaluates the semantic similarity between the prompt and story.
    """
    prompt_embedding = semantic_model.encode(prompt, convert_to_tensor=True)
    story_embedding = semantic_model.encode(story, convert_to_tensor=True)
    similarity = util.cos_sim(prompt_embedding, story_embedding).item()
    return similarity * 100


def main():
    """
    Main function to handle user input and generate a story.
    """
    print("=" * 50)
    print("✨ Welcome to the AI Story Generator! ✨")
    print("=" * 50)

    prompt = input("\nWhat should your story be about? (e.g., 'A young wizard finds a magic book'): ").strip()
    if not prompt:
        print("⚠️ Please enter a valid story idea.")
        return

    try:
        word_count = int(input("\nHow many words should the story be (approx.)? (e.g., 100): ").strip())
    except ValueError:
        print("⚠️ Invalid input. Using default length of 150 words.")
        word_count = 150

    print("\n🌟 Generating your story... Please wait! 🌟")
    story = build_story(prompt, word_count=word_count)
    relevance_score = evaluate_story(prompt, story)

    print("\nHere’s your story:\n")
    print("=" * 50)
    print(story)
    print("=" * 50)
    print(f"\n🔍 Relevance to prompt: {relevance_score:.2f}%")
    print(f"📊 Word Count (Story): {count_words(story)}")
    print("\n🌟 Thank you for using the AI Story Generator! 🌟")


if __name__ == "__main__":
    main()

✨ Welcome to the AI Story Generator! ✨



What should your story be about? (e.g., 'A young wizard finds a magic book'):  In a kingdom ruled by dragons, a blacksmith's apprentice forges a weapon powerful enough to challenge the beasts.

How many words should the story be (approx.)? (e.g., 100):  300



🌟 Generating your story... Please wait! 🌟
