# 🎨 AI Logo Generator with Stable Diffusion

This notebook creates a powerful AI logo generator using Stable Diffusion XL that can be accessed via API from your Next.js application.

## Features:
- High-quality logo generation using Stable Diffusion XL
- REST API endpoint for integration
- Optimized prompts for logo design
- Base64 image output for easy integration


## 1. Install Dependencies


In [None]:
# Install required packages
!pip install diffusers transformers accelerate torch torchvision torchaudio --quiet
!pip install flask flask-cors pillow --quiet
!pip install pyngrok --quiet

print("✅ Dependencies installed successfully!")


## 2. Import Libraries and Setup


In [None]:
import torch
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler
import base64
from io import BytesIO
from PIL import Image
import json
from flask import Flask, request, jsonify
from flask_cors import CORS
import threading
import time

print("✅ Libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")


## 3. Load Stable Diffusion XL Model


In [None]:
# Load Stable Diffusion XL model
print("🔄 Loading Stable Diffusion XL model...")

model_id = "stabilityai/stable-diffusion-xl-base-1.0"

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🎯 Using device: {device}")

# Load the pipeline
pipe = StableDiffusionXLPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    use_safetensors=True,
    variant="fp16" if device == "cuda" else None
)

# Optimize scheduler for faster generation
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

# Move to device
pipe = pipe.to(device)

# Enable memory efficient attention
if device == "cuda":
    pipe.enable_attention_slicing()
    pipe.enable_model_cpu_offload()

print("✅ Model loaded successfully!")
print(f"💾 Model size: ~{pipe.unet.config.sample_size}x{pipe.unet.config.sample_size}")


## 4. Logo Generation Functions


In [None]:
def enhance_logo_prompt(user_prompt, style="professional", color="modern", industry="general"):
    """
    Enhance user prompt with logo-specific keywords for better results
    """
    
    # Style descriptors
    style_keywords = {
        "professional": "clean, corporate, sophisticated, elegant, modern, business-like",
        "minimalist": "simple, clean, minimal, geometric, uncluttered, essential elements only",
        "creative": "artistic, unique, innovative, bold, expressive, distinctive",
        "corporate": "formal, business-like, trustworthy, established, professional, authoritative"
    }
    
    # Color descriptors
    color_keywords = {
        "modern": "contemporary color palette, trending colors, fresh and current",
        "vibrant": "bright, energetic, eye-catching colors, bold and dynamic",
        "monochrome": "black and white, grayscale, classic, timeless",
        "pastel": "soft, muted, gentle colors, subtle and refined"
    }
    
    # Industry descriptors
    industry_keywords = {
        "general": "versatile, adaptable design, universal appeal",
        "tech": "futuristic, digital, innovative, tech-forward, cutting-edge",
        "finance": "trustworthy, stable, professional, reliable, secure",
        "education": "friendly, approachable, knowledge-focused, inspiring",
        "marketing": "dynamic, creative, attention-grabbing, memorable"
    }
    
    # Build enhanced prompt
    enhanced_prompt = f"{user_prompt}, {style_keywords.get(style, '')}, {color_keywords.get(color, '')}, {industry_keywords.get(industry, '')}"
    
    # Add logo-specific keywords
    logo_keywords = "logo design, brand identity, vector style, scalable, professional, high quality, detailed, centered composition, transparent background, no text, just icon/symbol"
    
    final_prompt = f"{enhanced_prompt}, {logo_keywords}"
    
    # Negative prompt to avoid unwanted elements
    negative_prompt = "text, words, letters, typography, watermark, signature, low quality, blurry, distorted, amateur, cartoon, childish, cluttered, busy, complex background"
    
    return final_prompt, negative_prompt

def generate_logo(prompt, style="professional", color="modern", industry="general", width=512, height=512, num_inference_steps=30):
    """
    Generate a logo using Stable Diffusion XL
    """
    try:
        # Enhance the prompt
        enhanced_prompt, negative_prompt = enhance_logo_prompt(prompt, style, color, industry)
        
        print(f"🎨 Generating logo with prompt: {enhanced_prompt[:100]}...")
        
        # Generate image
        with torch.autocast(device):
            image = pipe(
                prompt=enhanced_prompt,
                negative_prompt=negative_prompt,
                width=width,
                height=height,
                num_inference_steps=num_inference_steps,
                guidance_scale=7.5,
                num_images_per_prompt=1
            ).images[0]
        
        # Convert to base64
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode()
        
        print("✅ Logo generated successfully!")
        
        return {
            "success": True,
            "image": img_str,
            "prompt": enhanced_prompt,
            "dimensions": f"{width}x{height}"
        }
        
    except Exception as e:
        print(f"❌ Error generating logo: {str(e)}")
        return {
            "success": False,
            "error": str(e)
        }

print("✅ Logo generation functions defined!")


## 5. Test Logo Generation


In [None]:
# Test logo generation
print("🧪 Testing logo generation...")

test_result = generate_logo(
    prompt="modern tech startup",
    style="professional",
    color="modern",
    industry="tech",
    width=512,
    height=512
)

if test_result["success"]:
    print("✅ Test generation successful!")
    print(f"📏 Dimensions: {test_result['dimensions']}")
    print(f"📝 Prompt used: {test_result['prompt'][:100]}...")
    
    # Display the generated image
    from IPython.display import Image as IPImage, display
    import base64
    
    img_data = base64.b64decode(test_result['image'])
    display(IPImage(data=img_data, width=256))
else:
    print(f"❌ Test generation failed: {test_result['error']}")


## 6. Create Flask API Server


