<a href="https://colab.research.google.com/github/WanlinTu/Maggietu.githunb.io/blob/main/SpectrumAI_Stable_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U gradio



In [None]:
# ABA Therapy Image Generator - Prompt Engineering Tool
# Run this in Google Colab with a GPU runtime

# Install Dependencies
!pip install -q diffusers transformers accelerate gradio

import torch
from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler
import gradio as gr
import random
import json
import os
from PIL import Image
from datetime import datetime
from google.colab import files
import csv
# Add memory management
import gc
import torch
torch.cuda.empty_cache()
gc.collect()

# Configure PyTorch to use less memory
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# Select model type
model_type = "Stable Diffusion 1.5"  # Options: "SDXL", "Stable Diffusion 2.1", "Stable Diffusion 1.5"

def load_model(model_type):
    """Load the specified diffusion model"""
    print(f"Loading {model_type} model. This may take a few minutes...")

    if model_type == "SDXL":
        # Stable Diffusion XL - best quality for realistic images
        model_id = "stabilityai/stable-diffusion-xl-base-1.0"
        pipe = StableDiffusionXLPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16,
            variant="fp16",
            use_safetensors=True
        )
    elif model_type == "Stable Diffusion 2.1":
        # Stable Diffusion 2.1 - good balance of quality and performance
        model_id = "stabilityai/stable-diffusion-2-1"
        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16
        )
    else:
        # Stable Diffusion 1.5 - fastest, less resource-intensive
        model_id = "runwayml/stable-diffusion-v1-5"
        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16
        )

    # Optimize for faster inference
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

    # Enable memory efficient attention for SDXL
    if model_type == "SDXL":
        pipe.enable_attention_slicing()

    # Move to GPU
    pipe = pipe.to("cuda")

    return pipe

