-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3739 from BerriAI/litellm_add_imagen_support
[FEAT] Async VertexAI Image Generation
- Loading branch information
Showing
7 changed files
with
386 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
import os, types | ||
import json | ||
from enum import Enum | ||
import requests # type: ignore | ||
import time | ||
from typing import Callable, Optional, Union, List, Any, Tuple | ||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason | ||
import litellm, uuid | ||
import httpx, inspect # type: ignore | ||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler | ||
from .base import BaseLLM | ||
|
||
|
||
class VertexAIError(Exception): | ||
def __init__(self, status_code, message): | ||
self.status_code = status_code | ||
self.message = message | ||
self.request = httpx.Request( | ||
method="POST", url=" https://cloud.google.com/vertex-ai/" | ||
) | ||
self.response = httpx.Response(status_code=status_code, request=self.request) | ||
super().__init__( | ||
self.message | ||
) # Call the base class constructor with the parameters it needs | ||
|
||
|
||
class VertexLLM(BaseLLM): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.access_token: Optional[str] = None | ||
self.refresh_token: Optional[str] = None | ||
self._credentials: Optional[Any] = None | ||
self.project_id: Optional[str] = None | ||
self.async_handler: Optional[AsyncHTTPHandler] = None | ||
|
||
def load_auth(self) -> Tuple[Any, str]: | ||
from google.auth.transport.requests import Request # type: ignore[import-untyped] | ||
from google.auth.credentials import Credentials # type: ignore[import-untyped] | ||
import google.auth as google_auth | ||
|
||
credentials, project_id = google_auth.default( | ||
scopes=["https://www.googleapis.com/auth/cloud-platform"], | ||
) | ||
|
||
credentials.refresh(Request()) | ||
|
||
if not project_id: | ||
raise ValueError("Could not resolve project_id") | ||
|
||
if not isinstance(project_id, str): | ||
raise TypeError( | ||
f"Expected project_id to be a str but got {type(project_id)}" | ||
) | ||
|
||
return credentials, project_id | ||
|
||
def refresh_auth(self, credentials: Any) -> None: | ||
from google.auth.transport.requests import Request # type: ignore[import-untyped] | ||
|
||
credentials.refresh(Request()) | ||
|
||
def _prepare_request(self, request: httpx.Request) -> None: | ||
access_token = self._ensure_access_token() | ||
|
||
if request.headers.get("Authorization"): | ||
# already authenticated, nothing for us to do | ||
return | ||
|
||
request.headers["Authorization"] = f"Bearer {access_token}" | ||
|
||
def _ensure_access_token(self) -> str: | ||
if self.access_token is not None: | ||
return self.access_token | ||
|
||
if not self._credentials: | ||
self._credentials, project_id = self.load_auth() | ||
if not self.project_id: | ||
self.project_id = project_id | ||
else: | ||
self.refresh_auth(self._credentials) | ||
|
||
if not self._credentials.token: | ||
raise RuntimeError("Could not resolve API token from the environment") | ||
|
||
assert isinstance(self._credentials.token, str) | ||
return self._credentials.token | ||
|
||
def image_generation( | ||
self, | ||
prompt: str, | ||
vertex_project: str, | ||
vertex_location: str, | ||
model: Optional[ | ||
str | ||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model | ||
client: Optional[AsyncHTTPHandler] = None, | ||
optional_params: Optional[dict] = None, | ||
timeout: Optional[int] = None, | ||
logging_obj=None, | ||
model_response=None, | ||
aimg_generation=False, | ||
): | ||
if aimg_generation == True: | ||
response = self.aimage_generation( | ||
prompt=prompt, | ||
vertex_project=vertex_project, | ||
vertex_location=vertex_location, | ||
model=model, | ||
client=client, | ||
optional_params=optional_params, | ||
timeout=timeout, | ||
logging_obj=logging_obj, | ||
model_response=model_response, | ||
) | ||
return response | ||
|
||
async def aimage_generation( | ||
self, | ||
prompt: str, | ||
vertex_project: str, | ||
vertex_location: str, | ||
model_response: litellm.ImageResponse, | ||
model: Optional[ | ||
str | ||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model | ||
client: Optional[AsyncHTTPHandler] = None, | ||
optional_params: Optional[dict] = None, | ||
timeout: Optional[int] = None, | ||
logging_obj=None, | ||
): | ||
response = None | ||
if client is None: | ||
_params = {} | ||
if timeout is not None: | ||
if isinstance(timeout, float) or isinstance(timeout, int): | ||
_httpx_timeout = httpx.Timeout(timeout) | ||
_params["timeout"] = _httpx_timeout | ||
else: | ||
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0) | ||
|
||
self.async_handler = AsyncHTTPHandler(**_params) # type: ignore | ||
else: | ||
self.async_handler = client # type: ignore | ||
|
||
# make POST request to | ||
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict | ||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict" | ||
|
||
""" | ||
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218 | ||
curl -X POST \ | ||
-H "Authorization: Bearer $(gcloud auth print-access-token)" \ | ||
-H "Content-Type: application/json; charset=utf-8" \ | ||
-d { | ||
"instances": [ | ||
{ | ||
"prompt": "a cat" | ||
} | ||
], | ||
"parameters": { | ||
"sampleCount": 1 | ||
} | ||
} \ | ||
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict" | ||
""" | ||
auth_header = self._ensure_access_token() | ||
optional_params = optional_params or { | ||
"sampleCount": 1 | ||
} # default optional params | ||
|
||
request_data = { | ||
"instances": [{"prompt": prompt}], | ||
"parameters": optional_params, | ||
} | ||
|
||
request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\"" | ||
logging_obj.pre_call( | ||
input=prompt, | ||
api_key=None, | ||
additional_args={ | ||
"complete_input_dict": optional_params, | ||
"request_str": request_str, | ||
}, | ||
) | ||
|
||
response = await self.async_handler.post( | ||
url=url, | ||
headers={ | ||
"Content-Type": "application/json; charset=utf-8", | ||
"Authorization": f"Bearer {auth_header}", | ||
}, | ||
data=json.dumps(request_data), | ||
) | ||
|
||
if response.status_code != 200: | ||
raise Exception(f"Error: {response.status_code} {response.text}") | ||
""" | ||
Vertex AI Image generation response example: | ||
{ | ||
"predictions": [ | ||
{ | ||
"bytesBase64Encoded": "BASE64_IMG_BYTES", | ||
"mimeType": "image/png" | ||
}, | ||
{ | ||
"mimeType": "image/png", | ||
"bytesBase64Encoded": "BASE64_IMG_BYTES" | ||
} | ||
] | ||
} | ||
""" | ||
|
||
_json_response = response.json() | ||
_predictions = _json_response["predictions"] | ||
|
||
_response_data: List[litellm.ImageObject] = [] | ||
for _prediction in _predictions: | ||
_bytes_base64_encoded = _prediction["bytesBase64Encoded"] | ||
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded) | ||
_response_data.append(image_object) | ||
|
||
model_response.data = _response_data | ||
|
||
return model_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.