In [None]:
# Create Flask app
app = Flask(__name__)
CORS(app)  # Enable CORS for all routes

@app.route('/health', methods=['GET'])
def health_check():
    """Health check endpoint"""
    return jsonify({
        "status": "healthy",
        "model_loaded": True,
        "device": device,
        "timestamp": time.time()
    })

@app.route('/generate-logo', methods=['POST'])
def api_generate_logo():
    """
    API endpoint for logo generation
    Expected JSON payload:
    {
        "prompt": "logo description",
        "style": "professional|minimalist|creative|corporate",
        "color": "modern|vibrant|monochrome|pastel",
        "industry": "general|tech|finance|education|marketing",
        "width": 512,
        "height": 512
    }
    """
    try:
        # Get request data
        data = request.get_json()
        
        if not data or 'prompt' not in data:
            return jsonify({
                "success": False,
                "error": "Prompt is required"
            }), 400
        
        # Extract parameters
        prompt = data['prompt']
        style = data.get('style', 'professional')
        color = data.get('color', 'modern')
        industry = data.get('industry', 'general')
        width = data.get('width', 512)
        height = data.get('height', 512)
        
        print(f"🎨 API Request - Prompt: {prompt}, Style: {style}, Color: {color}")
        
        # Generate logo
        result = generate_logo(
            prompt=prompt,
            style=style,
            color=color,
            industry=industry,
            width=width,
            height=height
        )
        
        if result['success']:
            return jsonify(result)
        else:
            return jsonify(result), 500
            
    except Exception as e:
        print(f"❌ API Error: {str(e)}")
        return jsonify({
            "success": False,
            "error": f"Internal server error: {str(e)}"
        }), 500

@app.route('/generate-variations', methods=['POST'])
def api_generate_variations():
    """
    API endpoint for generating multiple logo variations
    """
    try:
        data = request.get_json()
        
        if not data or 'prompt' not in data:
            return jsonify({
                "success": False,
                "error": "Prompt is required"
            }), 400
        
        prompt = data['prompt']
        
        # Define variations
        variations = [
            {"style": "minimalist", "color": "modern", "industry": "general"},
            {"style": "creative", "color": "vibrant", "industry": "general"},
            {"style": "professional", "color": "monochrome", "industry": "general"},
            {"style": "corporate", "color": "pastel", "industry": "general"}
        ]
        
        results = []
        
        for i, variation in enumerate(variations):
            print(f"🎨 Generating variation {i+1}/4...")
            
            result = generate_logo(
                prompt=prompt,
                style=variation['style'],
                color=variation['color'],
                industry=variation['industry']
            )
            
            if result['success']:
                results.append({
                    "id": i + 1,
                    "image": result['image'],
                    "style": variation['style'],
                    "color": variation['color'],
                    "prompt": result['prompt']
                })
        
        return jsonify({
            "success": True,
            "variations": results,
            "total": len(results)
        })
        
    except Exception as e:
        print(f"❌ Variations API Error: {str(e)}")
        return jsonify({
            "success": False,
            "error": f"Internal server error: {str(e)}"
        }), 500

print("✅ Flask API server defined!")


## 6.5. Setup ngrok Authentication


In [None]:
# Set your ngrok authtoken
from pyngrok import ngrok

# Replace with your actual ngrok authtoken
ngrok.set_auth_token("2zCl6P6V99Xx6m9frMnzoYE5D67_592wMzVUUbJuUz5BpkZur")

print("✅ ngrok authentication configured!")


## 7. Setup ngrok Tunnel


In [None]:
# Setup ngrok tunnel
from pyngrok import ngrok

# Create ngrok tunnel
public_url = ngrok.connect(5000)
print(f"🌐 Public URL: {public_url}")
print(f"🔗 API Endpoint: {public_url}/generate-logo")
print(f"🔗 Variations Endpoint: {public_url}/generate-variations")
print(f"🔗 Health Check: {public_url}/health")

# Store the URL for easy access
api_url = str(public_url)
print(f"\n📋 Copy this URL to your Next.js app: {api_url}")


## 8. Start the Server


In [None]:
# Start Flask server
print("🚀 Starting Flask server...")
print(f"🌐 Server will be available at: {api_url}")
print("\n📝 API Endpoints:")
print(f"   POST {api_url}/generate-logo")
print(f"   POST {api_url}/generate-variations")
print(f"   GET  {api_url}/health")
print("\n⚠️  Keep this cell running to maintain the server!")

# Run the Flask app
app.run(host='0.0.0.0', port=5000, debug=False, use_reloader=False)


## 7. Setup ngrok Tunnel


In [None]:
# Setup ngrok tunnel
from pyngrok import ngrok

# Create ngrok tunnel
public_url = ngrok.connect(5000)
print(f"🌐 Public URL: {public_url}")
print(f"🔗 API Endpoint: {public_url}/generate-logo")
print(f"🔗 Variations Endpoint: {public_url}/generate-variations")
print(f"🔗 Health Check: {public_url}/health")

# Store the URL for easy access
api_url = str(public_url)
print(f"\n📋 Copy this URL to your Next.js app: {api_url}")


## 8. Start the Server


In [None]:
# Start Flask server
print("🚀 Starting Flask server...")
print(f"🌐 Server will be available at: {api_url}")
print("\n📝 API Endpoints:")
print(f"   POST {api_url}/generate-logo")
print(f"   POST {api_url}/generate-variations")
print(f"   GET  {api_url}/health")
print("\n⚠️  Keep this cell running to maintain the server!")

# Run the Flask app
app.run(host='0.0.0.0', port=5000, debug=False, use_reloader=False)
