In [None]:
!pip install gradio

In [None]:
!pip install tiktoken

In [None]:
import nltk
nltk.download('punkt')
nltk.download('punkt_tab')
nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')

In [None]:
from diffusers import StableDiffusionPipeline
import torch
import gradio as gr
import textwrap
import nltk
import spacy
import json
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from nltk.tokenize import sent_tokenize
from diffusers import DiffusionPipeline
import os
import requests
import random
import re
import tiktoken

# Groq API configuration
GROQ_API_KEY = "YOUR_GROQ_API_KEY"
GROQ_API_URL = "https://api.groq.com/openai/v1/chat/completions"

# Determine the device to run the pipeline on
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the diffusion pipeline
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    use_safetensors=True,
    variant="fp16" if torch.cuda.is_available() else None
).to(device)

# Load NLP Dependencies
nltk.download('punkt')
nlp = spacy.load("en_core_web_sm")

# Font Configuration
DEFAULT_FONT = "/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf"
FALLBACK_FONT = "/usr/share/fonts/truetype/liberation/LiberationSans-Bold.ttf"

# Memory for Character Consistency
CHARACTER_MEMORY = {}
MEMORY = {}

# Initialize tokenizer for counting tokens
encoding = tiktoken.get_encoding("cl100k_base")  # Using OpenAI's tokenizer as an approximation

def count_tokens(text):
    """Count the number of tokens in a text string"""
    tokens = encoding.encode(text)
    return len(tokens)

def query_groq_api(prompt, model="llama3-70b-8192"):
    """Send a query to the Groq API"""
    headers = {
        "Authorization": f"Bearer {GROQ_API_KEY}",
        "Content-Type": "application/json"
    }

    data = {
        "model": model,
        "messages": [{"role": "user", "content": prompt}],
        "temperature": 0.7,
        "max_tokens": 1000
    }

    try:
        response = requests.post(GROQ_API_URL, headers=headers, json=data)
        response.raise_for_status()
        return response.json()["choices"][0]["message"]["content"]
    except Exception as e:
        print(f"Error querying Groq API: {e}")
        return None

def normalize_names(story):
    """
    Extract and normalize all potential character names in the story
    Returns a dictionary mapping original names to normalized (lowercase) names
    """
    doc = nlp(story)
    name_map = {}

    # Extract named entities that could be people
    for ent in doc.ents:
        if ent.label_ == "PERSON":
            normalized_name = ent.text.lower()
            name_map[ent.text] = normalized_name

    # Look for common animal names
    animals = ["dog", "cat", "horse", "bird", "fish", "bear", "lion", "tiger"]
    for animal in animals:
        if animal in story.lower():
            # Check if the animal might have a proper name nearby
            pattern = rf"(?i)([A-Z][a-z]+ (?:the )?{animal}|the {animal} named ([A-Z][a-z]+))"
            matches = re.findall(pattern, story)

            if matches:
                for match in matches:
                    for group in match:
                        if group and not group.lower() == f"the {animal}":
                            name_map[group] = group.lower()
            else:
                # Just use the animal type as a name
                capitalized = animal.capitalize()
                name_map[capitalized] = animal.lower()

    return name_map

