In [1]:
!pip install gradio transformers torch --upgrade


Collecting gradio
  Downloading gradio-5.46.0-py3-none-any.whl.metadata (16 kB)
Collecting gradio-client==1.13.0 (from gradio)
  Downloading gradio_client-1.13.0-py3-none-any.whl.metadata (7.1 kB)
Downloading gradio-5.46.0-py3-none-any.whl (60.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.3/60.3 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading gradio_client-1.13.0-py3-none-any.whl (325 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m325.0/325.0 kB[0m [31m29.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: gradio-client, gradio
  Attempting uninstall: gradio-client
    Found existing installation: gradio_client 1.12.1
    Uninstalling gradio_client-1.12.1:
      Successfully uninstalled gradio_client-1.12.1
  Attempting uninstall: gradio
    Found existing installation: gradio 5.44.1
    Uninstalling gradio-5.44.1:
      Successfully uninstalled gradio-5.44.1
Successfully installed gradio-5.46.0 gradio-clie

In [7]:
import gradio as gr
from transformers import pipeline
import tempfile
import os

# Load a genre-specific GPT-2 model from Hugging Face
# This model is specifically trained to generate stories based on genre tokens
generator = pipeline("text-generation", model="pranavpsv/gpt2-genre-story-generator", framework="pt")

# A dictionary mapping the UI genres to the model's required input tokens
# This fine-tuned model understands specific genre markers like '<superhero>', '<sci_fi>'
genre_tokens = {
    "superhero": "<superhero>",
    "action": "<action>",
    "drama": "<drama>",
    "horror": "<horror>",
    "thriller": "<thriller>",
    "sci-fi": "<sci_fi>",
}

# The story generation function, updated to use the model's specific prompt format
def generate_story(prompt, genre):
    # Construct the prompt with the required genre token at the beginning
    # The format is typically: <genre_token> <your_prompt>
    # We'll use a default if the genre isn't in our list to prevent errors
    genre_prefix = genre_tokens.get(genre, "")
    full_prompt = f"{genre_prefix} {prompt}"

    # Generate the story
    # Note: max_length and temperature are key parameters for controlling generation
    result = generator(
        full_prompt,
        max_length=150,
        num_return_sequences=1,
        temperature=0.8
    )
    story_text = result[0]["generated_text"]

    # Create a temporary file to save the story for download
    with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") as temp_file:
        temp_file.write(story_text)
        temp_file_path = temp_file.name

    # Return the story text and the file path for the Gradio outputs
    return story_text, temp_file_path

# Gradio UI
iface = gr.Interface(
    fn=generate_story,
    inputs=[
        gr.Textbox(
            label="Enter your prompt",
            placeholder="e.g., A bee sees a rose flower"
        ),
        # Update the dropdown to reflect the genres the fine-tuned model understands
        gr.Dropdown(
            ["superhero", "action", "drama", "horror", "thriller", "sci-fi"],
            label="Select Genre"
        ),
    ],
    outputs=[
        gr.Textbox(label="Generated Story", lines=10),
        gr.File(label="Download Story")
    ],
    title="AI Story Generator (Fine-Tuned)",
    description="This version uses a specialized model to generate stories that correctly follow the selected genre."
)

iface.launch(share=True)


config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/510M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/510M [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/166 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/203 [00:00<?, ?B/s]

Device set to use cuda:0


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://e8298e461ee8721283.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


