https://colab.research.google.com/drive/1hc8G2WY_4P_0Tri-lZ0HmVDdX6MKy5LV?usp=sharing

In [None]:
# Install only required packages
!pip install torch torchvision diffusers transformers flask flask-cors xformers

import numpy as np
import torch
from torch.cuda import amp  # Add this import
from diffusers import AutoPipelineForText2Image
from flask import Flask, request, jsonify, make_response
from flask_cors import CORS
import base64
import io
import os
import gc
from datetime import datetime
import logging
from threading import Thread
import time
from transformers import set_seed
import subprocess
import sys
import requests  # Add missing requests import

# Replace cloudflared installation with localtunnel
!npm install -g localtunnel

# Configure memory settings
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:512"
torch.backends.cuda.max_memory_split_size = 512 * 1024 * 1024  # 512 MB

logger = logging.getLogger(__name__)

# Updated CORS configuration with all necessary headers
app = Flask(__name__)
CORS(app, resources={
    r"/*": {
        "origins": "*",  # Allow all origins
        "methods": ["GET", "POST", "OPTIONS"],
        "allow_headers": ["*"],  # Allow all headers
        "expose_headers": ["*"],
        "max_age": 3600,
        "supports_credentials": True
    }
})

@app.after_request
def after_request(response):
    """Add CORS headers to all responses"""
    origin = request.headers.get('Origin', '*')
    response.headers.update({
        'Access-Control-Allow-Origin': origin,
        'Access-Control-Allow-Methods': '*',
        'Access-Control-Allow-Headers': '*',
        'Access-Control-Allow-Credentials': 'true',
        'Access-Control-Max-Age': '3600',
        'Access-Control-Expose-Headers': '*'
    })
    return response

# Add model download before pipeline setup
def download_model():
    import subprocess
    repo_name = "stabilityai-sdxl-turbo-turbo-tiny-green-smashed"
    subprocess.run(["mkdir", repo_name])
    subprocess.run([
        "huggingface-cli", "download",
        f'PrunaAI/{repo_name}',
        "--local-dir", repo_name,
        "--local-dir-use-symlinks", "False"
    ])
    return repo_name

def get_gpu_memory():
    try:
        return {
            'free': torch.cuda.mem_get_info()[0] // 1024**2,  # Convert to MB
            'total': torch.cuda.mem_get_info()[1] // 1024**2,
            'used': (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) // 1024**2
        }
    except Exception:
        return None

def optimize_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()
    
    if torch.cuda.is_available():
        # Clear memory fragments
        current_device = torch.cuda.current_device()
        torch.cuda.synchronize(current_device)
        
        # Force garbage collection of CUDA tensors
        for obj in gc.get_objects():
            try:
                if torch.is_tensor(obj) and obj.device.type == 'cuda':
                    del obj
            except Exception:
                pass
        
        # Final cleanup
        torch.cuda.empty_cache()
        gc.collect()

def setup_pipeline():
    optimize_gpu_memory()
    
    # Set deterministic behavior
    set_seed(42)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    
    # Load model with optimized settings
    pipeline = AutoPipelineForText2Image.from_pretrained(
        "stabilityai/sdxl-turbo",
        torch_dtype=torch.float16,
        variant="fp16",
        use_safetensors=True,
        low_cpu_mem_usage=True,
        device_map="balanced"  # Keep simple device mapping
    )
    
    # Configure memory settings
    pipeline.enable_xformers_memory_efficient_attention()
    
    # Put all trainable components in eval mode
    for component in [pipeline.text_encoder, pipeline.unet, pipeline.vae]:
        if component is not None:
            component.eval()
            if hasattr(component, 'requires_grad_'):
                component.requires_grad_(False)
    
    # Ensure text encoder is on CPU if needed
    if pipeline.text_encoder is not None:
        pipeline.text_encoder.to('cpu')
    
    optimize_gpu_memory()
    return pipeline, None

def find_free_port(start_port=5001, max_attempts=10):
    """Find first available port starting from start_port"""
    import socket
    for port in range(start_port, start_port + max_attempts):
        try:
            with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
                s.bind(('', port))
                return port
        except OSError:
            continue
    raise RuntimeError(f"Could not find free port in range {start_port}-{start_port + max_attempts}")

