# Advanced Stable Diffusion Image Generation with Gradio UI

This notebook allows you to generate images from text prompts using Stable Diffusion with an enhanced user interface. It runs on a free Google Colab GPU and provides a Gradio interface for interaction.

**Instructions:**
1.  **Enable GPU:** Go to `Runtime` -> `Change runtime type` and select `GPU` (e.g., T4) as the hardware accelerator.
2.  **Run Cells:** Execute each cell in order.
    *   **Cell 1:** Installs necessary libraries.
    *   **Cell 2:** Imports libraries, defines style presets, helper functions, and the core image generation logic.
    *   **Cell 3:** Defines the Gradio User Interface structure.
    *   **Cell 4:** Loads the Stable Diffusion model (this may take a few minutes, especially on the first run as it downloads model weights). Choose the model to load here.
    *   **Cell 5:** Launches the Gradio interface. Click the public URL it provides to open the UI.

In [None]:
# @title 1. Install Dependencies
# We need diffusers, transformers, accelerate for Stable Diffusion,
# gradio for the UI, and bitsandbytes for potential 8-bit optimization (optional but good to have).
!pip install diffusers transformers accelerate gradio bitsandbytes --quiet
print("Dependencies installed.")

## 2. Import Libraries, Define Presets, Helpers, and Generation Logic

In [None]:
# @title 2. Imports, Presets, Helpers, and Generation Logic
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler, AutoPipelineForText2Image
from PIL import Image
import random
import time
import os
import sys # For checking if in Colab
import traceback # For detailed error logging

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# --- Global Variables ---
pipe = None
current_model_id = None # Will be set in cell 4
DEFAULT_COLAB_MODEL_ID = "runwayml/stable-diffusion-v1-5" # Good default for free Colab tier
# For SDXL (might be too heavy for free Colab T4, consider A100 if available):
# DEFAULT_COLAB_MODEL_ID = "stabilityai/sdxl-base-1.0"

# --- Style Presets ---
STYLE_PRESETS = {
    "None": {"prompt_suffix": "", "negative_prompt_prefix": ""},
    "Realistic": {"prompt_suffix": "photorealistic, 4k, ultra detailed, cinematic lighting, professional photography", "negative_prompt_prefix": "cartoon, anime, drawing, sketch, stylized, illustration, painting, art"},
    "Anime": {"prompt_suffix": "anime style, key visual, vibrant, beautiful, detailed, official art, by makoto shinkai", "negative_prompt_prefix": "photorealistic, 3d render, ugly, disfigured, real life"},
    "Fantasy Art": {"prompt_suffix": "fantasy art, detailed, intricate, epic, trending on artstation, by greg rutkowski, Brom, Frank Frazetta", "negative_prompt_prefix": "photorealistic, modern, simple, photo"},
    "Digital Painting": {"prompt_suffix": "digital painting, concept art, smooth, sharp focus, illustration, detailed", "negative_prompt_prefix": "photo, 3d model, realism, ugly"},
    "3D Render": {"prompt_suffix": "3d render, octane render, blender, detailed, physically based rendering, vray", "negative_prompt_prefix": "2d, drawing, sketch, painting, illustration, flat"},
}

# --- Helper Functions ---
def get_random_seed():
    return random.randint(0, 2**32 - 1)

def apply_style(prompt, style_name):
    if style_name == "None" or style_name not in STYLE_PRESETS:
        return prompt, ""
    
    preset = STYLE_PRESETS[style_name]
    styled_prompt = f"{prompt.strip()}, {preset['prompt_suffix']}" if prompt.strip() else preset['prompt_suffix']
    return styled_prompt.strip(", "), preset['negative_prompt_prefix'].strip(", ")

# --- Google Drive Mounting (if in Colab) ---
def mount_google_drive():
    if 'google.colab' in sys.modules:
        try:
            from google.colab import drive
            drive.mount('/content/drive')
            print("Google Drive mounted successfully at /content/drive")
            return True
        except Exception as e:
            print(f"Error mounting Google Drive: {e}")
            return False
    return False # Not in Colab or failed
