Skip to content

Commit

Permalink
chore: fix basic functionalities, add basic unit test, update build info
Browse files Browse the repository at this point in the history
  • Loading branch information
HanaokaYuzu committed Feb 11, 2024
1 parent 24d5efb commit b35d790
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 51 deletions.
11 changes: 11 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"python.testing.unittestArgs": [
"-v",
"-s",
"./tests",
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
}
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ classifiers = [
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
requires-python = ">=3.7"
dependencies = [
Expand Down
138 changes: 100 additions & 38 deletions src/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,29 @@
from loguru import logger

from .consts import HEADERS
from .utils import running
from .types import Image, Candidate, ModelOutput


def running(func) -> callable:
"""
Decorator to check if client is running before making a request.
"""

async def wrapper(self: "GeminiClient", *args, **kwargs):
if not self.running:
await self.init(auto_close=self.auto_close, close_delay=self.close_delay)
if self.running:
return await func(self, *args, **kwargs)

raise Exception(
f"Invalid function call: GeminiClient.{func.__name__}. Client initialization failed."
)
else:
return await func(self, *args, **kwargs)

return wrapper


class GeminiClient:
"""
Async httpx client interface for gemini.google.com
Expand All @@ -26,56 +45,97 @@ class GeminiClient:
Dict of proxies
"""

__slots__ = ["running", "posttoken", "close_task", "client"]
__slots__ = [
"cookies",
"proxy",
"client",
"access_token",
"running",
"auto_close",
"close_delay",
"close_task",
]

def __init__(
self,
secure_1psid: str,
secure_1psidts: Optional[str] = None,
proxy: Optional[dict] = None,
):
self.cookies = {
"__Secure-1PSID": secure_1psid,
"__Secure-1PSIDTS": secure_1psidts,
}
self.proxy = proxy
self.client: AsyncClient | None = None
self.access_token: Optional[str] = None
self.running: bool = False
self.posttoken: Optional[str] = None
self.close_task: Optional[Task] = None
self.client: AsyncClient = AsyncClient(
timeout=20,
proxies=proxy,
follow_redirects=True,
headers=HEADERS,
cookies={
"__Secure-1PSID": secure_1psid,
"__Secure-1PSIDTS": secure_1psidts,
},
)
self.auto_close: bool = False
self.close_delay: int = 0
self.close_task: Task | None = None

async def init(self) -> None:
"""
Get SNlM0e value as posting token. Without this token posting will fail with 400 bad request.
async def init(
self, timeout: float = 30, auto_close: bool = False, close_delay: int = 300
) -> None:
"""
async with self.client:
response = await self.client.get("https://gemini.google.com/chat")
Get SNlM0e value as access token. Without this token posting will fail with 400 bad request.
if response.status_code != 200:
raise Exception(
f"Failed to initiate client. Request failed with status code {response.status_code}"
Parameters
----------
timeout: `int`, optional
Request timeout of the client in seconds. Used to limit the max waiting time when sending a request
auto_close: `bool`, optional
If `True`, the client will close connections and clear resource usage after a certain period
of inactivity. Useful for keep-alive services
close_delay: `int`, optional
Time to wait before auto-closing the client in seconds. Effective only if `auto_close` is `True`
"""
try:
self.client = AsyncClient(
timeout=timeout,
proxies=self.proxy,
follow_redirects=True,
headers=HEADERS,
cookies=self.cookies,
)
else:
match = re.search(r'"SNlM0e":"(.*?)"', response.text)
if match:
self.posttoken = match.group(1)
self.running = True
logger.success("Gemini client initiated successfully.")
else:

response = await self.client.get("https://gemini.google.com/app")

if response.status_code != 200:
raise Exception(
"Failed to initiate client. SNlM0e not found in response, make sure cookie values are valid."
f"Failed to initiate client. Request failed with status code {response.status_code}"
)

async def close_client(self, timeout=300) -> None:
else:
match = re.search(r'"SNlM0e":"(.*?)"', response.text)
if match:
self.access_token = match.group(1)
self.running = True
logger.success("Gemini client initiated successfully.")
else:
raise Exception(
"Failed to initiate client. SNlM0e not found in response, make sure cookie values are valid."
)

self.auto_close = auto_close
self.close_delay = close_delay
if self.auto_close:
await self.reset_close_task()
except Exception:
await self.close(0)
raise

async def close(self, wait: int | None = None) -> None:
"""
Close the client after a certain period of inactivity.
Close the client after a certain period of inactivity, or call manually to close immediately.
Parameters
----------
wait: `int`, optional
Time to wait before closing the client in seconds
"""
await asyncio.sleep(timeout)
await asyncio.sleep(wait is not None and wait or self.close_delay)
await self.client.aclose()
self.running = False

async def reset_close_task(self) -> None:
"""
Expand All @@ -84,7 +144,7 @@ async def reset_close_task(self) -> None:
if self.close_task:
self.close_task.cancel()
self.close_task = None
self.close_task = asyncio.create_task(self.close_client())
self.close_task = asyncio.create_task(self.close())

@running
async def generate_content(
Expand All @@ -108,11 +168,13 @@ async def generate_content(
"""
assert prompt, "Prompt cannot be empty."

await self.reset_close_task()
if self.auto_close:
await self.reset_close_task()

response = await self.client.post(
"https://gemini.google.com/_/GeminiChatUi/data/assistant.lamda.GeminiFrontendService/StreamGenerate",
"https://gemini.google.com/_/BardChatUi/data/assistant.lamda.BardFrontendService/StreamGenerate",
data={
"at": self.posttoken,
"at": self.access_token,
"f.req": json.dumps(
[None, json.dumps([[prompt], None, chat and chat.metadata])]
),
Expand Down
13 changes: 0 additions & 13 deletions src/gemini/utils.py

This file was deleted.

27 changes: 27 additions & 0 deletions tests/test_generate_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
import unittest

from gemini import GeminiClient


class TestGenerateContent(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.geminiclient = GeminiClient(
os.getenv("SECURE_1PSID") or "test_1psid",
os.getenv("SECURE_1PSIDTS") or "test_ipsidts",
)

@unittest.skipIf(
not (os.getenv("SECURE_1PSID") and os.getenv("SECURE_1PSIDTS")),
"Skipping test_success...",
)
async def test_success(self):
await self.geminiclient.init()
self.assertTrue(self.geminiclient.running)

response = await self.geminiclient.generate_content("Hello World!")
self.assertTrue(response.text)


if __name__ == "__main__":
unittest.main()

0 comments on commit b35d790

Please sign in to comment.