def setup_tunnel(port):
    """Setup tunnel with improved error handling and fixed subdomain"""
    try:
        print("\n=== Setting up LocalTunnel ===")
        
        # First, ensure no existing tunnels
        os.system('pkill -f lt')
        time.sleep(2)
        
        # Generate a unique subdomain based on timestamp
        timestamp = datetime.now().strftime('%H%M%S')
        subdomain = f"story-gen-{timestamp}"
        
        # Start tunnel with more robust command
        command = f'npx localtunnel --port {port} --subdomain {subdomain}'
        process = subprocess.Popen(
            command,
            shell=True,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            bufsize=1,
            universal_newlines=True
        )
        
        # Wait for tunnel URL with improved error handling
        start_time = time.time()
        timeout = 30  # seconds
        url = None
        
        while time.time() - start_time < timeout:
            output = process.stdout.readline()
            if not output and process.poll() is not None:
                break
                
            if output:
                print(f"Tunnel output: {output.strip()}")
                if 'your url is:' in output.lower():
                    url = output.split('is: ')[-1].strip()
                    break
                    
            time.sleep(0.1)
            
        if not url:
            raise Exception("Failed to get tunnel URL within timeout")
            
        # Verify tunnel is working
        test_url = f"{url}/health"
        print(f"Testing connection to {test_url}")
        
        for _ in range(3):  # 3 retries
            try:
                response = requests.get(test_url, 
                                     headers={'Accept': 'application/json'},
                                     timeout=5,
                                     verify=False)  # Ignore SSL verification
                if response.ok:
                    print("✓ Tunnel connection verified")
                    return url
            except Exception as e:
                print(f"Retry - Connection test failed: {e}")
                time.sleep(2)
                
        raise Exception("Could not verify tunnel connection")
            
    except Exception as e:
        print(f"✗ Tunnel setup failed: {str(e)}")
        if process and process.poll() is None:
            process.terminate()
        raise

def run_server():
    """Run server with improved error handling"""
    try:
        print("\n1. Checking GPU...")
        if not torch.cuda.is_available():
            print("⚠️ WARNING: No GPU detected!")
        else:
            print(f"✓ Found GPU: {torch.cuda.get_device_name(0)}")
            print(f"✓ Memory: {get_gpu_memory()}")
        
        print("\n2. Starting Flask server...")
        # Find available port
        port = find_free_port(5001)
        print(f"Found available port: {port}")
        
        # Update the Flask server to be more verbose
        def run_flask():
            app.run(
                host='0.0.0.0',
                port=port,
                debug=False,
                use_reloader=False,
                threaded=True
            )
        
        server = Thread(target=run_flask)
        server.daemon = True
        server.start()
        time.sleep(2)  # Give more time to start
        print("✓ Flask server running")
        
        # Test local connection
        try:
            test_response = requests.get(f'http://localhost:{port}/health', timeout=5)
            print(f"✓ Local health check: {test_response.status_code}")
        except Exception as e:
            print(f"⚠️ Local health check failed: {e}")

        print("\n3. Setting up tunnel...")
        tunnel_url = None
        max_attempts = 3
        
        for attempt in range(max_attempts):
            try:
                # Pass the found port to setup_tunnel
                tunnel_url = setup_tunnel(port)
                if tunnel_url:
                    break
            except Exception as e:
                if attempt < max_attempts - 1:
                    print(f"\nRetrying tunnel setup ({attempt + 1}/{max_attempts})...")
                    time.sleep(5)
                else:
                    raise Exception(f"Failed to establish tunnel after {max_attempts} attempts: {str(e)}")
        
        if not tunnel_url:
            raise Exception("Failed to get tunnel URL")
            
        print("\n=== SDXL Server Ready ===")
        print(f"Server URL: {tunnel_url}")
        print("Copy this URL to use in the Story Generator app")
        
        # Keep the notebook running
        try:
            while True:
                time.sleep(1)
        except KeyboardInterrupt:
            print("\nShutting down server...")
            os.system('pkill -f lt')
            
    except Exception as e:
        print(f"\nError: {str(e)}")
        os.system('pkill -f lt')
        raise

@app.route('/health', methods=['GET', 'OPTIONS'])
def health_check():
    """Health check endpoint with enhanced error handling"""
    if request.method == 'OPTIONS':
        return handle_preflight()

    try:
        gpu_available = torch.cuda.is_available()
        if not gpu_available:
            return jsonify({
                'status': 'unhealthy',
                'service': 'sdxl',
                'error': 'GPU not available',
                'gpu_available': False
            }), 503

        gpu_info = get_gpu_memory()
        if not gpu_info or gpu_info['free'] < 2000:  # Less than 2GB free
            return jsonify({
                'status': 'unhealthy',
                'service': 'sdxl',
                'error': 'Insufficient GPU memory',
                'gpu_available': True,
                'gpu_info': gpu_info
            }), 503

        response = jsonify({
            'status': 'healthy',
            'service': 'sdxl',
            'gpu_available': True,
            'gpu_info': gpu_info,
            'timestamp': datetime.now().isoformat()
        })

        # Set all required CORS headers
        response.headers.update({
            'Access-Control-Allow-Origin': '*',
            'Access-Control-Allow-Methods': 'GET, OPTIONS',
            'Access-Control-Allow-Headers': '*',
            'Content-Type': 'application/json',
            'Cache-Control': 'no-cache'
        })
        return response

    except Exception as e:
        logger.error(f"Health check failed: {str(e)}")
        return jsonify({
            'status': 'error',
            'service': 'sdxl',
            'error': str(e)
        }), 500

def handle_preflight():
    """Handle CORS preflight requests"""
    response = make_response()
    response.headers.update({
        'Access-Control-Allow-Origin': '*',
        'Access-Control-Allow-Methods': 'GET, OPTIONS',
        'Access-Control-Allow-Headers': '*',
        'Access-Control-Max-Age': '3600'
    })
    return response

