In [None]:
!pip install -q transformers torch gradio accelerate sentencepiece

print("✅ Libraries installed successfully!")

import torch
import gradio as gr
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    T5ForConditionalGeneration,
    T5Tokenizer
)
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🚀 Using device: {device}")

class StoryGenerator:
    def __init__(self):
        self.models = {}
        self.current_model = None

    def load_gpt2_model(self):
        """Load GPT-2 model for story generation"""
        print("📚 Loading GPT-2 model...")
        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        model = AutoModelForCausalLM.from_pretrained("gpt2")

        # Add padding token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        self.models['gpt2'] = {
            'tokenizer': tokenizer,
            'model': model,
            'pipeline': pipeline(
                'text-generation',
                model=model,
                tokenizer=tokenizer,
                device=0 if torch.cuda.is_available() else -1
            )
        }
        print("✅ GPT-2 model loaded successfully!")

    def load_genre_story_model(self):
        """Load specialized genre-based story generation model"""
        print("📖 Loading genre-based story model...")
        model_name = "aspis/gpt2-genre-story-generation"

        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token

        self.models['genre'] = {
            'tokenizer': tokenizer,
            'model': model,
            'pipeline': pipeline(
                'text-generation',
                model=model,
                tokenizer=tokenizer,
                device=0 if torch.cuda.is_available() else -1
            )
        }
        print("✅ Genre-based story model loaded successfully!")

    def load_flan_t5_model(self):
        """Load FLAN-T5 model for story generation"""
        print("🤖 Loading FLAN-T5 model...")
        model_name = "google/flan-t5-base"

        tokenizer = T5Tokenizer.from_pretrained(model_name)
        model = T5ForConditionalGeneration.from_pretrained(model_name)

        self.models['flan-t5'] = {
            'tokenizer': tokenizer,
            'model': model,
            'pipeline': pipeline(
                'text2text-generation',
                model=model,
                tokenizer=tokenizer,
                device=0 if torch.cuda.is_available() else -1
            )
        }
        print("✅ FLAN-T5 model loaded successfully!")

# Initialize the story generator
story_gen = StoryGenerator()

def generate_story_gpt2(prompt, max_length=500, temperature=0.8, top_p=0.9, top_k=50):
    """Generate story using GPT-2 model"""
    if 'gpt2' not in story_gen.models:
        story_gen.load_gpt2_model()

    pipeline = story_gen.models['gpt2']['pipeline']

    # Generate story
    result = pipeline(
        prompt,
        max_length=max_length,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        do_sample=True,
        num_return_sequences=1,
        pad_token_id=pipeline.tokenizer.eos_token_id
    )

    return result[0]['generated_text']

def generate_story_genre(prompt, genre="adventure", max_length=400, temperature=0.9):
    """Generate story using genre-specific model"""
    if 'genre' not in story_gen.models:
        story_gen.load_genre_story_model()

    # Format prompt for genre model
    genre_prompt = f"<BOS> <{genre}> {prompt}"

    pipeline = story_gen.models['genre']['pipeline']

    result = pipeline(
        genre_prompt,
        max_length=max_length,
        temperature=temperature,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        repetition_penalty=1.2,
        pad_token_id=pipeline.tokenizer.eos_token_id
    )

    generated_text = result[0]['generated_text']
    # Remove the genre prefix from output
    if genre_prompt in generated_text:
        generated_text = generated_text.replace(genre_prompt, "").strip()

    return generated_text

def generate_story_flan_t5(prompt, max_length=400):
    """Generate story using FLAN-T5 model"""
    if 'flan-t5' not in story_gen.models:
        story_gen.load_flan_t5_model()

    # Format prompt for T5
    formatted_prompt = f"Write a creative story based on this prompt: {prompt}"

    pipeline = story_gen.models['flan-t5']['pipeline']

    result = pipeline(
        formatted_prompt,
        max_length=max_length,
        temperature=0.8,
        do_sample=True,
        top_p=0.9
    )

    return result[0]['generated_text']

def generate_story_openai(prompt, model="gpt-3.5-turbo", max_tokens=800):
    """Generate story using OpenAI API (requires API key)"""
    try:
        import openai

        # You need to set your OpenAI API key here
        # openai.api_key = "your-api-key-here"

        response = openai.ChatCompletion.create(
            model=model,
            messages=[
                {"role": "system", "content": "You are a creative story writer. Write engaging, imaginative stories based on the given prompts."},
                {"role": "user", "content": f"Write a creative story based on this prompt: {prompt}"}
            ],
            max_tokens=max_tokens,
            temperature=0.8
        )

        return response.choices[0].message.content
    except ImportError:
        return "OpenAI library not installed. Please install with: !pip install openai"
    except Exception as e:
        return f"Error with OpenAI API: {str(e)}"


