# Story Generator - Image Server
This notebook runs the image generation server using SDXL-Turbo

In [None]:
# Install dependencies
!pip install torch torchvision diffusers transformers flask flask-cors pyngrok

from google.colab import userdata
NGROK_TOKEN = userdata.get('NGROK_AUTH_TOKEN')
!ngrok authtoken $NGROK_TOKEN

from IPython.display import clear_output
clear_output()

import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name()}")

In [None]:
from flask import Flask, request, jsonify
from flask_cors import CORS
from pyngrok import ngrok
import torch
from diffusers import DiffusionPipeline
import base64
import io
from PIL import Image

app = Flask(__name__)
CORS(app)

# Initialize the model
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/sdxl-turbo",
    torch_dtype=torch.float16,
    variant="fp16"
).to("cuda")

@app.route('/health', methods=['GET'])
def health_check():
    return jsonify({
        'status': 'healthy',
        'gpu_available': torch.cuda.is_available()
    })

@app.route('/generate_frame', methods=['POST'])
def generate_frame():
    try:
        data = request.json
        prompt = data.get('prompt', '')

        if not prompt:
            return jsonify({'success': False, 'error': 'No prompt provided'}), 400

        # Generate image
        image = pipe(
            prompt=prompt,
            num_inference_steps=1,
            guidance_scale=0.0,
        ).images[0]

        # Convert to base64
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        img_str = base64.b64encode(buffered.getvalue()).decode()

        return jsonify({
            'success': True,
            'image': f'data:image/jpeg;base64,{img_str}'
        })

    except Exception as e:
        return jsonify({'success': False, 'error': str(e)}), 500

# Start the server
ngrok.kill() # Kill any existing tunnels
public_url = ngrok.connect(addr="5000", proto="http")
print(f'Server running at: {public_url}')

app.run(host='0.0.0.0', port=5000)