GDRIVE_MOUNTED_SUCCESSFULLY = False # Global flag
GDRIVE_SAVE_PATH = "/content/drive/MyDrive/AI_Generated_Images/StableDiffusion/"

# --- Model Loading Function (Colab specific status updates) ---
def load_model_colab(model_id_to_load, use_float16=True, use_attention_slicing=False, status_gr_textbox_ref=None):
    """Loads the model. status_gr_textbox_ref is a list containing the Gradio Textbox instance."""
    global pipe, current_model_id

    def update_status(message):
        print(message) # Always print to console
        # UI update will be handled by yield

    if pipe is not None and current_model_id == model_id_to_load:
        update_status(f"Model '{model_id_to_load}' already loaded.")
        yield gr.update(value=f"Model '{model_id_to_load}' already loaded.")
        return

    update_status(f"Loading model: {model_id_to_load}...")
    yield gr.update(value=f"Loading model: {model_id_to_load}...")

    pipeline_args = {}
    if torch.cuda.is_available() and use_float16:
        pipeline_args["torch_dtype"] = torch.float16
        update_status("Using float16 precision.")
    
    model_is_sdxl = "sdxl" in model_id_to_load.lower()

    try:
        if model_is_sdxl:
            update_status(f"Loading SDXL model: {model_id_to_load}...")
            yield gr.update(value=f"Loading SDXL model: {model_id_to_load}...")
            pipe = AutoPipelineForText2Image.from_pretrained(model_id_to_load, **pipeline_args)
        else:
            update_status(f"Loading Stable Diffusion model: {model_id_to_load}...")
            yield gr.update(value=f"Loading Stable Diffusion model: {model_id_to_load}...")
            scheduler = EulerDiscreteScheduler.from_pretrained(model_id_to_load, subfolder="scheduler")
            pipeline_args["scheduler"] = scheduler
            pipe = StableDiffusionPipeline.from_pretrained(model_id_to_load, **pipeline_args)

        if torch.cuda.is_available():
            update_status("Moving model to CUDA.")
            pipe = pipe.to("cuda")
        else:
            update_status("CUDA not available. Running on CPU (this will be very slow). Warning: May not work.")

        if use_attention_slicing and hasattr(pipe, "enable_attention_slicing"):
            update_status("Enabling attention slicing.")
            pipe.enable_attention_slicing()
        
        current_model_id = model_id_to_load
        update_status(f"Model '{model_id_to_load}' loaded successfully.")
        yield gr.update(value=f"Model '{model_id_to_load}' loaded successfully.")

    except Exception as e:
        error_message = f"Error loading model '{model_id_to_load}': {str(e)}"
        update_status(error_message)
        traceback.print_exc()
        pipe = None
        current_model_id = None
        yield gr.update(value=f"{error_message}. Check console.")

