<a href="https://colab.research.google.com/github/WanlinTu/Maggietu.githunb.io/blob/main/SpectrumAI_Stable_Diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Very Simple ABA Therapy Image Generator
# Run this in Google Colab with a GPU runtime

# Install Dependencies
!pip install -q diffusers transformers accelerate gradio

import torch
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
import gradio as gr
import random
import os
from PIL import Image
from datetime import datetime
from google.colab import files
import csv
import gc

# Memory management
torch.cuda.empty_cache()
gc.collect()

# Configure PyTorch to use less memory
torch.backends.cudnn.benchmark = True
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'

def load_model():
    """Load Stable Diffusion 1.5 model"""
    print(f"Loading Stable Diffusion 1.5 model. This may take a few minutes...")

    # Load model
    model_id = "runwayml/stable-diffusion-v1-5"
    pipe = StableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.float16
    )

    # Optimize for faster inference
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.enable_attention_slicing()
    pipe = pipe.to("cuda")

    return pipe

class SimpleImageGenerator:
    def __init__(self):
        """Initialize the simple image generator"""
        self.pipe = load_model()
        self.log_file = 'aba_prompt_history.csv'

        # Create prompt log if it doesn't exist
        if not os.path.exists(self.log_file):
            with open(self.log_file, 'w', newline='') as file:
                writer = csv.writer(file)
                writer.writerow(['Timestamp', 'Prompt', 'Rating', 'Notes', 'Seed'])

    def generate_images(self, prompt, num_images=4, seed=None):
        """Generate images using the provided prompt directly"""
        images = []
        seeds_used = []

        # Set seed for reproducibility
        if seed is not None and seed != 0:
            base_seed = seed
        else:
            base_seed = random.randint(1, 100000)

        # Generate images with different seeds
        for i in range(num_images):
            try:
                current_seed = base_seed + i
                seeds_used.append(current_seed)
                generator = torch.Generator("cuda").manual_seed(current_seed)

                # Generate image
                image = self.pipe(
                    prompt,
                    num_inference_steps=30,
                    guidance_scale=7.5,
                    generator=generator
                ).images[0]

                images.append(image)

            except Exception as e:
                print(f"Error generating image {i+1}: {e}")
                error_img = Image.new('RGB', (512, 512), color='white')
                images.append(error_img)

        return images, seeds_used

    def save_prompt(self, prompt, rating, notes, seed):
        """Save the prompt to history with rating and notes"""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        # Save to CSV
        with open(self.log_file, 'a', newline='') as file:
            writer = csv.writer(file)
            writer.writerow([timestamp, prompt, rating, notes, seed])

        # Download the file
        try:
            files.download(self.log_file)
        except:
            print(f"File saved, download manually if needed.")

        return f"Prompt saved"

# Create the UI
def create_ui():
    # Initialize the generator
    generator = SimpleImageGenerator()

    with gr.Blocks(title="Simple ABA Therapy Image Generator") as app:
        gr.Markdown("# Simple ABA Therapy Image Generator")

        with gr.Row():
            with gr.Column(scale=1):
                # Direct prompt input
                prompt = gr.Textbox(
                    label="Enter Your Prompt",
                    placeholder="A clear educational image of a cat with a plain white background, photorealistic, high quality",
                    lines=5
                )

                # Simple controls
                seed = gr.Number(label="Seed (0 for random)", value=0)
                num_images = gr.Slider(minimum=1, maximum=6, value=4, step=1, label="Number of Images")

                # Generate button
                generate_btn = gr.Button("Generate Images", variant="primary")

            with gr.Column(scale=2):
                # Output area
                gallery = gr.Gallery(label="Generated Images", columns=2, height="auto")
                seeds_output = gr.Textbox(label="Seeds Used (for reproducibility)")

                # Evaluation
                with gr.Row():
                    rating = gr.Slider(minimum=1, maximum=5, value=3, step=0.5, label="Quality Rating (1-5)")
                    notes = gr.Textbox(label="Notes", placeholder="What worked/didn't work?")

                save_btn = gr.Button("Save Prompt & Rating")
                status = gr.Textbox(label="Status")

        # Event handlers
        def generate_images_direct(prompt_text, seed_val, num_imgs):
            if not prompt_text.strip():
                return [], "Please enter a prompt", ""

            try:
                images, seeds = generator.generate_images(
                    prompt=prompt_text,
                    num_images=num_imgs,
                    seed=seed_val
                )
                seeds_str = ", ".join(map(str, seeds))
                return images, f"Generated {len(images)} images", seeds_str
            except Exception as e:
                return [], f"Error: {str(e)}", ""

        def save_prompt_rating(prompt_text, rating_val, notes_text, seeds_txt):
            if not prompt_text.strip():
                return "Missing prompt"

            # Extract the first seed
            seeds = seeds_txt.split(", ")
            seed = int(seeds[0]) if seeds and seeds[0].isdigit() else 0

            try:
                message = generator.save_prompt(
                    prompt=prompt_text,
                    rating=rating_val,
                    notes=notes_text,
                    seed=seed
                )
                return message
            except Exception as e:
                return f"Error saving: {str(e)}"

        # Connect event handlers
        generate_btn.click(
            generate_images_direct,
            inputs=[prompt, seed, num_images],
            outputs=[gallery, status, seeds_output]
        )

        save_btn.click(
            save_prompt_rating,
            inputs=[prompt, rating, notes, seeds_output],
            outputs=[status]
        )

        return app

# Launch the application
app = create_ui()
app.queue(max_size=20).launch(debug=True, share=True)

Loading Stable Diffusion 1.5 model. This may take a few minutes...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://fa41dce38d89d7d5c8.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

Potential NSFW content was detected in one or more images. A black image will be returned instead. Try again with a different prompt and/or seed.


  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]