In [1]:
!pip install diffusers
!pip install -U peft
!pip install fastapi uvicorn pyngrok

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m24.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m


Load Ngrok Token

In [13]:
!ngrok authtoken YOUR_NGROK_TOKEN

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


Load 4 Models for Async Inference

In [3]:
import time
import torch
from diffusers import DiffusionPipeline
from huggingface_hub import hf_hub_download

def load_stable_diffusion(model_id: str, lora_weights: str):
    """
    Load the Stable Diffusion model with specified LoRA weights.

    :param model_id: The ID of the model to load.
    :param lora_weights: The path or name of the LoRA weights to load.
    :return: DiffusionPipeline: The loaded model pipeline.
    """
    pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
    pipe = pipe.to("cuda")

    # Optional CPU offloading to save some GPU Memory
    pipe.enable_model_cpu_offload()

    # Loading Trained LoRA Weights
    pipe.load_lora_weights(lora_weights)

    return pipe

def load_multiple_pipelines(model_id: str, lora_weights: str, count: int):
    """
    Load multiple Stable Diffusion pipelines.

    :param model_id: The ID of the model to load.
    :param lora_weights: The path or name of the LoRA weights to load.
    :param count: The number of pipelines to load.
    :return: list: A list of loaded pipelines.
    """
    return [load_stable_diffusion(model_id, lora_weights) for _ in range(count)]

# Load multiple pipelines for stability model
stability_pipes = load_multiple_pipelines("stabilityai/stable-diffusion-xl-base-1.0", "WizzVard22/sdxl-lora-testing", 4)

# Unpack pipelines
stability_pipe1, stability_pipe2, stability_pipe3, stability_pipe4 = stability_pipes

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Create a path to localy save generated images

In [4]:
import os

directory = 'data/generated_images/'
os.makedirs(os.path.dirname(directory), exist_ok=True)

Create a FastAPI instance

In [5]:
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

# Mount the directory to serve static files
app.mount("/data", StaticFiles(directory="data"), name="data")

Import Statements

In [6]:
import nest_asyncio
import asyncio
import uvicorn
from pyngrok import ngrok
from concurrent.futures import ProcessPoolExecutor
from random import randint
import torch

Derectory Setup and Port Configuration

In [7]:
directory = 'data/generated_images/'
os.makedirs(os.path.dirname(directory), exist_ok=True)
ngrok_port = 8000

Image Generation Function

In [8]:
def generate_image_sync(prompt, pipe):
    """
    Generate an image using the Stable Diffusion model based on the given prompt.

    :param prompt: The prompt for image generation.
    :param model_name: The model name to load and use for image generation.

    :return: Path to the generated image.
    """
    try:
        image_dir = f"data/generated_images/{randint(1, 100000)}.png"
        # Generate the image
        pipe(prompt, num_inference_steps=25).images[0].save(f"{image_dir}")
        torch.cuda.empty_cache()
        return image_dir
    except Exception as e:
        raise

FastAPI Routes

In [10]:
@app.get("/")
async def home():
    return {"message": "Hello, this is a test API"}

@app.post("/generate_image")
async def generate_image(request: Request):
    """
    Generate an image using the Stable Diffusion model based on the given prompt.

    :return: The image is sent directly as a binary file using FastAPI's FileResponse.
    """
    try:
        data = await request.json()
        prompt = data.get("prompt", "")
        model = data.get("model", "")

        pipes = [stability_pipe1, stability_pipe2, stability_pipe3, stability_pipe4]
        
        loop = asyncio.get_event_loop()
        tasks = [loop.run_in_executor(None, generate_image_sync, prompt, pipe) for pipe in pipes]
        image_paths = await asyncio.gather(*tasks)
        print(f"Image paths: {image_paths}")
        
        response = {"images": image_paths}

        print(f"Response saved with {len(image_paths)} images.")
        return JSONResponse(content=response)
    except Exception as e:
        print(f"Error during image generation: {e}")
        return JSONResponse(status_code=500, content={"error": "Internal Server Error"})

Running the App with Ngrok

In [None]:
# Custom function to run the app, handling event loop manually if needed
def run_app():
    config = uvicorn.Config(app, host="0.0.0.0", port=ngrok_port)
    server = uvicorn.Server(config)

    # Start ngrok tunnel
    public_url = ngrok.connect(ngrok_port)
    print(f" * ngrok tunnel \"{public_url}\" -> \"http://127.0.0.1:{ngrok_port}\"")

    try:
        nest_asyncio.apply()  # Apply the nest_asyncio fix to allow nested loops
    except ImportError:
        pass

    # Run the server
    server.run()

if __name__ == '__main__':
    run_app()