# --- Image Generation Function ---
def generate_image_colab_fn(prompt, negative_prompt, style_name, num_inference_steps, guidance_scale, seed_value, 
                            custom_filename_prefix_val, save_to_gdrive_val, 
                            progress=gr.Progress(track_ τότε=True)):
    global pipe, current_model_id, GDRIVE_MOUNTED_SUCCESSFULLY, GDRIVE_SAVE_PATH
    
    additional_info = []

    if pipe is None:
        return None, "Model not loaded. Please run the model loading cell first.", gr.DownloadButton.update(visible=False), ""

    progress(0, desc="🎨 Starting generation...")

    styled_prompt, style_negative_prefix = apply_style(prompt, style_name)
    if style_negative_prefix and negative_prompt:
        final_negative_prompt = f"{style_negative_prefix}, {negative_prompt}"
    elif style_negative_prefix:
        final_negative_prompt = style_negative_prefix
    else:
        final_negative_prompt = negative_prompt

    progress(0.1, desc=f"📝 Prompt: {styled_prompt[:100]}...")
    print(f"Generating with Prompt: '{styled_prompt}'")
    if final_negative_prompt: print(f"Negative Prompt: '{final_negative_prompt}'")
    
    try:
        seed = int(seed_value)
    except (ValueError, TypeError):
        seed = get_random_seed()
    print(f"🌱 Seed: {seed}")
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    generator = torch.Generator(device).manual_seed(seed)
    
    num_inference_steps = int(num_inference_steps)
    guidance_scale = float(guidance_scale)

    generation_args = {
        "prompt": styled_prompt,
        "negative_prompt": final_negative_prompt if final_negative_prompt else None,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale,
        "generator": generator
    }

    image = None
    try:
        for i in progress.tqdm(range(num_inference_steps), desc=" Diffusion steps"):
            if i == num_inference_steps - 1:
                if device == "cuda" and hasattr(pipe, 'torch_dtype') and pipe.torch_dtype == torch.float16:
                    with torch.autocast("cuda"):
                        image = pipe(**generation_args).images[0]
                else:
                    image = pipe(**generation_args).images[0]
        
        if image is None: raise RuntimeError("Image generation failed within loop.")

        timestamp = time.strftime("%Y%m%d-%H%M%S")
        sane_prompt_prefix = "".join(c if c.isalnum() else "_" for c in styled_prompt[:30])
        filename_prefix_to_use = custom_filename_prefix_val.strip() if custom_filename_prefix_val.strip() else sane_prompt_prefix
        base_filename = f"{filename_prefix_to_use}_{current_model_id.split('/')[-1]}_{seed}_{timestamp}.png"

        print(f"✅ Image generation successful. Filename: {base_filename}")
        progress(1.0, desc="🎉 Generation Complete!")
        
        # Prepare image for download button (local temp save)
        temp_dir = "temp_generated_colab_images"
        os.makedirs(temp_dir, exist_ok=True)
        local_img_path = os.path.join(temp_dir, base_filename)
        image.save(local_img_path)
        additional_info.append(f"Saved locally for download: {local_img_path}")

        # Save to Google Drive if requested and possible
        if save_to_gdrive_val and 'google.colab' in sys.modules:
            if not GDRIVE_MOUNTED_SUCCESSFULLY:
                print("Attempting to mount Google Drive for saving...")
                GDRIVE_MOUNTED_SUCCESSFULLY = mount_google_drive()
            
            if GDRIVE_MOUNTED_SUCCESSFULLY:
                try:
                    os.makedirs(GDRIVE_SAVE_PATH, exist_ok=True)
                    gdrive_img_path = os.path.join(GDRIVE_SAVE_PATH, base_filename)
                    image.save(gdrive_img_path)
                    print(f"💾 Image saved to Google Drive: {gdrive_img_path}")
                    additional_info.append(f"Saved to GDrive: {gdrive_img_path}")
                except Exception as e:
                    gdrive_save_error = f"❌ Error saving to Google Drive: {e}"
                    print(gdrive_save_error)
                    additional_info.append(gdrive_save_error)
            else:
                gdrive_mount_fail_info = "❌ Google Drive not mounted or mount failed. Cannot save."
                print(gdrive_mount_fail_info)
                additional_info.append(gdrive_mount_fail_info)
        elif save_to_gdrive_val:
            not_in_colab_info = "ℹ️ 'Save to Google Drive' is only available in Google Colab."
            print(not_in_colab_info)
            additional_info.append(not_in_colab_info)
        
        info_text = f"🌱 Seed: {seed}\n🕒 Timestamp: {timestamp}\n🔧 Model: {current_model_id}\n🎨 Style: {style_name}\n📛 Filename: {base_filename}"
        full_info_text = info_text + "\n" + "\n".join(additional_info)
        return image, full_info_text, gr.DownloadButton.update(value=local_img_path, visible=True)

    except Exception as e:
        traceback.print_exc()
        error_message = f"❌ Error during generation: {str(e)}"
        print(error_message)
        progress(1.0, desc=error_message)
        return None, f"{error_message}. Check Colab console.", gr.DownloadButton.update(visible=False)