class PromptEngineeringTool:
    def __init__(self, model_type="SDXL"):
        """Initialize the prompt engineering tool."""
        self.model_type = model_type
        self.pipe = load_model(model_type)
        self.prompt_history = []
        self.background_styles = {
            "plain_white": "with a plain white background",
            "plain_neutral": "with a plain neutral background",
            "simple": "with a simple, clean background",
            "minimal": "with a minimalist background, isolated",
            "detailed": "with a natural background"
        }
        self.style_modifiers = {
            "realistic": "photorealistic, detailed photograph",
            "cartoon": "simple cartoon style, child-friendly",
            "illustrated": "educational illustration style",
            "colorful": "bright, colorful, child-friendly",
            "simplified": "simplified, clear, high contrast"
        }

        # Create a prompt history file
        self._create_prompt_log()

    def _create_prompt_log(self):
        """Create or load the prompt history CSV file."""
        self.log_file = 'aba_prompt_history.csv'

        # Check if file exists, if not create it with headers
        if not os.path.exists(self.log_file):
            with open(self.log_file, 'w', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['Timestamp', 'Concept', 'Prompt', 'Background', 'Style', 'Rating', 'Notes'])

    def create_prompt(self, concept, background_style="plain_white", style="realistic", custom_modifiers=""):
        """Create a well-engineered prompt based on inputs."""

        # Get background text
        background = self.background_styles.get(background_style, "")

        # Get style modifier
        style_text = self.style_modifiers.get(style, "")

        # Build base prompt
        if self.model_type == "SDXL":
            # SDXL responds well to detailed prompts
            base_prompt = f"A clear image of a {concept} {background}, {style_text}, educational, high quality, perfect for children"
        else:
            # Simpler prompt for SD models
            base_prompt = f"A {concept} {background}, {style_text}, educational"

        # Add custom modifiers if provided
        if custom_modifiers:
            base_prompt += f", {custom_modifiers}"

        return base_prompt

    def generate_images(self, prompt, num_images=4, seed=None):
        """Generate multiple images using the provided prompt."""
        images = []

        # Set random seed for reproducibility if provided
        if seed is not None and seed != 0:
            random.seed(seed)
            base_seed = seed
        else:
            base_seed = random.randint(1, 100000)

        # Generate multiple images with different seeds for diversity
        for i in range(num_images):
            try:
                current_seed = base_seed + i
                generator = torch.Generator("cuda").manual_seed(current_seed)

                # Generate image
                if self.model_type == "SDXL":
                    image = self.pipe(
                        prompt,
                        num_inference_steps=25,
                        guidance_scale=7.0,
                        generator=generator
                    ).images[0]
                else:
                    image = self.pipe(
                        prompt,
                        num_inference_steps=30,
                        guidance_scale=7.5,
                        generator=generator
                    ).images[0]

                images.append(image)
            except Exception as e:
                print(f"Error generating image {i+1}: {e}")
                # Create an error image
                error_img = Image.new('RGB', (512, 512), color='white')
                images.append(error_img)

        return images

    def save_prompt(self, concept, prompt, background, style, rating, notes):
        """Save the prompt to history with rating and notes."""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        # Add to prompt history list
        entry = [timestamp, concept, prompt, background, style, rating, notes]
        self.prompt_history.append(entry)

        # Save to CSV
        with open(self.log_file, 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(entry)

        # Download the file every time it's updated
        try:
            files.download(self.log_file)
        except:
            print(f"File saved but couldn't auto-download. You can download it manually.")

        return f"Prompt saved: {concept}"

    def download_history(self):
        """Force download the prompt history file."""
        files.download(self.log_file)
        return "Downloading prompt history"

# Create and launch the UI
def create_ui():
    # Initialize the tool
    tool = PromptEngineeringTool(model_type)

    with gr.Blocks(title=f"ABA Therapy Prompt Engineering Tool - {model_type}") as app:
        gr.Markdown(f"# ABA Therapy Prompt Engineering Tool\n## Using {model_type} model")

        with gr.Tab("Generate & Test Prompts"):
            with gr.Row():
                with gr.Column(scale=1):
                    # Basic inputs
                    concept = gr.Textbox(label="Concept", placeholder="cat, red apple, happy child, etc.")

                    background = gr.Radio(
                        choices=list(tool.background_styles.keys()),
                        value="plain_white",
                        label="Background Style"
                    )

                    style = gr.Radio(
                        choices=list(tool.style_modifiers.keys()),
                        value="realistic",
                        label="Visual Style"
                    )

                    custom_mods = gr.Textbox(
                        label="Custom Modifiers (optional)",
                        placeholder="bright colors, simple shapes, etc."
                    )

                    # Advanced settings
                    with gr.Accordion("Advanced Settings", open=False):
                        seed = gr.Slider(minimum=0, maximum=10000, value=0, step=1, label="Seed (0 for random)")
                        num_images = gr.Slider(minimum=1, maximum=8, value=4, step=1, label="Number of Images")

                    # Buttons
                    preview_btn = gr.Button("Preview Prompt")
                    generate_btn = gr.Button("Generate Images", variant="primary")

                with gr.Column(scale=2):
                    # Output areas
                    prompt_preview = gr.Textbox(label="Generated Prompt")
                    gallery = gr.Gallery(label="Generated Images", columns=2, height="auto")

                    # Evaluation
                    with gr.Row():
                        rating = gr.Slider(minimum=1, maximum=5, value=3, step=0.5, label="Quality Rating (1-5)")
                        notes = gr.Textbox(label="Notes about these images", placeholder="What worked/didn't work?")

                    save_btn = gr.Button("Save Prompt & Rating")
                    status = gr.Textbox(label="Status")

        with gr.Tab("Prompt History"):
            download_btn = gr.Button("Download Prompt History CSV")
            history_status = gr.Textbox(label="Status")

            gr.Markdown("""
            ## Prompt Engineering Tips for ABA Therapy Images

            1. **Clear subjects**: Use "a clear image of X" rather than just "X"
            2. **Background control**: Specify the background clearly
            3. **Educational focus**: Add "educational" to keep images appropriate
            4. **Simplification**: For younger children, add "simplified" or "clear"
            5. **Consistency**: To generate similar images, use the same seed value
            6. **Diversity**: Try different styles for the same concept
            7. **Specificity**: Be very specific about what you want to see
            8. **Plain backgrounds**: Work best for younger children (easier focus)
            """)

        # Event handlers
        def preview_prompt(concept_input, bg, style_input, custom):
            if not concept_input.strip():
                return "Please enter a concept first"
            prompt = tool.create_prompt(concept_input, bg, style_input, custom)
            return prompt

        def generate_images_from_inputs(concept_input, bg, style_input, custom, seed_val, num_imgs):
            if not concept_input.strip():
                return prompt_preview.value, [], "Please enter a concept first"

            prompt = tool.create_prompt(concept_input, bg, style_input, custom)

            try:
                images = tool.generate_images(prompt, num_imgs, seed_val)
                return prompt, images, f"Generated {len(images)} images"
            except Exception as e:
                return prompt, [], f"Error: {str(e)}"

        def save_prompt_rating(concept_input, prompt_text, bg, style_input, rating_val, notes_text):
            if not concept_input.strip() or not prompt_text.strip():
                return "Missing concept or prompt"

            try:
                message = tool.save_prompt(concept_input, prompt_text, bg, style_input, rating_val, notes_text)
                return message
            except Exception as e:
                return f"Error saving: {str(e)}"

        def download_csv():
            try:
                return tool.download_history()
            except Exception as e:
                return f"Error downloading: {str(e)}"

        # Connect event handlers
        preview_btn.click(
            preview_prompt,
            inputs=[concept, background, style, custom_mods],
            outputs=[prompt_preview]
        )

        generate_btn.click(
            generate_images_from_inputs,
            inputs=[concept, background, style, custom_mods, seed, num_images],
            outputs=[prompt_preview, gallery, status]
        )

        save_btn.click(
            save_prompt_rating,
            inputs=[concept, prompt_preview, background, style, rating, notes],
            outputs=[status]
        )

        download_btn.click(download_csv, inputs=[], outputs=[history_status])

        return app

# Launch the application
app = create_ui()
app.queue(max_size=20).launch(debug=True, share=True)

Loading Stable Diffusion 1.5 model. This may take a few minutes...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://4a7707e8868abb8a9a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]