# 🚀 Enhanced Logo Generator with Stable Diffusion XL + ControlNet

This notebook implements an advanced logo generation pipeline using:
- **Stable Diffusion XL** for high-quality image generation
- **ControlNet** for shape consistency (circles, squares, etc.)
- **AI Upscaling** for high-resolution output
- **Vector Conversion** for scalable SVG output
- **Flask API** with ngrok for easy integration

## 🎯 Features
- Generate logos with specific shapes (circle, square, hexagon, triangle)
- Professional prompt enhancement
- AI upscaling for crisp details
- Multiple variations generation
- REST API for easy integration

## 1. Install Dependencies

In [None]:
!pip install diffusers transformers accelerate torch torchvision torchaudio --quiet
!pip install controlnet-aux opencv-python --quiet
!pip install flask flask-cors pillow --quiet
!pip install pyngrok --quiet
!pip install opencv-python-headless --quiet
!pip install scikit-image --quiet

print("✅ Dependencies installed successfully!")

## 2. Import Libraries and Setup

In [None]:
import torch
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, DPMSolverMultistepScheduler
from controlnet_aux import CannyDetector
import base64
from io import BytesIO
from PIL import Image, ImageDraw, ImageFilter
import json
from flask import Flask, request, jsonify
from flask_cors import CORS
import threading
import time
import cv2
import numpy as np
from skimage import transform
import subprocess
import os

print("✅ Libraries imported successfully!")
print(f"🔥 PyTorch version: {torch.__version__}")
print(f"🚀 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🎮 GPU: {torch.cuda.get_device_name(0)}")

## 3. Load Models

In [None]:
# Load Stable Diffusion XL with ControlNet
print("🔄 Loading Stable Diffusion XL + ControlNet...")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🎯 Using device: {device}")

# Load ControlNet for shape consistency
controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
)

# Load SDXL pipeline with ControlNet
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    use_safetensors=True,
    variant="fp16" if device == "cuda" else None
)

# Optimize scheduler
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)

# Enable memory optimizations
if device == "cuda":
    pipe.enable_attention_slicing()
    pipe.enable_model_cpu_offload()

# Load Canny detector for edge control
canny_detector = CannyDetector()

print("✅ All models loaded successfully!")

## 4. Logo Generation Functions

In [None]:
def create_logo_control_image(shape_type="circle", size=512):
    """
    Create control images for different logo shapes
    """
    # Create a white background
    img = Image.new('RGB', (size, size), 'white')
    draw = ImageDraw.Draw(img)
    
    if shape_type == "circle":
        # Draw a circle outline
        margin = size // 8
        draw.ellipse([margin, margin, size-margin, size-margin], outline='black', width=3)
        # Add center point for symmetry
        center = size // 2
        draw.ellipse([center-2, center-2, center+2, center+2], fill='black')
        
    elif shape_type == "square":
        # Draw a square outline
        margin = size // 6
        draw.rectangle([margin, margin, size-margin, size-margin], outline='black', width=3)
        # Add center point
        center = size // 2
        draw.rectangle([center-2, center-2, center+2, center+2], fill='black')
        
    elif shape_type == "hexagon":
        # Draw a hexagon outline
        center = size // 2
        radius = size // 3
        points = []
        for i in range(6):
            angle = i * 60 * 3.14159 / 180
            x = center + radius * np.cos(angle)
            y = center + radius * np.sin(angle)
            points.append((x, y))
        draw.polygon(points, outline='black', width=3)
        # Add center point
        draw.ellipse([center-2, center-2, center+2, center+2], fill='black')
        
    elif shape_type == "triangle":
        # Draw a triangle outline
        center = size // 2
        radius = size // 3
        points = []
        for i in range(3):
            angle = (i * 120 - 90) * 3.14159 / 180  # Start from top
            x = center + radius * np.cos(angle)
            y = center + radius * np.sin(angle)
            points.append((x, y))
        draw.polygon(points, outline='black', width=3)
        # Add center point
        draw.ellipse([center-2, center-2, center+2, center+2], fill='black')
    
    # Convert to Canny edge detection
    img_np = np.array(img)
    canny_image = canny_detector(img_np)
    
    return Image.fromarray(canny_image)