print("Helper functions and generation logic defined.")

## 3. Define Gradio User Interface

This cell sets up the interactive web UI using Gradio. It includes input fields for prompts, style selection, sliders for generation parameters, seed control, and areas for displaying the generated image and status messages.

In [None]:
# @title 3. Define Gradio User Interface

status_textbox_ref = [None] 

def create_gradio_ui_colab():
    global status_textbox_ref

    with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container { max-width: 98% !important; } footer {display: none !important}") as demo:
        gr.Markdown(f"# 🎨 Advanced Stable Diffusion UI (Colab)")
        status_textbox_ref[0] = gr.Textbox(label="Status", value="Interface loaded. Please load a model using Cell 4.", interactive=False, lines=2)

        with gr.Row(equal_height=False):
            with gr.Column(scale=2, min_width=400):
                gr.Markdown("## ⚙️ Input Controls")
                prompt_input = gr.Textbox(label="Enter your Prompt", lines=3, placeholder="e.g., A majestic lion in a futuristic city, neon lights, detailed fur")
                negative_prompt_input = gr.Textbox(label="Negative Prompt (what to avoid)", lines=2, placeholder="e.g., blurry, low quality, ugly, text, watermark, disfigured, cartoon")
                style_dropdown = gr.Dropdown(label="Artistic Style Preset", choices=["None"] + list(STYLE_PRESETS.keys())[1:], value="None")

                with gr.Row():
                    inference_steps_slider = gr.Slider(minimum=10, maximum=100, value=25, step=1, label="Inference Steps", info="More steps can improve detail but take longer.")
                    cfg_scale_slider = gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="CFG Scale (Guidance)", info="How strongly the prompt guides the image.")
                
                with gr.Row():
                    seed_input = gr.Number(label="Seed", value=get_random_seed(), precision=0, minimum=0)
                    random_seed_button = gr.Button("🎲 Randomize Seed", scale=1, min_width=50)
                
                with gr.Accordion("Output Options", open=True):
                    custom_filename_input = gr.Textbox(label="Custom Filename Prefix (Optional)", placeholder="my_creation_prefix")
                    save_to_gdrive_checkbox = gr.Checkbox(label="Save to Google Drive", value=False, visible='google.colab' in sys.modules, info=f"Saves to {GDRIVE_SAVE_PATH} if Drive is mounted.")
                    if 'google.colab' not in sys.modules:
                        gr.Markdown("_(Google Drive option only available in Colab environment.)_")

                generate_button = gr.Button("🖼️ Generate Image", variant="primary")

            with gr.Column(scale=3, min_width=520):
                gr.Markdown("## 🖼️ Generated Image")
                image_output = gr.Image(label="Output Image", type="pil", height=512, show_label=False, show_download_button=False, visible=False)
                download_button = gr.DownloadButton(label="💾 Download Image", visible=False)
                info_output = gr.Textbox(label="Generation Info & Logs", lines=5, interactive=False)

        def on_generate_wrapper(prompt, neg_prompt, style, steps, cfg, seed, filename_prefix, save_gdrive, progress=gr.Progress(track_ τότε=True)):
            yield {
                generate_button: gr.update(interactive=False, value="⏳ Generating..."),
                status_textbox_ref[0]: gr.update(value="⏳ Generating image..."),
                image_output: gr.update(visible=False, value=None),
                download_button: gr.update(visible=False),
                info_output: ""
            }
            
            img, info_text, dl_button_update = generate_image_colab_fn(prompt, neg_prompt, style, steps, cfg, seed, filename_prefix, save_gdrive, progress=progress)
            
            yield {
                image_output: gr.update(value=img, visible=True if img else False),
                info_output: info_text,
                generate_button: gr.update(interactive=True, value="🖼️ Generate Image"),
                status_textbox_ref[0]: gr.update(value="✅ Generation complete." if img else "❌ Generation failed. Check console."),
                download_button: dl_button_update
            }

        generate_button.click(
            fn=on_generate_wrapper,
            inputs=[prompt_input, negative_prompt_input, style_dropdown, inference_steps_slider, cfg_scale_slider, seed_input, custom_filename_input, save_to_gdrive_checkbox],
            outputs=[generate_button, status_textbox_ref[0], image_output, download_button, info_output]
        )

        random_seed_button.click(fn=get_random_seed, inputs=None, outputs=seed_input)
        
        gr.Markdown("--- ")
        gr.Markdown("**Note:** If you encounter issues, check the Colab console output for error messages. "
                    "Ensure you have selected a GPU runtime (`Runtime` -> `Change runtime type` -> `GPU`).")
        gr.Markdown("--- ")
        gr.Markdown("### Running Locally with `app.py`:\n"
                    "1. Download `app.py` and `requirements.txt` (if provided, or install: `pip install gradio torch diffusers transformers accelerate Pillow`).\n"
                    "2. Ensure you have Python and pip installed.\n"
                    "3. If using a GPU, ensure CUDA drivers and PyTorch with CUDA support are installed.\n"
                    "4. Open your terminal or command prompt, navigate to the directory where you saved `app.py`.\n"
                    "5. Run the app: `gradio app.py` or `python app.py`.\n"
                    "   You can also pass arguments, e.g., `python app.py --model_id stabilityai/stable-diffusion-2-1-base --port 7861`.")

    return demo