def generate_simple_outfit(character_type):
    """Generate a simple, easy-to-understand outfit based on character type"""

    # Basic outfit components with simple color descriptions
    if character_type.lower() == "human":
        # Simple color palette for human outfits
        basic_colors = ["red", "blue", "green", "yellow", "black", "white", "purple", "brown", "gray", "pink"]

        # Simple outfit combinations
        outfit_templates = [
            "{color1} shirt with {color2} pants",
            "{color1} t-shirt and {color2} jeans",
            "{color1} sweater and {color2} trousers",
            "{color1} dress with {color2} belt",
            "{color1} blouse with {color2} skirt",
            "{color1} jacket over {color2} shirt",
            "plain {color1} shirt and {color2} shorts"
        ]

        # Pick distinct colors
        color1 = random.choice(basic_colors)
        color2 = random.choice([c for c in basic_colors if c != color1])

        # Generate the outfit description
        outfit = random.choice(outfit_templates).format(color1=color1, color2=color2)

        # Add simple hair description
        hair_colors = ["black", "brown", "blonde", "red", "gray", "white"]
        hair_types = ["short", "long", "curly", "straight", "wavy"]
        hair = f"{random.choice(hair_types)} {random.choice(hair_colors)} hair"

        return {
            "outfit": outfit,
            "hair": hair,
            "top": f"{color1} shirt" if "shirt" in outfit else f"{color1} {outfit.split()[0]}",
            "bottom": f"{color2} pants" if "pants" in outfit else f"{color2} {outfit.split()[-1]}"
        }

    elif character_type.lower() in ["animal", "bird", "insect"]:
        animal_colors = {
            "dog": ["brown", "black", "white", "tan", "golden"],
            "cat": ["orange", "black", "white", "gray", "brown"],
            "bird": ["blue", "red", "yellow", "green", "black"],
            "insect": ["green", "black", "red", "brown", "yellow"],
            "animal": ["brown", "gray", "tan", "white", "black"]
        }

        animal_patterns = ["solid", "spotted", "striped", "patched", "mixed"]

        # Choose appropriate color for the animal type
        animal_key = character_type.lower()
        if animal_key not in animal_colors:
            animal_key = "animal"

        color = random.choice(animal_colors[animal_key])
        pattern = random.choice(animal_patterns)

        if pattern == "solid":
            body_color = f"{color}"
        else:
            second_color = random.choice([c for c in animal_colors[animal_key] if c != color])
            body_color = f"{color} with {second_color} {pattern} pattern"

        return {
            "body_color": body_color,
            "species_details": f"typical {character_type.lower()} features",
            # Add simpler color description for prompt
            "simple_color": color
        }

    elif character_type.lower() in ["object", "other"]:
        colors = ["red", "blue", "green", "yellow", "brown", "gray", "black", "white"]
        material = ["wooden", "metal", "plastic", "glass", "stone", "fabric"]

        description = f"{random.choice(colors)} {random.choice(material)}"

        return {
            "key_features": description
        }

    return {}

def analyze_story_with_groq(story):
    """Use Groq to analyze the full story and extract characters with detailed descriptions"""
    # First normalize all character names to avoid duplicates
    name_map = normalize_names(story)
    print(f"Normalized name map: {name_map}")

    # Create a simplified prompt for Groq
    prompt = f"""
    Analyze this story and extract all characters.

    For each character, provide a basic description.

    Return response in this exact JSON format:
    {{
        "characters": [
            {{
                "name": "Character name",
                "gender": "male/female/other",
                "type": "human/animal/bird/insect/object/other"
            }},
            ...
        ]
    }}

    IMPORTANT REQUIREMENTS:
    1. Extract ALL characters mentioned in the story
    2. Use simple categories for character types
    3. List each unique character ONCE only
    4. If gender is unclear, use "unknown"

    Story:
    {story}
    """

    try:
        response = query_groq_api(prompt)
        if response:
            # Clean the response to ensure it's valid JSON
            response = response.strip()
            # Remove any markdown code block markers
            if response.startswith('```json'):
                response = response[7:]
            elif response.startswith('```'):
                response = response[3:]
            if response.endswith('```'):
                response = response[:-3]

            # Try to extract JSON content if embedded in other text
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                response = json_match.group(0)

            # Parse the JSON
            data = json.loads(response)
            print(f"Extracted {len(data.get('characters', []))} characters")

            # Deduplicate characters by normalizing names and comparing to name_map
            unique_characters = {}
            for char in data.get("characters", []):
                char_name = char.get("name", "").strip()

                # Normalize the character name
                char_name_lower = char_name.lower()

                # Check if this is a known character from our name_map
                found = False
                for original_name, normalized in name_map.items():
                    if normalized == char_name_lower or original_name.lower() == char_name_lower:
                        # Use the original capitalization from the story
                        char["name"] = original_name
                        unique_characters[normalized] = char
                        found = True
                        break

                if not found:
                    # If not found in map, add with lowercase name as key
                    unique_characters[char_name_lower] = char

            # Generate simple outfit descriptions for each character
            final_characters = []
            for normalized_name, char in unique_characters.items():
                # Generate outfit description based on character type
                visual_description = generate_simple_outfit(char.get("type", "human"))
                char["visual_description"] = visual_description
                final_characters.append(char)

                print(f"Generated outfit for {char.get('name')}: {visual_description}")

            return {"characters": final_characters}
        return {"characters": []}
    except Exception as e:
        print(f"Error parsing Groq API response: {e}")
        print(f"Raw response: {response[:200]}...")
        # Create basic character info based on story text analysis as fallback
        fallback_characters = create_fallback_character_info(story)
        return {"characters": fallback_characters}