def generate_story(prompt, model_choice, genre="adventure", max_length=500, temperature=0.8, top_p=0.9, top_k=50):
    """Main function to generate stories based on model choice"""

    if not prompt.strip():
        return "⚠️ Please enter a prompt to generate a story!"

    try:
        if model_choice == "GPT-2 (Fast)":
            return generate_story_gpt2(prompt, max_length, temperature, top_p, top_k)

        elif model_choice == "Genre-Based GPT-2":
            return generate_story_genre(prompt, genre, max_length, temperature)

        elif model_choice == "FLAN-T5 (Creative)":
            return generate_story_flan_t5(prompt, max_length)

        elif model_choice == "OpenAI GPT (API Required)":
            return generate_story_openai(prompt)

        else:
            return "❌ Invalid model choice!"

    except Exception as e:
        return f"❌ Error generating story: {str(e)}"

def create_gradio_interface():
    """Create and launch Gradio interface"""

    with gr.Blocks(title="🎭 AI Story Generator", theme=gr.themes.Soft()) as interface:

        gr.HTML("""
        <div style="text-align: center; padding: 20px;">
            <h1>🎭 AI Story Generator</h1>
            <p>Transform your single-line prompts into captivating stories!</p>
        </div>
        """)

        with gr.Row():
            with gr.Column(scale=1):
                # Input components
                prompt_input = gr.Textbox(
                    label="📝 Story Prompt",
                    placeholder="Enter your story idea here... (e.g., 'A detective finds a mysterious key')",
                    lines=3
                )

                model_choice = gr.Dropdown(
                    choices=[
                        "GPT-2 (Fast)",
                        "Genre-Based GPT-2",
                        "FLAN-T5 (Creative)",
                        "OpenAI GPT (API Required)"
                    ],
                    value="GPT-2 (Fast)",
                    label="🤖 Model Selection"
                )

                genre_choice = gr.Dropdown(
                    choices=["adventure", "romance", "mystery-&-detective", "fantasy",
                            "humor-&-comedy", "paranormal", "science-fiction"],
                    value="adventure",
                    label="🎨 Genre (for Genre-Based model)",
                    visible=False
                )

                # Advanced settings
                with gr.Accordion("⚙️ Advanced Settings", open=False):
                    max_length = gr.Slider(
                        minimum=100, maximum=1000, value=500, step=50,
                        label="📏 Max Length"
                    )
                    temperature = gr.Slider(
                        minimum=0.1, maximum=2.0, value=0.8, step=0.1,
                        label="🌡️ Creativity (Temperature)"
                    )
                    top_p = gr.Slider(
                        minimum=0.1, maximum=1.0, value=0.9, step=0.1,
                        label="🎯 Focus (Top-p)"
                    )
                    top_k = gr.Slider(
                        minimum=10, maximum=100, value=50, step=10,
                        label="🔢 Vocabulary Size (Top-k)"
                    )

                generate_btn = gr.Button("✨ Generate Story", variant="primary", size="lg")

            with gr.Column(scale=2):
                # Output component
                story_output = gr.Textbox(
                    label="📖 Generated Story",
                    lines=20,
                    max_lines=30,
                    show_copy_button=True
                )

        # Example prompts
        gr.Examples(
            examples=[
                ["A robot discovers emotions for the first time"],
                ["Two strangers get stuck in an elevator during a power outage"],
                ["A child finds a door in their basement that leads to another world"],
                ["The last bookstore on Earth refuses to close"],
                ["A time traveler accidentally changes the wrong historical event"]
            ],
            inputs=[prompt_input],
            label="💡 Example Prompts"
        )

        # Show/hide genre selection based on model choice
        def update_genre_visibility(model):
            return gr.update(visible=(model == "Genre-Based GPT-2"))

        model_choice.change(
            fn=update_genre_visibility,
            inputs=[model_choice],
            outputs=[genre_choice]
        )

        # Generate story on button click
        generate_btn.click(
            fn=generate_story,
            inputs=[prompt_input, model_choice, genre_choice, max_length, temperature, top_p, top_k],
            outputs=[story_output]
        )

        # Footer
        gr.HTML("""
        <div style="text-align: center; padding: 20px; color: #666;">
            <p>🚀 Powered by Hugging Face Transformers & Gradio</p>
            <p>💡 Tip: Try different models and settings to get varied storytelling styles!</p>
        </div>
        """)

    return interface


# Create and launch the interface
print("🎉 Creating Gradio interface...")
interface = create_gradio_interface()

# Launch with public sharing enabled for Colab
print("🚀 Launching AI Story Generator...")
interface.launch(
    share=True,  # Creates public link for sharing
    debug=True,  # Enable debug mode
    server_name="0.0.0.0",  # Allow external access
    server_port=7860
)



def quick_test():
    """Quick test function to verify everything works"""
    test_prompt = "A mysterious door appears in the garden"
    print(f"🧪 Testing with prompt: '{test_prompt}'")

    try:
        story = generate_story_gpt2(test_prompt, max_length=200)
        print("✅ Test successful!")
        print(f"📖 Generated story preview: {story[:100]}...")
    except Exception as e:
        print(f"❌ Test failed: {e}")



print("\n🎉 Setup complete! Your AI Story Generator is ready to use!")
print("📝 Enter a prompt and watch the magic happen!")


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m64.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m52.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m34.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━