gradio_app_instance = create_gradio_ui_colab()
print("Gradio UI defined. Status textbox referenced.")

## 4. Load the Stable Diffusion Model

This cell will download and load the pre-trained Stable Diffusion model. This can take several minutes, especially the first time you run it, as it needs to download the model weights (several gigabytes).

**Choose your model below.** `runwayml/stable-diffusion-v1-5` is recommended for the free Colab tier (T4 GPU). SDXL models like `stabilityai/sdxl-base-1.0` are higher quality but require more VRAM (may not fit on T4, might need Colab Pro with A100/V100).

In [None]:
# @title 4. Load the Model (can take a few minutes)

# --- Model Configuration ---
MODEL_TO_LOAD = "runwayml/stable-diffusion-v1-5"  # @param ["runwayml/stable-diffusion-v1-5", "stabilityai/sdxl-base-1.0", "dreamlike-art/dreamlike-photoreal-2.0", "prompthero/openjourney"] {"allow-input": true}
USE_FLOAT16_PRECISION = True  # @param {type:"boolean"}
ENABLE_ATTENTION_SLICING = True  # @param {type:"boolean"}

print(f"Selected model for loading: {MODEL_TO_LOAD}")
print(f"Using float16 precision: {USE_FLOAT16_PRECISION}")
print(f"Enabling attention slicing: {ENABLE_ATTENTION_SLICING}")

# Clean up any existing pipe to free memory before loading a new one
if 'pipe' in globals() and pipe is not None:
    print("Clearing existing model from memory...")
    del pipe
    pipe = None # Ensure it's None
    torch.cuda.empty_cache()
    print("Existing model cleared.")

# Update the global current_model_id for the info display
current_model_id = MODEL_TO_LOAD

