Skip to content

Commit

Permalink
Merge pull request #3739 from BerriAI/litellm_add_imagen_support
Browse files Browse the repository at this point in the history
[FEAT] Async VertexAI Image Generation
  • Loading branch information
ishaan-jaff committed May 20, 2024
2 parents 1deb274 + f3eb832 commit 91a89eb
Show file tree
Hide file tree
Showing 7 changed files with 386 additions and 1 deletion.
16 changes: 16 additions & 0 deletions docs/my-website/docs/image_generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,20 @@ response = image_generation(
model="bedrock/stability.stable-diffusion-xl-v0",
)
print(f"response: {response}")
```

## VertexAI - Image Generation Models

### Usage

Use this for image generation models on VertexAI

```python
response = litellm.image_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
)
print(f"response: {response}")
```
25 changes: 25 additions & 0 deletions docs/my-website/docs/providers/vertex.md
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,31 @@ All models listed [here](https://github.com/BerriAI/litellm/blob/57f37f743886a02
| text-embedding-preview-0409 | `embedding(model="vertex_ai/text-embedding-preview-0409", input)` |
| text-multilingual-embedding-preview-0409 | `embedding(model="vertex_ai/text-multilingual-embedding-preview-0409", input)` |

## Image Generation Models

Usage

```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
)
```

**Generating multiple images**

Use the `n` parameter to pass how many images you want generated
```python
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
```

## Extra

Expand Down
3 changes: 3 additions & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def identify(event_details):
get_supported_openai_params,
get_api_base,
get_first_chars_messages,
ModelResponse,
ImageResponse,
ImageObject,
)
from .llms.huggingface_restapi import HuggingfaceConfig
from .llms.anthropic import AnthropicConfig
Expand Down
224 changes: 224 additions & 0 deletions litellm/llms/vertex_httpx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
import os, types
import json
from enum import Enum
import requests # type: ignore
import time
from typing import Callable, Optional, Union, List, Any, Tuple
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid
import httpx, inspect # type: ignore
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .base import BaseLLM


class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs


class VertexLLM(BaseLLM):
def __init__(self) -> None:
super().__init__()
self.access_token: Optional[str] = None
self.refresh_token: Optional[str] = None
self._credentials: Optional[Any] = None
self.project_id: Optional[str] = None
self.async_handler: Optional[AsyncHTTPHandler] = None

def load_auth(self) -> Tuple[Any, str]:
from google.auth.transport.requests import Request # type: ignore[import-untyped]
from google.auth.credentials import Credentials # type: ignore[import-untyped]
import google.auth as google_auth

credentials, project_id = google_auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)

credentials.refresh(Request())

if not project_id:
raise ValueError("Could not resolve project_id")

if not isinstance(project_id, str):
raise TypeError(
f"Expected project_id to be a str but got {type(project_id)}"
)

return credentials, project_id

def refresh_auth(self, credentials: Any) -> None:
from google.auth.transport.requests import Request # type: ignore[import-untyped]

credentials.refresh(Request())

def _prepare_request(self, request: httpx.Request) -> None:
access_token = self._ensure_access_token()

if request.headers.get("Authorization"):
# already authenticated, nothing for us to do
return

request.headers["Authorization"] = f"Bearer {access_token}"

def _ensure_access_token(self) -> str:
if self.access_token is not None:
return self.access_token

if not self._credentials:
self._credentials, project_id = self.load_auth()
if not self.project_id:
self.project_id = project_id
else:
self.refresh_auth(self._credentials)

if not self._credentials.token:
raise RuntimeError("Could not resolve API token from the environment")

assert isinstance(self._credentials.token, str)
return self._credentials.token

def image_generation(
self,
prompt: str,
vertex_project: str,
vertex_location: str,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
model_response=None,
aimg_generation=False,
):
if aimg_generation == True:
response = self.aimage_generation(
prompt=prompt,
vertex_project=vertex_project,
vertex_location=vertex_location,
model=model,
client=client,
optional_params=optional_params,
timeout=timeout,
logging_obj=logging_obj,
model_response=model_response,
)
return response

async def aimage_generation(
self,
prompt: str,
vertex_project: str,
vertex_location: str,
model_response: litellm.ImageResponse,
model: Optional[
str
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
client: Optional[AsyncHTTPHandler] = None,
optional_params: Optional[dict] = None,
timeout: Optional[int] = None,
logging_obj=None,
):
response = None
if client is None:
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
_httpx_timeout = httpx.Timeout(timeout)
_params["timeout"] = _httpx_timeout
else:
_params["timeout"] = httpx.Timeout(timeout=600.0, connect=5.0)

self.async_handler = AsyncHTTPHandler(**_params) # type: ignore
else:
self.async_handler = client # type: ignore

# make POST request to
# https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:predict"

"""
Docs link: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
curl -X POST \
-H "Authorization: Bearer $(gcloud auth print-access-token)" \
-H "Content-Type: application/json; charset=utf-8" \
-d {
"instances": [
{
"prompt": "a cat"
}
],
"parameters": {
"sampleCount": 1
}
} \
"https://us-central1-aiplatform.googleapis.com/v1/projects/PROJECT_ID/locations/us-central1/publishers/google/models/imagegeneration:predict"
"""
auth_header = self._ensure_access_token()
optional_params = optional_params or {
"sampleCount": 1
} # default optional params

request_data = {
"instances": [{"prompt": prompt}],
"parameters": optional_params,
}

request_str = f"\n curl -X POST \\\n -H \"Authorization: Bearer {auth_header[:10] + 'XXXXXXXXXX'}\" \\\n -H \"Content-Type: application/json; charset=utf-8\" \\\n -d {request_data} \\\n \"{url}\""
logging_obj.pre_call(
input=prompt,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
"request_str": request_str,
},
)

response = await self.async_handler.post(
url=url,
headers={
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {auth_header}",
},
data=json.dumps(request_data),
)

if response.status_code != 200:
raise Exception(f"Error: {response.status_code} {response.text}")
"""
Vertex AI Image generation response example:
{
"predictions": [
{
"bytesBase64Encoded": "BASE64_IMG_BYTES",
"mimeType": "image/png"
},
{
"mimeType": "image/png",
"bytesBase64Encoded": "BASE64_IMG_BYTES"
}
]
}
"""

_json_response = response.json()
_predictions = _json_response["predictions"]

_response_data: List[litellm.ImageObject] = []
for _prediction in _predictions:
_bytes_base64_encoded = _prediction["bytesBase64Encoded"]
image_object = litellm.ImageObject(b64_json=_bytes_base64_encoded)
_response_data.append(image_object)

model_response.data = _response_data

return model_response
32 changes: 32 additions & 0 deletions litellm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
Expand Down Expand Up @@ -118,6 +119,7 @@
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
vertex_chat_completion = VertexLLM()
####### COMPLETION ENDPOINTS ################


Expand Down Expand Up @@ -3854,6 +3856,36 @@ def image_generation(
model_response=model_response,
aimg_generation=aimg_generation,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)
or optional_params.pop("vertex_ai_project", None)
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.pop("vertex_location", None)
or optional_params.pop("vertex_ai_location", None)
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
model_response = vertex_chat_completion.image_generation(
model=model,
prompt=prompt,
timeout=timeout,
logging_obj=litellm_logging_obj,
optional_params=optional_params,
model_response=model_response,
vertex_project=vertex_ai_project,
vertex_location=vertex_ai_location,
aimg_generation=aimg_generation,
)

return model_response
except Exception as e:
## Map to OpenAI Exception
Expand Down
33 changes: 33 additions & 0 deletions litellm/tests/test_image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,36 @@ async def test_aimage_generation_bedrock_with_optional_params():
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")


@pytest.mark.asyncio
async def test_aimage_generation_vertex_ai():
from test_amazing_vertex_completion import load_vertex_ai_credentials

litellm.set_verbose = True

load_vertex_ai_credentials()
try:
response = await litellm.aimage_generation(
prompt="An olympic size swimming pool",
model="vertex_ai/imagegeneration@006",
vertex_ai_project="adroit-crow-413218",
vertex_ai_location="us-central1",
n=1,
)
assert response.data is not None
assert len(response.data) > 0

for d in response.data:
assert isinstance(d, litellm.ImageObject)
print("data in response.data", d)
assert d.b64_json is not None
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
Loading

0 comments on commit 91a89eb

Please sign in to comment.