In [None]:
from flask import Flask, request, jsonify
from pyngrok import ngrok
from typing import Dict, Any
from flask import Flask, request, jsonify
import asyncio
import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import io
import os
from typing import List, Dict

class StoryToImageGenerator:
    def __init__(self):
        self.model = None
        self.setup_model()
        self.results = []
        self.image_paths = []

    def setup_model(self):
        """Initialize Stable Diffusion model"""
        try:
            self.model = StableDiffusionPipeline.from_pretrained(
                "CompVis/stable-diffusion-v1-4",
                torch_dtype=torch.float32
            )
            device = "cuda" if torch.cuda.is_available() else "cpu"
            self.model.to(device)
            print(f"Model loaded successfully on {device}")
        except Exception as e:
            raise RuntimeError(f"Failed to load model: {str(e)}")

    async def generate_from_story(self, story: str) -> List[Dict]:
        scenes = self._break_into_scenes(story)
        self.results = []  # Reset results
        self.image_paths = []

        for scene_num, scene in enumerate(scenes, 1):
            try:
                prompt = self._create_cartoon_prompt(scene)
                image = await self.model(prompt).images[0]

                filename = f"scene_{scene_num}.png"
                filepath = os.path.join('/content/sample_data/', filename)
                image.save(filepath)

                self.image_paths.append(filepath)

                self.results.append({
                    'scene_number': scene_num,
                    'filepath': filepath,
                    'status': 'success'
                })
            except Exception as e:
                self.results.append({
                    'scene_number': scene_num,
                    'status': 'error',
                    'message': str(e)
                })

        return self.results, self.image_paths



    def _break_into_scenes(self, story: str) -> List[str]:
        """Break down the story into individual scenes"""
        paragraphs = [p.strip() for p in story.split('.') if p.strip()]
        scenes = []
        for paragraph in paragraphs:
            sentences = [s.strip() for s in paragraph.split('.') if s.strip()]
            current_scene = ""
            for sentence in sentences:
                if len(current_scene) + len(sentence) < 200:
                    current_scene += (" " if current_scene else "") + sentence
                else:
                    if current_scene:
                        scenes.append(current_scene)
                    current_scene = sentence
            if current_scene:
                scenes.append(current_scene)
        return scenes



    def _create_cartoon_prompt(self, scene_description: str) -> str:
        """Create a cartoon-style prompt from scene description"""
        base_style = "cartoon illustration, vibrant colors, smooth lines, cheerful atmosphere, digital art style"
        return f"{base_style}, {scene_description}"



generator = StoryToImageGenerator()

app = Flask(__name__)

@app.route('/')
def home():
    return "<h1>Hello World!</h1>"

@app.route('/generate', methods=['POST'])
async def generate_images():
    try:
        data = request.get_json()
        if not data or 'story' not in data:
            return jsonify({
                'status': 'error',
                'message': 'Story parameter is required'
            }), 400

        story = data['story']
        generator.results, generator.image_paths = await generator.generate_from_story(story)

        return jsonify({
            'status': 'success',
            'results': generator.results
        })
    except Exception as e:
        return jsonify({
            'status': 'error',
            'message': str(e)
        }), 500

@app.route('/getResult')
def get_result():
    return jsonify({"result": generator.results})

@app.route('/gepath')
def get_path():
    return jsonify({"result": generator.image_paths})