Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gemini.py): support google-genai system instruction #2925

Merged
merged 1 commit into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions litellm/llms/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from litellm.utils import ModelResponse, get_secret, Choices, Message, Usage
import litellm
import sys, httpx
from .prompt_templates.factory import prompt_factory, custom_prompt
from .prompt_templates.factory import prompt_factory, custom_prompt, get_system_prompt
from packaging.version import Version


class GeminiError(Exception):
Expand Down Expand Up @@ -103,6 +104,13 @@ async def __aiter__(self):
break


def supports_system_instruction():
import google.generativeai as genai

gemini_pkg_version = Version(genai.__version__)
return gemini_pkg_version >= Version("0.5.0")


def completion(
model: str,
messages: list,
Expand All @@ -124,7 +132,7 @@ def completion(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
genai.configure(api_key=api_key)

system_prompt = ""
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
Expand All @@ -135,6 +143,7 @@ def completion(
messages=messages,
)
else:
system_prompt, messages = get_system_prompt(messages=messages)
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="gemini"
)
Expand Down Expand Up @@ -166,7 +175,11 @@ def completion(
)
## COMPLETION CALL
try:
_model = genai.GenerativeModel(f"models/{model}")
_params = {"model_name": "models/{}".format(model)}
_system_instruction = supports_system_instruction()
if _system_instruction and len(system_prompt) > 0:
_params["system_instruction"] = system_prompt
_model = genai.GenerativeModel(**_params)
if stream == True:
if acompletion == True:

Expand Down Expand Up @@ -213,11 +226,12 @@ async def async_streaming():
encoding=encoding,
)
else:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
)
params = {
"contents": prompt,
"generation_config": genai.types.GenerationConfig(**inference_params),
"safety_settings": safety_settings,
}
response = _model.generate_content(**params)
except Exception as e:
raise GeminiError(
message=str(e),
Expand Down
15 changes: 14 additions & 1 deletion litellm/llms/prompt_templates/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,7 +959,20 @@ def parse_xml_params(xml_content, json_schema: Optional[dict] = None):
return params


###
### GEMINI HELPER FUNCTIONS ###


def get_system_prompt(messages):
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
return system_prompt, messages


def convert_openai_message_to_cohere_tool_result(message):
Expand Down
2 changes: 1 addition & 1 deletion litellm/llms/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def completion(
return async_completion(**data)

if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
prompt, images = _gemini_vision_convert_messages(messages=messages)
Expand Down
Loading