# The load_model_colab is a generator. We iterate through its yields to process them.
# The status updates are handled by yielding gr.update() for the status_textbox.
if status_textbox_ref[0] is not None:
    # This is tricky. Direct update of status_textbox_ref[0] from here won't reflect in UI
    # unless this cell's execution is part of a Gradio event chain that has status_textbox_ref[0] as an output.
    # For a standalone cell execution, print() is the most reliable feedback during the process.
    # The yielded gr.update() values are for when this function is called *by* Gradio.
    # We will manually set the value of the referenced textbox at the end of this cell.
    status_textbox_ref[0].value = f"Initiating load for {MODEL_TO_LOAD}..." # Initial message
    print(f"UI Status Check: Textbox for status is available. Initial message set.")

    final_status_message = f"Model {MODEL_TO_LOAD} loading process initiated."
    # Consume the generator to execute the loading process
    for status_update_yield in load_model_colab(MODEL_TO_LOAD, USE_FLOAT16_PRECISION, ENABLE_ATTENTION_SLICING, status_textbox_ref):
        # The yielded gr.update object contains the value for the status textbox
        if isinstance(status_update_yield, gr. uomini.Update):
            final_status_message = status_update_yield.value
            # We can't apply this update to the UI directly from here in a way that Gradio recognizes for a live update.
            # Print it to console for confirmation.
            print(f"Loading progress: {final_status_message}")
        else:
            # This case should not happen if load_model_colab always yields gr.update
            print(f"Unexpected yield from load_model_colab: {status_update_yield}")
    
    # After the loop, set the final status on the referenced textbox
    if status_textbox_ref[0] is not None:
        status_textbox_ref[0].value = final_status_message # Manually update the value property
        print(f"Final status after loading attempt: {final_status_message}")
    
    if pipe is None:
        print(f"Critical: Model {MODEL_TO_LOAD} FAILED to load. Check console output above for errors.")
        if status_textbox_ref[0] is not None: status_textbox_ref[0].value = f"❌ Model {MODEL_TO_LOAD} FAILED to load. Check console."
    else:
        print(f"Success: Model {MODEL_TO_LOAD} loaded and ready.")
        if status_textbox_ref[0] is not None: status_textbox_ref[0].value = f"✅ Model '{MODEL_TO_LOAD}' loaded. Ready to generate."
else:
    print("Warning: Gradio status textbox reference not found. Loading model without live UI status updates during load.")
    # Fallback if status_textbox_ref[0] isn't correctly captured (should not happen with current setup)
    for _ in load_model_colab(MODEL_TO_LOAD, USE_FLOAT16_PRECISION, ENABLE_ATTENTION_SLICING, None): pass # Consume generator
    if pipe is None: print(f"Critical: Model {MODEL_TO_LOAD} FAILED to load (no UI ref). Check console.")
    else: print(f"Success: Model {MODEL_TO_LOAD} loaded (no UI ref). Ready.")

## 5. Launch the Gradio App

Run the cell below to start the Gradio interface. 

**Important:** 
*   Make sure **Cell 4 (Load the Model)** has completed successfully before running this cell.
*   Click the **public URL** (it usually looks like `https://xxxx.gradio.live` or `https://xxxxxx.gradio.app`) that appears in the output to open the web UI in a new tab.

In [None]:
# @title 5. Launch Gradio UI
if pipe is not None and gradio_app_instance is not None:
    final_launch_message = f"🚀 Launching Gradio app with model: {current_model_id}"
    print(final_launch_message)
    if status_textbox_ref[0] is not None:
        status_textbox_ref[0].value = final_launch_message # Update status before launching
    
    # Clean up temp directory from previous runs if any
    temp_image_dir = "temp_generated_colab_images"
    if os.path.exists(temp_image_dir):
        try:
            import shutil
            shutil.rmtree(temp_image_dir)
            print(f"Cleaned up old temp directory: {temp_image_dir}")
        except Exception as e:
            print(f"Warning: Could not clean up temp directory {temp_image_dir}: {e}")
    os.makedirs(temp_image_dir, exist_ok=True) # Ensure it exists for current session

    gradio_app_instance.queue().launch(debug=False, share=True) # share=True creates a public link
else:
    error_msg = "ERROR: Model not loaded or Gradio UI not defined. Cannot launch Gradio app."
    print(error_msg)
    if status_textbox_ref[0] is not None:
        status_textbox_ref[0].value = error_msg
    print("Please ensure Cell 3 (Define UI) and Cell 4 (Load Model) executed successfully.")
    if pipe is None: print("- Model (pipe) is None.")
    if gradio_app_instance is None: print("- Gradio App (gradio_app_instance) is None.")