def create_fallback_character_info(story):
    """Create basic character information if Groq API fails"""
    # Use name normalization
    name_map = normalize_names(story)

    characters = []
    processed_names = set()

    # Process names from name_map
    for original, normalized in name_map.items():
        if normalized not in processed_names:
            # Determine if the name suggests a human or animal
            char_type = "human"
            for animal in ["dog", "cat", "horse", "bird", "fish", "bear", "lion", "tiger"]:
                if animal in normalized or animal in original.lower():
                    char_type = "animal"
                    break

            # Generate simple outfit
            visual_description = generate_simple_outfit(char_type)

            characters.append({
                "name": original,
                "gender": "unknown",
                "type": char_type,
                "visual_description": visual_description
            })

            processed_names.add(normalized)

    return characters

def enhance_sentence_with_groq(sentence, story_context, character_info=None):
    """Use Groq to simplify and enhance a sentence for image generation, focusing on the exact action"""
    # Create a context summary of characters for Groq
    character_context = ""
    if character_info and 'characters' in character_info:
        for char in character_info.get('characters', []):
            if char.get('type', '').lower() == 'human':
                visual = char.get('visual_description', {})
                character_context += f"{char.get('name')}: {visual.get('outfit', 'no outfit')}. "
            elif char.get('type', '').lower() in ['animal', 'bird', 'insect']:
                visual = char.get('visual_description', {})
                # Use simpler color description for animals
                simple_color = visual.get('simple_color', visual.get('body_color', '').split()[0])
                character_context += f"{char.get('name')}: {simple_color} {char.get('type')}. "

    prompt = f"""
    Transform this sentence into a CLEAR, SIMPLE image description for comic generation.

    Character information:
    {character_context}

    IMPORTANT RULES:
    - Focus ONLY on the MAIN ACTION in the sentence
    - Use SIMPLE, DIRECT language that's easy to visualize
    - MAXIMUM 15 words
    - Include only key characters and their primary action
    - Start with the main character or action
    - Be extremely specific about what characters are DOING
    - Include Indian context elements when relevant (settings, objects, etc.)
    - Use active voice ("person does thing" not "thing is done")
    - For animals, use simple descriptions like "brown dog" instead of complex patterns
    Sentence: {sentence}

    Provide ONLY the image description without explanation or additional text.
    """

    try:
        enhanced = query_groq_api(prompt)
        if enhanced:
            enhanced = enhanced.strip()

            # Remove quotation marks if present
            if enhanced.startswith('"') and enhanced.endswith('"'):
                enhanced = enhanced[1:-1]

            # Simple word count limit
            words = enhanced.split()
            if len(words) > 20:
                enhanced = " ".join(words[:20])

            return enhanced
        return sentence
    except Exception as e:
        print(f"Error enhancing sentence with Groq: {e}")
        return sentence

def extract_character_concise_description(character):
    """Extract a concise but detailed description of character for the prompt"""
    desc = ""

    if character["type"].lower() == "human":
        visual = character.get("visual_description", {})
        # Use the simple outfit description
        outfit = visual.get("outfit", "").strip()
        if outfit:
            desc = outfit
        else:
            top = visual.get("top", "").strip()
            bottom = visual.get("bottom", "").strip()
            if top and bottom:
                desc = f"{top} with {bottom}"
            elif top:
                desc = top
            elif bottom:
                desc = bottom

    elif character["type"].lower() in ["animal", "bird", "insect"]:
        visual = character.get("visual_description", {})
        # For animals, use simpler color description
        simple_color = visual.get("simple_color", "")
        if simple_color:
            desc = f"{simple_color} color {character['type'].lower()}"
        else:
            # Fallback to the first color in body_color if available
            body_color = visual.get("body_color", "")
            if body_color:
                first_color = body_color.split()[0]
                desc = f"{first_color} color {character['type'].lower()}"
            else:
                desc = character["type"].lower()

    elif character["type"].lower() in ["object", "other"]:
        visual = character.get("visual_description", {})
        if visual.get("key_features"):
            desc = visual.get("key_features", "").strip()

    return desc[:50]  # Keep descriptions shorter for clarity

