In [5]:
!pip install torch torchvision torchaudio
!pip install diffusers
!pip install transformers
!pip install fastapi uvicorn
!pip install pillow
!pip install python-multipart
!pip install pyngrok




In [None]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import StableDiffusionPipeline
import torch
from io import BytesIO
import base64
from PIL import Image
import asyncio
from pyngrok import ngrok
import uvicorn

# --- Configuration ---
MODEL_ID = "runwayml/stable-diffusion-v1-5"
USE_CUDA = torch.cuda.is_available()
DEVICE = "cuda" if USE_CUDA else "cpu"
DTYPE = torch.float16 if USE_CUDA else torch.float32

# --- Global Pipeline ---
pipeline = None

# --- Asynchronous Pipeline Loading ---
async def load_pipeline():
    global pipeline
    try:
        pipeline = StableDiffusionPipeline.from_pretrained(
            MODEL_ID, torch_dtype=DTYPE
        ).to(DEVICE)
        pipeline.enable_attention_slicing()  # Optimizes memory usage
        print(f"Model loaded successfully on {DEVICE}")
    except Exception as e:
        print(f"Error loading pipeline: {e}")
        raise RuntimeError("Failed to load the model pipeline")

# --- API Models ---
class ImageGenerationRequest(BaseModel):
    prompt: str
    num_images: int = 1
    steps: int = 50
    guidance_scale: float = 7.5

class ImageGenerationResponse(BaseModel):
    images: list[str]

# --- FastAPI Application ---
app = FastAPI()

@app.on_event("startup")
async def on_startup():
    await load_pipeline()

def generate_images(request: ImageGenerationRequest) -> list[str]:
    try:
        # Generate images
        output = pipeline(
            prompt=request.prompt,
            num_images_per_prompt=request.num_images,
            num_inference_steps=request.steps,
            guidance_scale=request.guidance_scale,
        )
        images = output.images

        # Convert to Base64 and save locally
        base64_images = []
        for idx, image in enumerate(images):
            # Convert to Base64
            buffered = BytesIO()
            image.save(buffered, format="PNG")
            img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
            base64_images.append(img_str)
        return base64_images
    except Exception as e:
        print(f"Error during image generation: {e}")
        raise HTTPException(status_code=500, detail="Image generation failed")

@app.post("/generate", response_model=ImageGenerationResponse)
async def generate_image_endpoint(request: ImageGenerationRequest):
    # Ensure pipeline is loaded
    if pipeline is None:
        raise HTTPException(status_code=503, detail="Model pipeline is not ready")
    images = generate_images(request)
    return {"images": images}

# --- Expose FastAPI via ngrok ---
ngrok.set_auth_token('2qAlnMNFjxT1TJj8jjVtOyttiuu_2bDcGegeLL1Nvhce1SoE9')  # Set your ngrok auth token here
public_url = ngrok.connect(8000)
print(f"FastAPI app is live at: {public_url}")

# --- Run the FastAPI Application ---
if __name__ == "__main__":
    uvicorn.run(app="__main__:app", host="0.0.0.0", port=8000, reload=True)


        on_event is deprecated, use lifespan event handlers instead.

        Read more about it in the
        [FastAPI docs for Lifespan Events](https://fastapi.tiangolo.com/advanced/events/).
        
  @app.on_event("startup")


FastAPI app is live at: NgrokTunnel: "https://ae2f-34-168-237-2.ngrok-free.app" -> "http://localhost:8000"


INFO:     Will watch for changes in these directories: ['/content']
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
INFO:     Started reloader process [577] using StatReload


In [None]:
import requests
import base64
from PIL import Image
import io

# Replace with your generated ngrok URL
url = "http://xxxxxx.ngrok.io/generate"
data = {
    "prompt": "A beautiful sunset over the ocean",
    "num_images": 1,
    "steps": 50,
    "guidance_scale": 7.5
}

response = requests.post(url, json=data)
images = response.json()

# Display the generated image
image_data = base64.b64decode(images['images'][0])
image = Image.open(io.BytesIO(image_data))
image.show()
