<a href="https://colab.research.google.com/github/abhi-0203/Optimized-AI-Image-Generator-for-Google-Colab/blob/main/optimized_genai_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch
!pip install -U xformers --index-url https://download.pytorch.org/whl/cu126
!pip install diffusers==0.32.2
!pip install transformers==4.49

Looking in indexes: https://download.pytorch.org/whl/cu126
Collecting xformers
  Downloading https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.1 kB)
Downloading https://download.pytorch.org/whl/cu126/xformers-0.0.32.post2-cp39-abi3-manylinux_2_28_x86_64.whl (117.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 MB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xformers
Successfully installed xformers-0.0.32.post2
Collecting diffusers==0.32.2
  Downloading diffusers-0.32.2-py3-none-any.whl.metadata (18 kB)
Downloading diffusers-0.32.2-py3-none-any.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m29.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: diffusers
  Attempting uninstall: diffusers
    Found existing installation: diffusers 0.35.1
    Uninstalling diffusers-0.35.1:
      Successfully uninstalle

In [None]:
import gradio as gr
from diffusers import DiffusionPipeline, AutoencoderKL
import torch
import time
import gc
import os
from PIL import Image
import warnings

# Suppress warnings for cleaner output in Colab
warnings.filterwarnings("ignore", category=UserWarning)