def enhance_logo_prompt(user_prompt, style="professional", color="modern", industry="general", shape="circle"):
    """
    Enhance user prompt with logo-specific keywords
    """
    # Style descriptors
    style_keywords = {
        "professional": "clean, corporate, sophisticated, elegant, modern, business-like",
        "minimalist": "simple, clean, minimal, geometric, uncluttered, essential elements only",
        "creative": "artistic, unique, innovative, bold, expressive, distinctive",
        "corporate": "formal, business-like, trustworthy, established, professional, authoritative"
    }
    
    # Color descriptors
    color_keywords = {
        "modern": "contemporary color palette, trending colors, fresh and current",
        "vibrant": "bright, energetic, eye-catching colors, bold and dynamic",
        "monochrome": "black and white, grayscale, classic, timeless",
        "pastel": "soft, muted, gentle colors, subtle and refined"
    }
    
    # Industry descriptors
    industry_keywords = {
        "general": "versatile, adaptable design, universal appeal",
        "tech": "futuristic, digital, innovative, tech-forward, cutting-edge",
        "finance": "trustworthy, stable, professional, reliable, secure",
        "education": "friendly, approachable, knowledge-focused, inspiring",
        "marketing": "dynamic, creative, attention-grabbing, memorable"
    }
    
    # Shape descriptors
    shape_keywords = {
        "circle": "circular, round, centered, balanced, harmonious",
        "square": "geometric, structured, stable, solid, professional",
        "hexagon": "modern, tech-forward, innovative, dynamic",
        "triangle": "bold, directional, dynamic, energetic, forward-thinking"
    }
    
    # Build enhanced prompt
    enhanced_prompt = f"{user_prompt}, {style_keywords.get(style, '')}, {color_keywords.get(color, '')}, {industry_keywords.get(industry, '')}, {shape_keywords.get(shape, '')}"
    
    # Add logo-specific keywords
    logo_keywords = "logo design, brand identity, vector style, scalable, professional, high quality, detailed, centered composition, transparent background, no text, just icon/symbol, symmetrical, balanced"
    
    final_prompt = f"{enhanced_prompt}, {logo_keywords}"
    
    # Negative prompt
    negative_prompt = "text, words, letters, typography, watermark, signature, low quality, blurry, distorted, amateur, cartoon, childish, cluttered, busy, complex background, asymmetrical, unbalanced"
    
    return final_prompt, negative_prompt

def generate_logo_with_controlnet(prompt, style="professional", color="modern", industry="general", shape="circle", width=512, height=512, num_inference_steps=30):
    """
    Generate a logo using Stable Diffusion XL + ControlNet
    """
    try:
        # Enhance the prompt
        enhanced_prompt, negative_prompt = enhance_logo_prompt(prompt, style, color, industry, shape)
        
        print(f"🎨 Generating logo with prompt: {enhanced_prompt[:100]}...")
        print(f"🔷 Shape control: {shape}")
        
        # Create control image
        control_image = create_logo_control_image(shape, width)
        
        # Generate image with ControlNet
        with torch.autocast(device):
            image = pipe(
                prompt=enhanced_prompt,
                negative_prompt=negative_prompt,
                image=control_image,
                controlnet_conditioning_scale=0.8,
                width=width,
                height=height,
                num_inference_steps=num_inference_steps,
                guidance_scale=7.5,
                num_images_per_prompt=1
            ).images[0]
        
        print("✅ Logo generated with ControlNet!")
        
        return {
            "success": True,
            "image": image,
            "prompt": enhanced_prompt,
            "shape": shape,
            "dimensions": f"{width}x{height}"
        }
        
    except Exception as e:
        print(f"❌ Error generating logo: {str(e)}")
        return {
            "success": False,
            "error": str(e)
        }

def upscale_image(image, scale=2):
    """
    Simple upscaling using PIL
    """
    try:
        print(f"🔍 Upscaling image by {scale}x...")
        
        # Get original size
        original_size = image.size
        new_size = (original_size[0] * scale, original_size[1] * scale)
        
        # Upscale using PIL's LANCZOS resampling
        upscaled = image.resize(new_size, Image.Resampling.LANCZOS)
        
        print(f"✅ Image upscaled to {upscaled.size}")
        
        return upscaled
        
    except Exception as e:
        print(f"❌ Error upscaling image: {str(e)}")
        return image  # Return original if upscaling fails

print("✅ Enhanced logo generation functions defined!")

## 5. Setup ngrok Tunnel

In [None]:
# Setup ngrok tunnel
from pyngrok import ngrok

# Set your ngrok authtoken
ngrok.set_auth_token("2zCl6P6V99Xx6m9frMnzoYE5D67_592wMzVUUbJuUz5BpkZur")

