Skip to content

Commit

Permalink
feat: add support for pathlib.Path objects as generation input
Browse files Browse the repository at this point in the history
  • Loading branch information
HanaokaYuzu committed May 25, 2024
1 parent e978a5f commit b5388e2
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ asyncio.run(main())
### Generate contents from image

Gemini supports image recognition and generating contents from images. Optionally, you can pass images in a list of file data in `bytes` or their paths in `str` to `GeminiClient.generate_content` together with text prompt.
Gemini supports image recognition and generating contents from images. Optionally, you can pass images in a list of file data in `bytes` or their paths in `str` or `pathlib.Path` to `GeminiClient.generate_content` together with text prompt.

```python
async def main():
Expand Down
13 changes: 7 additions & 6 deletions src/gemini_webapi/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import functools
import asyncio
from asyncio import Task
from pathlib import Path
from typing import Any, Optional

from httpx import AsyncClient, ReadTimeout
Expand Down Expand Up @@ -252,7 +253,7 @@ async def start_auto_refresh(self) -> None:
async def generate_content(
self,
prompt: str,
images: list[bytes | str] | None = None,
images: list[bytes | str | Path] | None = None,
chat: Optional["ChatSession"] = None,
) -> ModelOutput:
"""
Expand All @@ -262,8 +263,8 @@ async def generate_content(
----------
prompt: `str`
Prompt provided by user.
images: `list[bytes | str]`, optional
List of image file data in bytes or file paths in string.
images: `list[bytes | str | Path]`, optional
List of image file paths or file data in bytes.
chat: `ChatSession`, optional
Chat data to retrieve conversation history. If None, will automatically generate a new chat id when sending post request.
Expand Down Expand Up @@ -485,7 +486,7 @@ def __setattr__(self, name: str, value: Any) -> None:
self.rcid = value.rcid

async def send_message(
self, prompt: str, images: list[bytes | str] | None = None,
self, prompt: str, images: list[bytes | str | Path] | None = None,
) -> ModelOutput:
"""
Generates contents with prompt.
Expand All @@ -495,8 +496,8 @@ async def send_message(
----------
prompt: `str`
Prompt provided by user.
images: `list[bytes | str]`, optional
List of image file data in bytes or file paths in string.
images: `list[bytes | str | Path]`, optional
List of image file paths or file data in bytes.
Returns
-------
Expand Down
8 changes: 5 additions & 3 deletions src/gemini_webapi/utils/upload_file.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from pathlib import Path

from httpx import AsyncClient
from pydantic import validate_call

from ..constants import Endpoint, Headers


@validate_call
async def upload_file(file: bytes | str, proxies: dict | None = None) -> str:
async def upload_file(file: bytes | str | Path, proxies: dict | None = None) -> str:
"""
Upload a file to Google's server and return its identifier.
Parameters
----------
file : `bytes` | `str`
file : `bytes` | `str` | `Path`
File data in bytes, or path to the file to be uploaded.
proxies: `dict`, optional
Dict of proxies.
Expand All @@ -28,7 +30,7 @@ async def upload_file(file: bytes | str, proxies: dict | None = None) -> str:
If the upload request failed.
"""

if isinstance(file, str):
if not isinstance(file, bytes):
with open(file, "rb") as f:
file = f.read()

Expand Down
3 changes: 2 additions & 1 deletion tests/test_client_features.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import unittest
import logging
from pathlib import Path

from loguru import logger

Expand Down Expand Up @@ -29,7 +30,7 @@ async def test_successful_request(self):
@logger.catch(reraise=True)
async def test_upload_image(self):
response = await self.geminiclient.generate_content(
"Describe the image", images=["assets/banner.png"]
"Describe these images", images=[Path("assets/banner.png"), "assets/favicon.png"]
)
self.assertTrue(response.text)
logger.debug(response.text)
Expand Down

0 comments on commit b5388e2

Please sign in to comment.