class OptimizedImageGenerator:
    def __init__(self):
        self.pipe = None
        self.current_model = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.generator = None

        # Setup for Colab optimization
        if torch.cuda.is_available():
            torch.backends.cudnn.benchmark = True
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.backends.cudnn.allow_tf32 = True

    def check_memory(self):
        """Check and display current GPU memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3  # GB
            reserved = torch.cuda.memory_reserved() / 1024**3    # GB
            return f"🔍 GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB"
        return "CPU mode - No GPU memory tracking"

    def clear_memory(self):
        """Aggressive memory cleanup for Colab"""
        if hasattr(self, 'pipe') and self.pipe is not None:
            # Move pipeline to CPU first
            try:
                self.pipe = self.pipe.to("cpu")
            except:
                pass

        # Delete references
        if hasattr(self, 'pipe'):
            del self.pipe
        if hasattr(self, 'generator'):
            del self.generator

        # Reset attributes
        self.pipe = None
        self.generator = None

        # Force garbage collection
        gc.collect()

        # Clear CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()

    def load_model(self, model_id, use_cpu_offload=True, use_sequential_offload=False):
        """Load model with advanced Colab optimizations"""

        if not model_id.strip():
            return "❌ Please enter a model ID"

        if model_id == self.current_model:
            return f"⚠️ Model {model_id} already loaded"

        try:
            # Clear existing model
            if self.pipe is not None:
                self.clear_memory()

            status_msg = f"🔄 Loading {model_id}..."
            print(status_msg)

            # Load with memory optimizations for Colab
            load_kwargs = {
                "torch_dtype": torch.float16,
                "safety_checker": None,
                "requires_safety_checker": False,
                "low_cpu_mem_usage": True,
                "use_safetensors": True
            }

            # Try to load model
            self.pipe = DiffusionPipeline.from_pretrained(model_id, **load_kwargs)

            # Explicitly cast to float16 after loading
            if torch.cuda.is_available():
                self.pipe = self.pipe.to(torch.float16)

            # Apply memory optimizations based on settings
            if use_sequential_offload and hasattr(self.pipe, 'enable_sequential_cpu_offload'):
                # Most memory efficient but slowest
                self.pipe.enable_sequential_cpu_offload()
            elif use_cpu_offload and hasattr(self.pipe, 'enable_model_cpu_offload'):
                # Good balance of speed and memory
                self.pipe.enable_model_cpu_offload()
            else:
                # Keep everything on GPU if memory allows
                self.pipe = self.pipe.to(self.device)


            # Enable memory efficient attention
            try:
                if hasattr(self.pipe, 'enable_xformers_memory_efficient_attention'):
                    self.pipe.enable_xformers_memory_efficient_attention()
                elif hasattr(self.pipe, 'enable_attention_slicing'):
                    self.pipe.enable_attention_slicing()
            except:
                pass

            # Enable VAE slicing for large images
            if hasattr(self.pipe, 'enable_vae_slicing'):
                self.pipe.enable_vae_slicing()

            # Enable VAE tiling for very large images
            if hasattr(self.pipe, 'enable_vae_tiling'):
                self.pipe.enable_vae_tiling()

            self.current_model = model_id
            memory_info = self.check_memory()

            return f"✅ Successfully loaded: {model_id}\n{memory_info}"

        except Exception as e:
            self.clear_memory()
            error_msg = str(e)
            if "out of memory" in error_msg.lower():
                return f"❌ Out of memory loading {model_id}. Try enabling CPU offload or using a smaller model."
            return f"❌ Error loading {model_id}: {error_msg}"

# Initialize generator
generator = OptimizedImageGenerator()

def generate_image(
    prompt,
    negative_prompt="blurry, low quality, distorted, deformed",
    steps=25,
    guidance=7.5,
    width=1024,
    height=1024,
    seed=None,
    use_random_seed=True,
    batch_size=1
):
    if not generator.pipe:
        return None, "⚠️ Please load a model first!"

    if not prompt.strip():
        return None, "⚠️ Please enter a prompt!"

    try:
        # Memory check before generation
        memory_info = generator.check_memory()
        print(f"Pre-generation: {memory_info}")

        # Handle seed
        if use_random_seed or seed is None:
            seed = int(time.time() * 1000000) % 2147483647

        # Create generator for reproducibility
        torch_generator = torch.Generator(device=generator.device).manual_seed(int(seed))

        # Clear cache before generation
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        # Generation parameters
        gen_kwargs = {
            "prompt": prompt,
            "negative_prompt": negative_prompt,
            "num_inference_steps": int(steps),
            "guidance_scale": float(guidance),
            "width": int(width),
            "height": int(height),
            "generator": torch_generator,
            "num_images_per_prompt": int(batch_size)
        }

        # Generate image
        start_time = time.time()
        result = generator.pipe(**gen_kwargs)
        generation_time = time.time() - start_time

        # Get the first image
        image = result.images[0] if result.images else None

        # Post-generation cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        success_msg = f"✅ Generated in {generation_time:.2f}s\nSeed: {seed}\n{generator.check_memory()}"

        return image, success_msg

    except Exception as e:
        error_msg = str(e)
        if "out of memory" in error_msg.lower():
            # Aggressive cleanup on OOM
            generator.clear_memory()
            return None, "❌ Out of memory during generation. Try reducing image size, steps, or batch size."
        return None, f"❌ Generation failed: {error_msg}"

def clear_all_memory():
    """Manual memory cleanup function"""
    generator.clear_memory()
    return "🧹 Memory cleared!"

# Custom CSS for better Colab appearance
css = """
.gradio-container {
    font-family: 'IBM Plex Sans', sans-serif;
}
.gr-button {
    color: white;
    border-color: #9D5CFF;
    background: #9D5CFF;
}
.gr-button:hover {
    border-color: #9D5CFF;
    background: #9D5CFF;
}
"""

# Create the Gradio interface
with gr.Blocks(title="🚀 Optimized AI Image Generator - Colab Edition", css=css, theme=gr.themes.Soft()) as demo:

    gr.Markdown("""
    # 🚀 Optimized AI Image Generator for Google Colab

    **Key Optimizations:**
    - 🧠 Smart memory management with automatic cleanup
    - 🔄 CPU/GPU offloading options for memory-constrained environments
    - ⚡ Efficient attention mechanisms (xFormers/slicing)
    - 🎯 VAE optimizations for large image generation
    - 📊 Real-time memory monitoring
    """)

    with gr.Row():
        with gr.Column(scale=1):
            # Model Management Section
            gr.Markdown("### 🎯 Model Management")

            model_url = gr.Textbox(
                label="Hugging Face Model ID",
                placeholder="stabilityai/stable-diffusion-xl-base-1.0",
                value="runwayml/stable-diffusion-v1-5"
            )

            with gr.Row():
                use_cpu_offload = gr.Checkbox(
                    label="Enable CPU Offloading",
                    value=True,
                    info="Recommended for Colab (saves VRAM)"
                )
                use_sequential_offload = gr.Checkbox(
                    label="Sequential Offloading",
                    value=False,
                    info="Maximum memory saving (slower)"
                )

            load_btn = gr.Button("🔄 Load Model", variant="primary", size="lg")
            clear_btn = gr.Button("🧹 Clear Memory", variant="secondary")

            model_status = gr.Markdown("**Status:** No model loaded")

            # Generation Parameters
            gr.Markdown("### ⚙️ Generation Settings")

            prompt = gr.Textbox(
                label="Prompt",
                placeholder="A beautiful sunset over mountains, detailed, artistic",
                lines=3
            )

            negative_prompt = gr.Textbox(
                label="Negative Prompt",
                value="blurry, low quality, distorted, deformed, ugly, bad anatomy",
                lines=2
            )

            with gr.Row():
                steps = gr.Slider(10, 50, value=25, step=1, label="Inference Steps")
                guidance = gr.Slider(1.0, 20.0, value=7.5, step=0.5, label="Guidance Scale")

            with gr.Row():
                width = gr.Slider(512, 1536, value=1024, step=64, label="Width")
                height = gr.Slider(512, 1536, value=1024, step=64, label="Height")

            with gr.Row():
                seed = gr.Number(label="Seed (optional)", precision=0, value=42)
                random_seed = gr.Checkbox(label="Random Seed", value=True)
                batch_size = gr.Slider(1, 4, value=1, step=1, label="Batch Size")

            generate_btn = gr.Button("🎨 Generate Image", variant="primary", size="lg")

        with gr.Column(scale=1):
            # Output Section
            gr.Markdown("### 🖼️ Generated Image")

            output_image = gr.Image(
                label="Generated Image",
                type="pil",
                height=600,
                show_download_button=True
            )

            status = gr.Markdown("**Ready to generate!**")

            # Memory Monitor
            gr.Markdown("### 📊 Memory Monitor")
            memory_display = gr.Markdown("Click 'Check Memory' to see current usage")
            memory_check_btn = gr.Button("📊 Check Memory")

    # Tips Section
    gr.Markdown("""
    ### 💡 Colab Usage Tips

    1. **Memory Management**: Enable CPU offloading if you encounter out-of-memory errors
    2. **Performance**: Start with smaller images (768x768) and fewer steps (20-25) for faster generation
    3. **Quality**: Use negative prompts to improve image quality
    4. **Stability**: Clear memory between model switches to prevent crashes
    5. **Colab Limits**: Free tier has session limits - save your generated images!
    """)

    # Event handlers
    load_btn.click(
        fn=generator.load_model,
        inputs=[model_url, use_cpu_offload, use_sequential_offload],
        outputs=model_status,
        show_progress=True
    )

    clear_btn.click(
        fn=clear_all_memory,
        outputs=model_status
    )

    generate_btn.click(
        fn=generate_image,
        inputs=[
            prompt, negative_prompt, steps, guidance,
            width, height, seed, random_seed, batch_size
        ],
        outputs=[output_image, status],
        show_progress=True
    )

    memory_check_btn.click(
        fn=generator.check_memory,
        outputs=memory_display
    )

# Launch configuration optimized for Colab
if __name__ == "__main__":
    # For Colab, use share=True to get public URL
    demo.launch(
        share=True,           # Creates shareable link
        debug=True,          # Disable debug for cleaner output
        show_error=True,      # Show errors for debugging
        server_name="0.0.0.0", # Allow external connections
        server_port=7860,     # Standard port
        inbrowser=True,       # Auto-open in browser
        inline=False          # Don't inline in notebook
    )