def incorporate_character_details(text, character_info):
    """Add simple character details to prompt - improved version"""
    if not character_info or 'characters' not in character_info:
        return text

    modified_text = text
    characters = character_info.get('characters', [])

    # Create name map dictionary for normalized name lookup
    name_map = {}
    char_descriptions = {}
    
    for char in characters:
        char_name = char.get('name', '')
        name_map[char_name.lower()] = char
        # Store the simple description for each character
        char_descriptions[char_name.lower()] = extract_character_concise_description(char)

    # Set to keep track of characters we've already added details for
    processed_characters = set()

    # Check for character matches in the text
    for char_name, char_data in name_map.items():
        # Use case-insensitive search with word boundaries
        pattern = re.compile(r'\b' + re.escape(char_name) + r'\b', re.IGNORECASE)
        
        # Check if character name is in the text and not already processed
        if pattern.search(modified_text) and char_name not in processed_characters:
            # Get the description
            description = char_descriptions.get(char_name, "")
            
            if description:
                # Find actual match to preserve capitalization
                match = pattern.search(modified_text)
                if match:
                    actual_name = match.group(0)
                    
                    # Format the description properly based on character type
                    if char_data.get('type', '').lower() == 'human':
                        replacement = f"{actual_name} (wearing {description})"
                    elif char_data.get('type', '').lower() in ['animal', 'bird', 'insect']:
                        replacement = f"{actual_name} ({description})"
                    else:
                        replacement = f"{actual_name} ({description})"
                    
                    # Remove any duplicates of the description already in the text
                    # Check if the description is already in the text after the character name
                    if description.lower() in modified_text.lower():
                        # Create a pattern to find places where the description is duplicated
                        desc_pattern = re.compile(r'\b' + re.escape(actual_name) + r'\b\s*\([^)]*\)\s*(?:in|with|wearing)\s+' + 
                                                re.escape(description), re.IGNORECASE)
                        if desc_pattern.search(modified_text):
                            # If the description appears after the character name with parentheses,
                            # replace the whole thing with just our replacement
                            modified_text = desc_pattern.sub(replacement, modified_text)
                        else:
                            # Otherwise, just add our replacement
                            modified_text = pattern.sub(replacement, modified_text, 1)
                    else:
                        # If no duplicate, just add our replacement
                        modified_text = pattern.sub(replacement, modified_text, 1)
                    
                    # Mark this character as processed
                    processed_characters.add(char_name)

    return modified_text

