Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Async VertexAI Image Generation #3739

Merged
merged 17 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions docs/my-website/docs/image_generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,20 @@ response = image_generation(
model="bedrock/stability.stable-diffusion-xl-v0",
)
print(f"response: {response}")
```

## VertexAI - Image Generation Models

### Usage

Use this for image generation models on VertexAI

```python
response = litellm.image_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
)
print(f"response: {response}")
```
25 changes: 25 additions & 0 deletions docs/my-website/docs/providers/vertex.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,31 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |

## Image Generation Models

Usage

```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
)
```

**Generating multiple images**

Use the `n` parameter to pass how many images you want generated
```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
```

## Extra

Expand Down
3 changes: 3 additions & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def identify(event_details):
get_supported_openai_params,
get_api_base,
get_first_chars_messages,
ModelResponse,
ImageResponse,
ImageObject,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
Expand Down
224 changes: 224 additions & 0 deletions litellm/llms/vertex_httpx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import os, types
import json
from enum import Enum
import requests # type: ignore
import time
from typing import Callable, Optional, Union, List, Any, Tuple
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM


class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs


class VertexLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[Any] = None
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None

def load_auth(self) -> Tuple[Any, str]:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.credentials import Credentials # type: ignore[import-untyped]
import google.auth as google_auth

credentials, project_id = google_auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

credentials.refresh(Request())

if not project_id:
raise ValueError("Could not resolve project_id")

if not isinstance(project_id, str):
raise TypeError(
f"Expected project_id to be a str but got {type(project_id)}"
)

return credentials, project_id

def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import Request # type: ignore[import-untyped]

credentials.refresh(Request())

def _prepare_request(self, request: httpx.Request) -> None:
access_token = self._ensure_access_token()

if request.headers.get("Authorization"):
# already authenticated, nothing for us to do
return

request.headers["Authorization"] = f"Bearer {access_token}"

def _ensure_access_token(self) -> str:
if self.access_token is not None:
return self.access_token

if not self._credentials:
self._credentials, project_id = self.load_auth()
if not self.project_id:
self.project_id = project_id
else:
self.refresh_auth(self._credentials)

if not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")

assert isinstance(self._credentials.token, str)
return self._credentials.token

def image_generation(
self,
prompt: str,
vertex_project: str,
vertex_location: str,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
model_response=None,
aimg_generation=False,
):
if aimg_generation == True:
response = self.aimage_generation(
prompt=prompt,
vertex_project=vertex_project,
vertex_location=vertex_location,
model=model,
client=client,
optional_params=optional_params,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
)
return response

async def aimage_generation(
self,
prompt: str,
vertex_project: str,
vertex_location: str,
model_response: litellm.ImageResponse,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
):
response = None
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)

self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
else:
self.async_handler = client # type: ignore

# make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"

"""
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d {
"instances": [
{
"prompt": "a cat"
}
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header = self._ensure_access_token()
optional_params = optional_params or {
"sampleCount": 1
} # default optional params

request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}

request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)

response = await self.async_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)

if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""

_json_response = response.json()
_predictions = _json_response["predictions"]

_response_data: List[litellm.ImageObject] = []
for _prediction in _predictions:
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
_response_data.append(image_object)

model_response.data = _response_data

return model_response
32 changes: 32 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
Expand Down Expand Up @@ -118,6 +119,7 @@
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################


Expand Down Expand Up @@ -3854,6 +3856,36 @@ def image_generation(
model_response=model_response,
aimg_generation=aimg_generation,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
model_response = vertex_chat_completion.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
aimg_generation=aimg_generation,
)

return model_response
except Exception as e:
## Map to OpenAI Exception
Expand Down
33 changes: 33 additions & 0 deletions litellm/tests/test_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,36 @@ async def test_aimage_generation_bedrock_with_optional_params():
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")


@pytest.mark.asyncio
async def test_aimage_generation_vertex_ai():
from test_amazing_vertex_completion import load_vertex_ai_credentials

litellm.set_verbose = True

load_vertex_ai_credentials()
try:
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
assert response.data is not None
assert len(response.data) > 0

for d in response.data:
assert isinstance(d, litellm.ImageObject)
print("data in response.data", d)
assert d.b64_json is not None
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
Loading
Loading