In [None]:
!pip install transformers fastapi uvicorn pyngrok accelerate diffusers imageio-ffmpeg

In [None]:
!ngrok config add-authtoken <your-ngroktoken> # Add your ngrok token here

In [None]:
from fastapi import FastAPI, HTTPException
from fastapi.responses import FileResponse
from pydantic import BaseModel
from diffusers import DiffusionPipeline, I2VGenXLPipeline
from diffusers.utils import export_to_video, load_image


from pyngrok import ngrok
import uvicorn
import nest_asyncio
import torch
import uuid
import base64
import requests
from PIL import Image
from io import BytesIO

app = FastAPI()

text_to_video_pipeline = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
text_to_video_pipeline.enable_model_cpu_offload()
text_to_video_pipeline.enable_vae_slicing()

image_to_video_pipeline = I2VGenXLPipeline.from_pretrained("ali-vilab/i2vgen-xl", torch_dtype=torch.float16, variant="fp16")
image_to_video_pipeline.enable_model_cpu_offload()
image_to_video_pipeline.enable_sequential_cpu_offload()
image_to_video_pipeline.vae.enable_slicing()
image_to_video_pipeline.vae.enable_tiling()

processed_requests = {}


class Prompt(BaseModel):
    text: str = None
    image: str = None
    request_id: str


@app.get("/")
def root():
    return {"message": "FastAPI server for video generation inference is running!"}


@app.post("/predict/")
def predict(request: Prompt):
    try:
        print(f"Received payload: {request.dict()}")

        request_id = request.request_id
        prompt_text = request.text or ""
        image_base64 = request.image

        if not prompt_text and not image_base64:
            raise HTTPException(
                status_code=400, detail="At least one of 'text' or 'image_base64' must be provided."
            )


        if request_id in processed_requests:
            video_filename = processed_requests[request_id]
            return FileResponse(video_filename, media_type="video/mp4", filename=video_filename)

        video_id = str(uuid.uuid4())
        video_filename = f"{video_id}.mp4"

        negative_prompt = "Distorted, discontinuous, Ugly, blurry, low resolution, motionless, static, disfigured, disconnected limbs, Ugly faces, incomplete arms"

        if prompt_text and image_base64:
            image_bytes = base64.b64decode(image_base64)
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
            video = image_to_video_pipeline(
                prompt=prompt_text,
                image=image,
                num_inference_steps=50,
                negative_prompt=negative_prompt,
                guidance_scale=9.0,
                generator=torch.Generator(device="cuda").manual_seed(8888),
            ).frames[0]

        elif prompt_text:
            video = text_to_video_pipeline(prompt_text).frames[0]

        elif image_base64:
            image_bytes = base64.b64decode(image_base64)
            image = Image.open(BytesIO(image_bytes)).convert("RGB")
            video = image_to_video_pipeline(
                prompt="",
                image=image,
                num_inference_steps=50,
                negative_prompt=negative_prompt,
                guidance_scale=9.0,
                generator=torch.Generator(device="cuda").manual_seed(8888),
            ).frames[0]

        export_to_video(video, video_filename, fps=10)
        processed_requests[request_id] = video_filename

        return FileResponse(video_filename, media_type="video/mp4", filename=video_filename)

    except Exception as e:
        print(f"Error: {e}")
        raise HTTPException(status_code=500, detail=f"Error: {str(e)}")


def start_ngrok():
    public_url = ngrok.connect(8000)
    print(f"Public URL: {public_url}")

def start_server():
    nest_asyncio.apply()
    uvicorn.run(app, host="0.0.0.0", port=8000)

if __name__ == "__main__":
    print("Starting ngrok...")
    start_ngrok()
    print("Starting FastAPI server...")
    start_server()