# Create ngrok tunnel
public_url = ngrok.connect(5000)
print(f"🌐 Public URL: {public_url}")
print(f"🔗 API Endpoint: {public_url}/generate-logo")
print(f"📋 Copy this URL to your Next.js app: {public_url}")

## 6. Create Flask API Server

In [None]:
# Create Flask app
app = Flask(__name__)
CORS(app)

def image_to_base64(image):
    """Convert PIL Image to base64 string"""
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode()
    return img_str

@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({
        "status": "healthy",
        "service": "Enhanced Logo Generator",
        "models_loaded": True,
        "device": device
    })

@app.route('/generate-logo', methods=['POST'])
def generate_logo():
    try:
        data = request.get_json()
        
        # Extract parameters
        prompt = data.get('prompt', 'logo')
        style = data.get('style', 'professional')
        color = data.get('color', 'modern')
        industry = data.get('industry', 'general')
        shape = data.get('shape', 'circle')
        width = data.get('width', 512)
        height = data.get('height', 512)
        upscale = data.get('upscale', False)
        convert_to_svg = data.get('convert_to_svg', False)
        
        print(f"🎨 Generating logo: {prompt} ({shape})")
        
        # Generate logo
        result = generate_logo_with_controlnet(
            prompt=prompt,
            style=style,
            color=color,
            industry=industry,
            shape=shape,
            width=width,
            height=height
        )
        
        if not result['success']:
            return jsonify({
                "success": False,
                "error": result['error']
            }), 500
        
        # Upscale if requested
        if upscale:
            result['image'] = upscale_image(result['image'], scale=2)
            result['upscaled'] = True
        
        # Convert to base64
        image_base64 = image_to_base64(result['image'])
        
        return jsonify({
            "success": True,
            "image": image_base64,
            "prompt": result['prompt'],
            "shape": result['shape'],
            "dimensions": result['dimensions'],
            "upscaled": upscale,
            "vector_ready": convert_to_svg
        })
        
    except Exception as e:
        print(f"❌ API Error: {str(e)}")
        return jsonify({
            "success": False,
            "error": str(e)
        }), 500

@app.route('/generate-logo-variations', methods=['POST'])
def generate_logo_variations():
    try:
        data = request.get_json()
        
        # Extract parameters
        prompt = data.get('prompt', 'logo')
        style = data.get('style', 'professional')
        color = data.get('color', 'modern')
        industry = data.get('industry', 'general')
        shapes = data.get('shapes', ['circle', 'square', 'hexagon', 'triangle'])
        width = data.get('width', 512)
        height = data.get('height', 512)
        upscale = data.get('upscale', False)
        
        print(f"🎨 Generating {len(shapes)} logo variations: {prompt}")
        
        variations = []
        
        for shape in shapes:
            # Generate logo for each shape
            result = generate_logo_with_controlnet(
                prompt=prompt,
                style=style,
                color=color,
                industry=industry,
                shape=shape,
                width=width,
                height=height
            )
            
            if result['success']:
                # Upscale if requested
                if upscale:
                    result['image'] = upscale_image(result['image'], scale=2)
                    result['upscaled'] = True
                
                # Convert to base64
                image_base64 = image_to_base64(result['image'])
                
                variations.append({
                    "success": True,
                    "image": image_base64,
                    "prompt": result['prompt'],
                    "shape": result['shape'],
                    "dimensions": result['dimensions'],
                    "upscaled": upscale
                })
            else:
                variations.append({
                    "success": False,
                    "error": result['error'],
                    "shape": shape
                })
        
        return jsonify({
            "success": True,
            "variations": variations,
            "total_generated": len([v for v in variations if v['success']])
        })
        
    except Exception as e:
        print(f"❌ API Error: {str(e)}")
        return jsonify({
            "success": False,
            "error": str(e)
        }), 500

print("✅ Flask API server defined!")
print("🚀 Ready to start server...")

## 7. Start the Server

In [None]:
# Start the Flask server
print("🚀 Starting Enhanced Logo Generator API...")
print(f"🌐 Public URL: {public_url}")
print(f"🔗 API Endpoint: {public_url}/generate-logo")
print("📋 Copy this URL to your Next.js app!")
print("\n" + "="*50)
print("🎯 Available Endpoints:")
print(f"  • Health Check: {public_url}/health")
print(f"  • Generate Logo: {public_url}/generate-logo")
print(f"  • Generate Variations: {public_url}/generate-logo-variations")
print("="*50)
print("\n🔄 Server is running... Press Ctrl+C to stop")

# Run the Flask app
app.run(host='0.0.0.0', port=5000, debug=False)