Skip to content

Commit

Permalink
feat: use fallback text if generation result contains card content only
Browse files Browse the repository at this point in the history
close #17
  • Loading branch information
HanaokaYuzu committed May 30, 2024
1 parent 2fab029 commit fa54d00
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 13 deletions.
17 changes: 12 additions & 5 deletions src/gemini_webapi/client.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import json
import functools
import asyncio
import functools
import json
import re
from asyncio import Task
from pathlib import Path
from typing import Any, Optional

from httpx import AsyncClient, ReadTimeout

from .types import WebImage, GeneratedImage, Candidate, ModelOutput
from .exceptions import AuthError, APIError, TimeoutError, GeminiError
from .constants import Endpoint, Headers
from .exceptions import AuthError, APIError, TimeoutError, GeminiError
from .types import WebImage, GeneratedImage, Candidate, ModelOutput
from .utils import (
upload_file,
rotate_1psidts,
Expand Down Expand Up @@ -359,6 +360,10 @@ async def generate_content(
try:
candidates = []
for candidate in body[4]:
text = candidate[1][0]
if re.match(r"^http://googleusercontent.com/card_content/\d+$", text):
text = candidate[22] and candidate[22][0] or text

web_images = (
candidate[12]
and candidate[12][1]
Expand All @@ -373,6 +378,7 @@ async def generate_content(
]
or []
)

generated_images = (
candidate[12]
and candidate[12][7]
Expand All @@ -391,10 +397,11 @@ async def generate_content(
]
or []
)

candidates.append(
Candidate(
rcid=candidate[0],
text=candidate[1][0],
text=text,
web_images=web_images,
generated_images=generated_images,
)
Expand Down
16 changes: 8 additions & 8 deletions tests/test_client_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,14 @@ async def test_upload_image(self):
response = await self.geminiclient.generate_content(
"Describe these images", images=[Path("assets/banner.png"), "assets/favicon.png"]
)
self.assertTrue(response.text)
logger.debug(response.text)

@logger.catch(reraise=True)
async def test_continuous_conversation(self):
chat = self.geminiclient.start_chat()
response1 = await chat.send_message("Briefly introduce Europe")
self.assertTrue(response1.text)
logger.debug(response1.text)
response2 = await chat.send_message("What's the population there?")
self.assertTrue(response2.text)
logger.debug(response2.text)

@logger.catch(reraise=True)
Expand All @@ -61,12 +58,10 @@ async def test_chatsession_with_image(self):
chat = self.geminiclient.start_chat()
response1 = await chat.send_message(
"What's the difference between these two images?",
images=["assets/pic1.png", "assets/pic2.png"],
images=["assets/banner.png", "assets/favicon.png"],
)
self.assertTrue(response1.text)
logger.debug(response1.text)
response2 = await chat.send_message("Tell me more.")
self.assertTrue(response2.text)
logger.debug(response2.text)

@logger.catch(reraise=True)
Expand All @@ -89,20 +84,25 @@ async def test_ai_image_generation(self):
self.assertTrue(image.url)
logger.debug(image)

@logger.catch(reraise=True)
async def test_card_content(self):
response = await self.geminiclient.generate_content(
"How is today's weather?"
)
logger.debug(response.text)

@logger.catch(reraise=True)
async def test_extension_google_workspace(self):
response = await self.geminiclient.generate_content(
"@Gmail What's the latest message in my mailbox?"
)
self.assertTrue(response.text)
logger.debug(response)

@logger.catch(reraise=True)
async def test_extension_youtube(self):
response = await self.geminiclient.generate_content(
"@Youtube What's the lastest activity of Taylor Swift?"
)
self.assertTrue(response.text)
logger.debug(response)

@logger.catch(reraise=True)
Expand Down

0 comments on commit fa54d00

Please sign in to comment.