@app.route('/generate_image', methods=['POST', 'OPTIONS'])
def generate_image():
    if request.method == 'OPTIONS':
        response = make_response()
        response.headers.update({
            'Access-Control-Allow-Origin': '*',
            'Access-Control-Allow-Headers': 'Content-Type',
            'Access-Control-Allow-Methods': 'POST',
            'Access-Control-Max-Age': '3600'
        })
        return response

    try:
        # Check GPU memory and cleanup if needed
        gpu_mem = get_gpu_memory()
        if gpu_mem and gpu_mem['free'] < 4000:  # Need at least 4GB free
            optimize_gpu_memory()
            if get_gpu_memory()['free'] < 4000:
                return jsonify({
                    'success': False,
                    'error': 'Insufficient GPU memory available'
                }), 503

        data = request.get_json()
        prompt = data.get('prompt', '')
        print(f"Received image generation request with prompt: {prompt}")  # Debug logging

        if not prompt:
            return jsonify({'success': False, 'error': 'Prompt is required'}), 400

        # Check GPU memory before generation
        gpu_mem = get_gpu_memory()
        if gpu_mem and gpu_mem['free'] < 2000:  # Less than 2GB free
            optimize_gpu_memory()

        # Use global pipeline
        global pipeline
        
        # Fixed autocast usage and pipeline check
        with torch.amp.autocast('cuda', dtype=torch.float16):
            with torch.inference_mode():
                optimize_gpu_memory()
                print("Generating image...") # Debug logging
                
                if pipeline is None:
                    raise ValueError("Pipeline not initialized")
                
                # Check each component individually
                for name, component in [
                    ('text_encoder', getattr(pipeline, 'text_encoder', None)),
                    ('unet', getattr(pipeline, 'unet', None)),
                    ('vae', getattr(pipeline, 'vae', None))
                ]:
                    if component is not None and hasattr(component, 'device'):
                        if name == 'text_encoder' and component.device.type != 'cpu':
                            component.to('cpu')
                        elif name != 'text_encoder' and component.device.type != 'cuda':
                            component.to('cuda')
                
                image = pipeline(
                    prompt=prompt,
                    num_inference_steps=1,
                    guidance_scale=0.0,
                    width=384,
                    height=384,
                    negative_prompt="low quality, blurry, distorted",
                ).images[0]
                optimize_gpu_memory()
                print("Image generated successfully") # Debug logging

        # Add error handling for invalid values
        if np.isnan(np.array(image)).any():
            raise ValueError("Generated image contains invalid values")

        buffered = io.BytesIO()
        image.save(buffered, format="JPEG", optimize=True, quality=85)
        img_str = base64.b64encode(buffered.getvalue()).decode()

        # Cleanup
        del image
        buffered.close()
        torch.cuda.empty_cache()
        gc.collect()

        print("Sending response") # Debug logging
        return jsonify({
            'success': True,
            'image': f'data:image/jpeg;base64,{img_str}'
        })

    except Exception as e:
        print(f"Error generating image: {str(e)}") # Debug logging
        optimize_gpu_memory()
        torch.cuda.empty_cache()
        logger.error(f"Image generation error: {str(e)}")
        return jsonify({
            'success': False,
            'error': f"Generation failed: {str(e)}"
        }), 500

@app.route('/verify_connection', methods=['POST'])
def verify_connection():
    try:
        data = request.get_json()
        backend_url = data.get('backend_url')
        timestamp = data.get('timestamp')

        if not backend_url:
            return jsonify({
                'success': False,
                'error': 'No backend URL provided'
            }), 400

        return jsonify({
            'success': True,
            'message': 'Connection verified',
            'timestamp': timestamp
        })

    except Exception as e:
        logger.error(f"Verification error: {e}")
        return jsonify({
            'success': False,
            'error': str(e)
        }), 500

@app.route('/', methods=['GET'])
def root():
    """Basic root endpoint for testing"""
    return jsonify({
        'status': 'alive',
        'timestamp': datetime.now().isoformat()
    })

@app.route('/debug', methods=['GET'])
def debug():
    """Debug endpoint with detailed information"""
    return jsonify({
        'status': 'debug',
        'gpu': {
            'available': torch.cuda.is_available(),
            'device_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
            'memory': get_gpu_memory(),
            'device_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else None
        },
        'server': {
            'timestamp': datetime.now().isoformat(),
            'pid': os.getpid(),
            'python_version': sys.version
        }
    })

if __name__ == '__main__':
    # Check CUDA availability before starting
    if not torch.cuda.is_available():
        print("ERROR: No GPU available. This server requires CUDA GPU support.")
        exit(1)
    
    # Initialize globals
    global pipeline, autocast_context
    pipeline = None
    autocast_context = None
    
    # Set CUDA device properties
    torch.cuda.set_per_process_memory_fraction(0.7)  # Reduced from 0.8 to 0.7
    torch.cuda.empty_cache()
    
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.enabled = True
    run_server()