def create_image_prompt(sentence, character_info, style):
    """Create a simple prompt for Stable Diffusion"""
    # Start with core sentence with character details
    sentence_with_chars = incorporate_character_details(sentence, character_info)

    # Add style information
    basic_style = ""
    if style.lower() == "manga":
        basic_style = "manga style"
    elif style.lower() == "cartoon":
        basic_style = "cartoon style"
    else:
        basic_style = f"{style} style"

    # Simple, clear prompt
    prompt = f"{sentence_with_chars}. {basic_style}, clear image"

    token_count = count_tokens(prompt)
    print(f"Final image prompt ({token_count} tokens): {prompt}")

    # If somehow still over limit, truncate
    if token_count > 60:
        words = prompt.split()
        reduced_prompt = " ".join(words[:len(words)//2])
        print(f"Prompt was too long. Reduced to: {reduced_prompt}")
        return reduced_prompt

    return prompt

def generate_image(sentence, story_context, character_info, style):
    """Generate an image for a sentence using Stable Diffusion with detailed prompts"""
    if pipe is None:
        print("[Error] Model is not available. Returning a blank image.")
        return Image.new("RGB", (768, 768), "gray")

    print(f"Processing sentence: {sentence}")

    # Step 1: Get a more detailed, action-focused description for the image
    enhanced_sentence = enhance_sentence_with_groq(sentence, story_context, character_info)
    print(f"Enhanced action description: {enhanced_sentence}")

    # Step 2: Create final prompt for Stable Diffusion
    full_prompt = create_image_prompt(enhanced_sentence, character_info, style)

    try:
        with torch.inference_mode():
            # Add negative prompt to avoid confusing elements
            negative_prompt = "deformed, bad anatomy, disfigured, poorly drawn face, mutation, mutated, extra limb, ugly, poorly drawn hands, missing limb, floating limbs, disconnected limbs, malformed hands, blurry, watermark, text, grainy"

            image = pipe(
                prompt=full_prompt,
                negative_prompt=negative_prompt,
                height=768,
                width=768,
                guidance_scale=8.0,  # Slightly higher guidance scale for better prompt adherence
                num_inference_steps=40  # More steps for better quality
            ).images[0]

        # Store character memory to maintain consistency in future panels
        for char in character_info.get('characters', []):
            if char.get('name', '').lower() in sentence.lower():
                CHARACTER_MEMORY[char.get('name', '')] = True

        # Store the image in memory
        MEMORY[sentence] = image
        return image

    except Exception as e:
        print(f"Error generating image: {e}")
        # Return a blank image with error message
        img = Image.new("RGB", (768, 768), "gray")
        draw = ImageDraw.Draw(img)
        try:
            font = ImageFont.truetype(DEFAULT_FONT, 24)
        except:
            font = ImageFont.truetype(FALLBACK_FONT, 24)
        draw.text((50, 350), f"Image generation failed: {str(e)}", fill="white")
        return img

def create_comic_panel(image, narration):
    """Create a comic panel with the image and narration text"""
    panel_width, panel_height = 768, 900
    image = image.resize((768, 768))
    panel = Image.new("RGB", (panel_width, panel_height), "white")
    panel.paste(image, (0, 0))

    draw = ImageDraw.Draw(panel)
    try:
        font = ImageFont.truetype(DEFAULT_FONT, 22)
    except OSError:
        font = ImageFont.truetype(FALLBACK_FONT, 22)

    wrapped_text = textwrap.fill(narration, width=70)
    draw.text((20, 780), wrapped_text, font=font, fill="black")
    return panel

def create_comic_strip(images, narrations):
    """Create a comic strip from individual panels"""
    panel_width, panel_height = 768, 900
    columns, spacing = 2, 20
    rows = (len(images) + 1) // columns
    strip_width = columns * panel_width + (columns - 1) * spacing
    strip_height = rows * panel_height + (rows - 1) * spacing
    comic_strip = Image.new("RGB", (strip_width, strip_height), "white")

    for idx, (image, narration) in enumerate(zip(images, narrations)):
        panel = create_comic_panel(image, narration)
        x, y = (idx % columns) * (panel_width + spacing), (idx // columns) * (panel_height + spacing)
        comic_strip.paste(panel, (x, y))

    output_path = "comic_strip.png"
    comic_strip.save(output_path)
    return output_path

def generate_comic(story, style):
    """Generate a comic strip from a story with improved character consistency"""
    print("\n--- Comic Generation Process ---\n")

    # Split the story into sentences
    sentences = sent_tokenize(story)
    print(f"[1] Split Story into {len(sentences)} sentences.")

    # Use Groq to analyze the full story for characters and their simplified descriptions
    print("[2] Analyzing story with Groq API for character details...")
    character_info = analyze_story_with_groq(story)

    if character_info and character_info.get('characters'):
        print(f"[3] Character Analysis Complete: {len(character_info.get('characters', []))} characters detected")
        for char in character_info.get('characters', []):
            print(f"  - {char.get('name', 'Unknown')}: {char.get('type', 'unknown')}")
            if char.get('type') == 'human':
                visual = char.get('visual_description', {})
                print(f"    Human details: Outfit: {visual.get('outfit', 'N/A')}")
            elif char.get('type') in ['animal', 'bird', 'insect']:
                visual = char.get('visual_description', {})
                print(f"    Animal details: {visual.get('body_color', 'N/A')}")

            # Create simple description for the console output
            concise_desc = extract_character_concise_description(char)
            print(f"    Simple description for prompts: \"{concise_desc}\"")
    else:
        print("[3] Character Analysis Failed or No Characters Found. Proceeding with basic processing.")
        character_info = {"characters": []}

    # Generate images for each sentence
    images = []
    for idx, sentence in enumerate(sentences):
        print(f"\n[4.{idx+1}] Processing Sentence: {sentence}")

        try:
            # Generate image using the sentence and character info
            image = generate_image(sentence, story, character_info, style)
            images.append(image)
        except Exception as e:
            print(f"Failed to generate image for sentence: {sentence}")
            print(f"Error: {e}")
            # Create a placeholder image
            img = Image.new("RGB", (768, 768), "gray")
            draw = ImageDraw.Draw(img)
            try:
                font = ImageFont.truetype(DEFAULT_FONT, 24)
            except:
                font = ImageFont.truetype(FALLBACK_FONT, 24)
            draw.text((50, 350), f"Failed to generate: {sentence[:50]}...", fill="white")
            images.append(img)

    print("\n[5] Image Generation Complete. Creating Comic Strip...")
    comic_path = create_comic_strip(images, sentences)
    print(f"[6] Comic Strip Generated: {comic_path}")
    return comic_path

def main():
    """Main function to launch the Gradio interface"""
    gr.Interface(
        fn=generate_comic,
        inputs=[
            gr.Textbox(label="Story", placeholder="Enter your story here..."),
            gr.Textbox(label="Comic Style", placeholder="e.g., realistic, cartoon, manga", value="manga")
        ],
        outputs=gr.Image(label="Generated Comic Strip"),
        title="AI-Driven Comic Strip Generator",
        description="Turn your stories into comic strips using AI. Characters will be rendered with consistent visual attributes and accurate actions."
    ).launch(share=True, debug=True)

if __name__ == "__main__":
    main()