In [None]:
import os
import requests
from typing import List
from fastapi import UploadFile
from dotenv import load_dotenv
 
# Load environment variables
load_dotenv()
 
# Configuration
INVOKE_URL = "https://ai.api.nvidia.com/v1/vlm/nvidia/vila"
API_KEY = os.getenv("API_KEY", "")
if not API_KEY:
    raise RuntimeError("API_KEY is missing! Set it in your .env file or environment variables.")
 
NVCF_ASSET_URL = "https://api.nvcf.nvidia.com/v2/nvcf/assets"
SUPPORTED_LIST = {
    "png": ["image/png", "img"],
    "jpg": ["image/jpg", "img"],
    "jpeg": ["image/jpeg", "img"],
    "mp4": ["video/mp4", "video"],
}
 
# Helper Functions
def get_extension(filename: str) -> str:
    _, ext = os.path.splitext(filename)
    return ext[1:].lower()
 
def mime_type(ext: str) -> str:
    return SUPPORTED_LIST[ext][0]
 
def media_type(ext: str) -> str:
    return SUPPORTED_LIST[ext][1]
 
def upload_asset(media_file: UploadFile, description: str) -> str:
    """Uploads a media file to the NVIDIA API and returns the asset ID."""
    ext = get_extension(media_file.filename)
    if ext not in SUPPORTED_LIST:
        raise ValueError(f"Unsupported file format: {ext}")
 
    headers = {
        "Authorization": f"Bearer {API_KEY}",
        "Content-Type": "application/json",
        "Accept": "application/json",
    }
 
    response = requests.post(
        NVCF_ASSET_URL,
        headers=headers,
        json={"contentType": mime_type(ext), "description": description},
        timeout=30,
    )
    response.raise_for_status()
    authorize_res = response.json()
 
    upload_url = authorize_res["uploadUrl"]
    asset_id = authorize_res["assetId"]
 
    upload_response = requests.put(
        upload_url,
        data=media_file.file.read(),
        headers={
            "x-amz-meta-nvcf-asset-description": description,
            "content-type": mime_type(ext),
        },
        timeout=300,
    )
    upload_response.raise_for_status()
    return asset_id
 
def delete_asset(asset_id: str):
    """Deletes a previously uploaded asset by its ID."""
    headers = {"Authorization": f"Bearer {API_KEY}"}
    response = requests.delete(f"{NVCF_ASSET_URL}/{asset_id}", headers=headers, timeout=30)
    response.raise_for_status()
 
def chat_with_media_vila(
    query: str,
    files: List[UploadFile],
    stream: bool = False,
    max_tokens: int = 1024,
    temperature: float = 0.2,
    top_p: float = 0.7,
    seed: int = 50,
    num_frames_per_inference: int = 8,
   
):
    """Handles chat interactions with NVIDIA's VILA API."""
    asset_list = []
    media_content = ""
    has_video = False
 
    try:
        for media_file in files:
            ext = get_extension(media_file.filename)
            if ext not in SUPPORTED_LIST:
                raise ValueError(f"Unsupported file format: {ext}")
 
            if media_type(ext) == "video":
                if len(files) > 1:
                    raise ValueError("Only a single video file is supported.")
                has_video = True
 
            asset_id = upload_asset(media_file, "Reference media file")
            asset_list.append(asset_id)
            media_content += f'<{media_type(ext)} src="data:{mime_type(ext)};asset_id,{asset_id}" />'
 
        asset_seq = ",".join(asset_list)
        headers = {
            "Authorization": f"Bearer {API_KEY}",
            "Content-Type": "application/json",
            "NVCF-INPUT-ASSET-REFERENCES": asset_seq,
            "NVCF-FUNCTION-ASSET-IDS": asset_seq,
            "Accept": "application/json" if not stream else "text/event-stream",
        }
 
        messages = [{"role": "user", "content": f"{query} {media_content}"}]
        payload = {
            "max_tokens": max_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "seed": seed,
            "num_frames_per_inference": num_frames_per_inference,
            "messages": messages,
            "stream": stream,
            "model": "nvidia/vila"
        }

        response = requests.post(INVOKE_URL, headers=headers, json=payload, stream=stream)
 
        if stream:
            return {"stream": list(response.iter_lines())}
        else:
            return response.json()
 
    except Exception as e:
        return {"error": str(e)}
 
    finally:
        for asset_id in asset_list:
            try:
                delete_asset(asset_id)
            except Exception:
                pass

In [None]:
@app.post("/nvidia-vila/")
def video_text(
    query: str = Form(...),
    files: List[UploadFile] = File(...),
    stream: bool = Form(False),
    max_tokens: int = Form(1024),
    temperature: float = Form(0.2),
    top_p: float = Form(0.7),
    seed: int = Form(50),
    num_frames_per_inference: int = Form(8)
):
    try:
        result = chat_with_media_vila(
            query=query,
            files=files,
            stream=stream,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            seed=seed,
            num_frames_per_inference=num_frames_per_inference
        )
        return JSONResponse(content=result)
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")