Skip to content

Commit

Permalink
feat: less strict typings (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
David-Kristek committed Feb 7, 2024
1 parent a7333f4 commit fd32f12
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 49 deletions.
10 changes: 0 additions & 10 deletions src/genai/_utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,3 @@ def hash_params(**kwargs):

def single_execution(fn: T) -> T:
return functools.cache(fn) # type: ignore


def prompts_to_strings(prompts: Union[list[str], str, None]) -> list[str]:
if prompts is None:
return []

if not isinstance(prompts, list):
return [prompts]

return prompts
12 changes: 6 additions & 6 deletions src/genai/extensions/huggingface/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
)


from typing import List, Optional
from typing import Optional

from genai._utils.general import to_model_instance
from genai.client import Client
Expand All @@ -25,7 +25,7 @@ def __init__(
parameters: Optional[TextGenerationParameters] = None,
chat_prompt_template: Optional[str] = None,
run_prompt_template: Optional[str] = None,
additional_tools: Optional[List[str]] = None,
additional_tools: Optional[list[str]] = None,
):
super().__init__(
chat_prompt_template=chat_prompt_template,
Expand All @@ -36,14 +36,14 @@ def __init__(
self.model = model
self.parameters = parameters

def generate_one(self, prompt, stop):
def generate_one(self, prompt: str, stop: Optional[list[str]] = None):
return self._generate([prompt], stop)[0]

def generate_many(self, prompts, stop):
def generate_many(self, prompts: list[str], stop: Optional[list[str]] = None):
return self._generate(prompts, stop)

def _generate(self, prompts: List[str], stop: Optional[List[str]] = None) -> List[str]:
final_results: List[str] = []
def _generate(self, prompts: list[str], stop: Optional[list[str]] = None) -> list[str]:
final_results: list[str] = []
if len(prompts) == 0:
return final_results

Expand Down
4 changes: 2 additions & 2 deletions src/genai/extensions/langchain/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from pathlib import Path
from typing import Any, Dict, Iterator, Optional, Union
from typing import Any, Iterator, Optional, Union

from pydantic import ConfigDict
from pydantic.v1 import validator
Expand Down Expand Up @@ -137,7 +137,7 @@ def load_from_file(cls, file: Union[str, Path], *, client: Client):
return cls(**config, client=client)

@property
def _identifying_params(self) -> Dict[str, Any]:
def _identifying_params(self) -> dict[str, Any]:
return {
"model_id": self.model_id,
"prompt_id": self.prompt_id,
Expand Down
16 changes: 8 additions & 8 deletions src/genai/extensions/langchain/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from functools import partial
from pathlib import Path
from typing import Any, Iterator, List, Optional, Union
from typing import Any, Iterator, Optional, Union

from pydantic import ConfigDict
from pydantic.v1 import validator
Expand Down Expand Up @@ -137,7 +137,7 @@ def _prepare_stream_request(self, **kwargs):
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
Expand All @@ -146,8 +146,8 @@ def _call(

def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
Expand Down Expand Up @@ -207,8 +207,8 @@ def _generate(

async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
prompts: list[str],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
Expand All @@ -219,7 +219,7 @@ async def _agenerate(
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[CustomGenerationChunk]:
Expand Down Expand Up @@ -261,5 +261,5 @@ def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
)
)

def get_token_ids(self, text: str) -> List[int]:
def get_token_ids(self, text: str) -> list[int]:
raise NotImplementedError("API does not support returning token ids.")
4 changes: 2 additions & 2 deletions src/genai/extensions/llama_index/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from functools import partial
from typing import Any, List, Optional, Sequence
from typing import Any, Optional, Sequence

from genai import Client
from genai._types import EnumLike
Expand Down Expand Up @@ -57,7 +57,7 @@ def to_genai_message(message: ChatMessage) -> BaseMessage:
raise ValueError(f"Got unknown message type {message}")


def to_genai_messages(messages: Sequence[ChatMessage]) -> List[BaseMessage]:
def to_genai_messages(messages: Sequence[ChatMessage]) -> list[BaseMessage]:
return [to_genai_message(msg) for msg in messages]


Expand Down
8 changes: 1 addition & 7 deletions src/genai/prompt/prompt_service.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from typing import Optional, Union

from genai._types import EnumLike, EnumLikeOrEnumLikeList, ModelLike
from genai._utils.general import (
cast_list,
to_enum,
to_enum_optional,
to_model_instance,
to_model_optional,
)
from genai._utils.general import cast_list, to_enum, to_enum_optional, to_model_instance, to_model_optional
from genai._utils.service import (
BaseService,
BaseServiceConfig,
Expand Down
9 changes: 3 additions & 6 deletions src/genai/text/generation/generation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
from genai._types import ModelLike
from genai._utils.api_client import ApiClient
from genai._utils.async_executor import execute_async
from genai._utils.general import (
prompts_to_strings,
to_model_instance,
to_model_optional,
)
from genai._utils.general import cast_list, to_model_instance, to_model_optional
from genai._utils.service import (
BaseService,
BaseServiceConfig,
Expand Down Expand Up @@ -47,6 +43,7 @@

__all__ = ["GenerationService", "BaseConfig", "BaseServices", "CreateExecutionOptions"]


from genai._utils.http_client.retry_transport import BaseRetryTransport
from genai._utils.limiters.base_limiter import BaseLimiter
from genai._utils.limiters.external_limiter import ConcurrencyResponse, ExternalLimiter
Expand Down Expand Up @@ -138,7 +135,7 @@ def create(
To limit number of concurrent requests or change execution procedure, see 'execute_options' parameter.
"""
metadata = get_service_action_metadata(self.create)
prompts: list[str] = prompts_to_strings(inputs)
prompts = cast_list(inputs) if inputs else []
parameters_formatted = to_model_optional(parameters, TextGenerationParameters)
moderations_formatted = to_model_optional(moderations, ModerationParameters)
template_formatted = to_model_optional(data, PromptTemplateData)
Expand Down
8 changes: 2 additions & 6 deletions src/genai/text/moderation/moderation_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,7 @@

from genai._types import ModelLike
from genai._utils.async_executor import execute_async
from genai._utils.general import (
prompts_to_strings,
to_model_instance,
to_model_optional,
)
from genai._utils.general import cast_list, to_model_instance, to_model_optional
from genai._utils.http_client.httpx_client import AsyncHttpxClient
from genai._utils.service import (
BaseService,
Expand Down Expand Up @@ -109,7 +105,7 @@ async def handler(input: str, http_client: AsyncHttpxClient, *_) -> TextModerati
return TextModerationCreateResponse(**http_response.json())

yield from execute_async(
inputs=prompts_to_strings(inputs),
inputs=cast_list(inputs),
handler=handler,
http_client=self._get_async_http_client,
ordered=execution_options_formatted.ordered,
Expand Down
3 changes: 1 addition & 2 deletions src/genai/text/tokenization/tokenization_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
batch_by_size_constraint,
cast_list,
merge_objects,
prompts_to_strings,
to_model_instance,
)
from genai._utils.http_client.httpx_client import AsyncHttpxClient
Expand Down Expand Up @@ -80,7 +79,7 @@ def create(
options = to_model_instance([self.config.create_execution_options, execution_options], CreateExecutionOptions)
parameters_validated = to_model_instance(parameters, TextTokenizationParameters)
batches = batch_by_size_constraint(
prompts_to_strings(prompts),
prompts,
max_size_bytes=self._api_client.config.max_payload_size_bytes,
max_chunk_size=options.batch_size or len(prompts),
)
Expand Down

0 comments on commit fd32f12

Please sign in to comment.