diff --git a/effectful/handlers/llm/encoding.py b/effectful/handlers/llm/encoding.py new file mode 100644 index 00000000..fe8b7e21 --- /dev/null +++ b/effectful/handlers/llm/encoding.py @@ -0,0 +1,239 @@ +import base64 +import io +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable + +import pydantic +from litellm import ( + ChatCompletionImageUrlObject, + OpenAIMessageContentListBlock, +) +from PIL import Image + +from effectful.ops.syntax import _CustomSingleDispatchCallable + + +def _pil_image_to_base64_data(pil_image: Image.Image) -> str: + buf = io.BytesIO() + pil_image.save(buf, format="PNG") + return base64.b64encode(buf.getvalue()).decode("utf-8") + + +def _pil_image_to_base64_data_uri(pil_image: Image.Image) -> str: + return f"data:image/png;base64,{_pil_image_to_base64_data(pil_image)}" + + +class EncodableAs[T, U](ABC): + t: type[U] + + def __init__(self, *args, **kwargs): + pass + + @classmethod + @abstractmethod + def encode(cls, vl: T) -> U: + pass + + @classmethod + @abstractmethod + def decode(cls, vl: U) -> T: + pass + + @classmethod + def serialize(cls, value: U) -> list[OpenAIMessageContentListBlock]: + return [{"type": "text", "text": str(value)}] + + +class Encodable[T](EncodableAs[T, type]): + t = type + + +@_CustomSingleDispatchCallable +def type_to_encodable_type[T]( + __dispatch: Callable[[type[T]], Callable[..., Encodable[T]]], ty: type[T] +) -> Encodable[T]: + origin_ty = typing.get_origin(ty) or ty + return __dispatch(origin_ty)(ty) + + +@type_to_encodable_type.register(object) +def _type_encodable_type_base[T](ty: type[T]) -> Encodable[T]: + class BaseEncodable(EncodableAs[T, T]): + t: type[T] = ty + + @classmethod + def encode(cls, vl: T) -> T: + return vl + + @classmethod + def decode(cls, vl: T) -> T: + return vl + + return typing.cast(Encodable[T], BaseEncodable()) + + +@type_to_encodable_type.register(pydantic.BaseModel) +def _type_encodable_type_pydantic_base_model[T: pydantic.BaseModel]( + ty: type[T], +) -> Encodable[T]: + class EncodablePydanticBaseModel(EncodableAs[T, T]): + t: type[T] = ty + + @classmethod + def decode(cls, vl: T) -> T: + return vl + + @classmethod + def encode(cls, vl: T) -> T: + return vl + + @classmethod + def serialize(cls, vl: T) -> list[OpenAIMessageContentListBlock]: + return [{"type": "text", "text": vl.model_dump_json()}] + + return typing.cast(Encodable[T], EncodablePydanticBaseModel()) + + +@type_to_encodable_type.register(Image.Image) +class EncodableImage(EncodableAs[Image.Image, ChatCompletionImageUrlObject]): + t = ChatCompletionImageUrlObject + + @classmethod + def encode(cls, image: Image.Image) -> ChatCompletionImageUrlObject: + return { + "detail": "auto", + "url": _pil_image_to_base64_data_uri(image), + } + + @classmethod + def decode(cls, image: ChatCompletionImageUrlObject) -> Image.Image: + image_url = image["url"] + if not image_url.startswith("data:image/"): + raise RuntimeError( + f"expected base64 encoded image as data uri, received {image_url}" + ) + data = image_url.split(",")[1] + return Image.open(fp=io.BytesIO(base64.b64decode(data))) + + @classmethod + def serialize( + cls, value: ChatCompletionImageUrlObject + ) -> list[OpenAIMessageContentListBlock]: + return [{"type": "image_url", "image_url": value}] + + +@type_to_encodable_type.register(tuple) +def _type_encodable_type_tuple[T](ty: type[T]) -> Encodable[T]: + args = typing.get_args(ty) + + # Handle empty tuple, or tuple with no args + if not args or args == ((),): + return _type_encodable_type_base(ty) + + # Create encoders for each element type + element_encoders = [type_to_encodable_type(arg) for arg in args] + + # Check if any element type is Image.Image + has_image = any(arg is Image.Image for arg in args) + + encoded_ty: type[typing.Any] = typing.cast( + type[typing.Any], + tuple[*(enc.t for enc in element_encoders)], # type: ignore + ) + + class TupleEncodable(EncodableAs[T, typing.Any]): + t: type[typing.Any] = encoded_ty + + @classmethod + def encode(cls, t: T) -> typing.Any: + if not isinstance(t, tuple): + raise TypeError(f"Expected tuple, got {type(t)}") + if len(t) != len(element_encoders): + raise ValueError( + f"Tuple length {len(t)} does not match expected length {len(element_encoders)}" + ) + return tuple([enc.encode(elem) for enc, elem in zip(element_encoders, t)]) + + @classmethod + def decode(cls, t: typing.Any) -> T: + if len(t) != len(element_encoders): + raise ValueError( + f"tuple length {len(t)} does not match expected length {len(element_encoders)}" + ) + decoded_elements: list[typing.Any] = [ + enc.decode(elem) for enc, elem in zip(element_encoders, t) + ] + return typing.cast(T, tuple(decoded_elements)) + + @classmethod + def serialize(cls, value: typing.Any) -> list[OpenAIMessageContentListBlock]: + if has_image: + # If tuple contains images, serialize each element and flatten the results + result: list[OpenAIMessageContentListBlock] = [] + if not isinstance(value, tuple): + raise TypeError(f"Expected tuple, got {type(value)}") + if len(value) != len(element_encoders): + raise ValueError( + f"Tuple length {len(value)} does not match expected length {len(element_encoders)}" + ) + for enc, elem in zip(element_encoders, value): + result.extend(enc.serialize(elem)) + return result + else: + return super().serialize(value) + + return typing.cast(Encodable[T], TupleEncodable()) + + +@type_to_encodable_type.register(list) +def _type_encodable_type_list[T](ty: type[T]) -> Encodable[T]: + args = typing.get_args(ty) + + # Handle unparameterized list (list without type args) + if not args: + return _type_encodable_type_base(ty) + + # Get the element type (first type argument) + element_ty = args[0] + element_encoder = type_to_encodable_type(element_ty) + + # Check if element type is Image.Image + has_image = element_ty is Image.Image + + # Build the encoded type (list of encoded element type) - runtime-created, use Any + encoded_ty: type[typing.Any] = typing.cast( + type[typing.Any], + list[element_encoder.t], # type: ignore + ) + + class ListEncodable(EncodableAs[T, typing.Any]): + t: type[typing.Any] = encoded_ty + + @classmethod + def encode(cls, t: T) -> typing.Any: + if not isinstance(t, list): + raise TypeError(f"Expected list, got {type(t)}") + return [element_encoder.encode(elem) for elem in t] + + @classmethod + def decode(cls, t: typing.Any) -> T: + decoded_elements: list[typing.Any] = [ + element_encoder.decode(elem) for elem in t + ] + return typing.cast(T, decoded_elements) + + @classmethod + def serialize(cls, value: typing.Any) -> list[OpenAIMessageContentListBlock]: + if has_image: + # If list contains images, serialize each element and flatten the results + result: list[OpenAIMessageContentListBlock] = [] + if not isinstance(value, list): + raise TypeError(f"Expected list, got {type(value)}") + for elem in value: + result.extend(element_encoder.serialize(elem)) + return result + else: + return super().serialize(value) + + return typing.cast(Encodable[T], ListEncodable()) diff --git a/effectful/handlers/llm/providers.py b/effectful/handlers/llm/providers.py index d0e9c1aa..27097f2d 100644 --- a/effectful/handlers/llm/providers.py +++ b/effectful/handlers/llm/providers.py @@ -7,19 +7,20 @@ import string import traceback import typing -from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from collections.abc import Callable, Hashable, Iterable, Mapping from typing import Any, get_type_hints import litellm import pydantic +from effectful.handlers.llm.encoding import type_to_encodable_type + try: from PIL import Image except ImportError: raise ImportError("'pillow' is required to use effectful.handlers.providers") from litellm import ( - ChatCompletionImageObject, Choices, Message, OpenAIChatCompletionToolParam, @@ -44,84 +45,89 @@ def _pil_image_to_base64_data_uri(pil_image: Image.Image) -> str: return f"data:image/png;base64,{_pil_image_to_base64_data(pil_image)}" -def _pil_image_to_openai_image_param( - pil_image: Image.Image, -) -> ChatCompletionImageObject: - return { - "type": "image_url", - "image_url": { - "detail": "auto", - "url": _pil_image_to_base64_data_uri(pil_image), - }, - } - - -@defop -@functools.singledispatch -def format_value(value: Any) -> OpenAIMessageContent: - """Convert a Python value to internal message part representation. - - This function can be extended by registering handlers for - different types using @format_value.register. - - Returns a OpenAIMessageContent - either a string or a list of OpenAIMessageContentListBlock. - """ - return [{"type": "text", "text": str(value)}] - - -@format_value.register(Image.Image) # type: ignore -def _(value: Image.Image) -> OpenAIMessageContent: - return [_pil_image_to_openai_image_param(value)] - - -@format_value.register(str) # type: ignore -def _(value: str) -> OpenAIMessageContent: - return [{"type": "text", "text": value}] - - -@format_value.register(bytes) # type: ignore -def _(value: bytes) -> OpenAIMessageContent: - return [{"type": "text", "text": str(value)}] - - -@format_value.register(Sequence) # type: ignore -def _(values: Sequence) -> OpenAIMessageContent: - if all(isinstance(value, Image.Image) for value in values): - return [_pil_image_to_openai_image_param(value) for value in values] - else: - return [{"type": "text", "text": str(values)}] - - @dataclasses.dataclass class Tool[**P, T]: - parameter_model: type[pydantic.BaseModel] operation: Operation[P, T] name: str + parameter_annotations: dict[str, type] def serialise_return_value(self, value) -> OpenAIMessageContent: """Serializes a value returned by the function into a json format suitable for the OpenAI API.""" sig = inspect.signature(self.operation) - ret_ty = sig.return_annotation - ret_ty_origin = typing.get_origin(ret_ty) or ret_ty + encoded_ty = type_to_encodable_type(sig.return_annotation) + encoded_value = encoded_ty.encode(value) + return encoded_ty.serialize(encoded_value) + + @functools.cached_property + def parameter_model(self) -> type[pydantic.BaseModel]: + fields = { + param_name: type_to_encodable_type(param_type).t + for param_name, param_type in self.parameter_annotations.items() + } + parameter_model = pydantic.create_model( + "Params", + __config__={"extra": "forbid"}, + **fields, # type: ignore + ) + return parameter_model - return format_value.dispatch(ret_ty_origin)(value) # type: ignore + def call_with_json_args( + self, template: Template, json_str: str + ) -> OpenAIMessageContent: + """Implements a roundtrip call to a python function. Input is a json string representing an LLM tool call request parameters. The output is the serialised response to the model.""" + try: + op = self.operation + # build dict of raw encodable types U + raw_args = self.parameter_model.model_validate_json(json_str) + + # use encoders to decode Us to python types T + params: dict[str, Any] = { + param_name: type_to_encodable_type( + self.parameter_annotations[param_name] + ).decode(getattr(raw_args, param_name)) + for param_name in raw_args.model_fields_set + } + + # call tool with python types + result = tool_call( + template, + self.operation, + **params, + ) + # serialize back to U using encoder for return type + sig = inspect.signature(op) + encoded_ty = type_to_encodable_type(sig.return_annotation) + encoded_value = encoded_ty.encode(result) + # serialise back to Json + return encoded_ty.serialize(encoded_value) + except Exception as exn: + return str({"status": "failure", "exception": str(exn)}) @classmethod def of_operation(cls, op: Operation[P, T], name: str): sig = inspect.signature(op) hints = get_type_hints(op) - fields = { - param_name: hints.get(param_name, str) for param_name in sig.parameters - } - - parameter_model = pydantic.create_model( - "Params", __config__={"extra": "forbid"}, **fields - ) + parameter_annotations: dict[str, type] = {} + + for param_name, param in sig.parameters.items(): + # Check if parameter annotation is missing (inspect.Parameter.empty) + if param.annotation is inspect.Parameter.empty: + raise TypeError( + f"Parameter '{param_name}' in operation '{op.__name__}' " + "does not have a type annotation" + ) + # get_type_hints might not include the parameter if annotation is invalid + if param_name not in hints: + raise TypeError( + f"Parameter '{param_name}' in operation '{op.__name__}' " + "does not have a valid type annotation" + ) + parameter_annotations[param_name] = hints[param_name] return cls( - parameter_model=parameter_model, operation=op, name=name, + parameter_annotations=parameter_annotations, ) @property @@ -177,23 +183,21 @@ def push_current_text(): if field_name is not None: obj, _ = self.get_field(field_name, args, kwargs) - obj = self.convert_field(obj, conversion) - - if isinstance(obj, Image.Image): - assert not format_spec, ( - "image template parameters cannot have format specifiers" + part = self.convert_field(obj, conversion) + # special casing for text + if ( + isinstance(part, list) + and len(part) == 1 + and part[0]["type"] == "text" + ): + current_text += self.format_field( + part[0]["text"], format_spec if format_spec else "" ) + elif isinstance(part, list): push_current_text() - prompt_parts.append( - { - "type": "image_url", - "image_url": _pil_image_to_base64_data_uri(obj), - } - ) + prompt_parts.extend(part) else: - current_text += self.format_field( - obj, format_spec if format_spec else "" - ) + prompt_parts.append(part) push_current_text() return prompt_parts @@ -343,24 +347,6 @@ def _retry_completion(self, template: Template, *args, **kwargs) -> Any: raise Exception("Max retries reached") -def _call_tool_with_json_args( - template: Template, tool: Tool, json_str_args: str -) -> OpenAIMessageContent: - try: - args = tool.parameter_model.model_validate_json(json_str_args) - result = tool_call( - template, - tool.operation, - **{ - field: getattr(args, field) - for field in tool.parameter_model.model_fields - }, - ) - return tool.serialise_return_value(result) - except Exception as exn: - return str({"status": "failure", "exception": str(exn)}) - - def _pydantic_model_from_type(typ: type): return pydantic.create_model("Response", value=typ, __config__={"extra": "forbid"}) @@ -375,13 +361,19 @@ def compute_response(template: Template, model_input: list[Any]) -> ModelRespons tools = _tools_of_operations(template.tools) tool_schemas = [t.function_definition for t in tools.values()] - response_format = _pydantic_model_from_type(ret_type) if ret_type != str else None + response_encoding_type: type | None = type_to_encodable_type(ret_type).t + if response_encoding_type == str: + response_encoding_type = None # loop based on: https://cookbook.openai.com/examples/reasoning_function_calls while True: response: ModelResponse = completion( messages=model_input, - response_format=response_format, + response_format=pydantic.create_model( + "Response", value=response_encoding_type, __config__={"extra": "forbid"} + ) + if response_encoding_type + else None, tools=tool_schemas, ) @@ -395,7 +387,7 @@ def compute_response(template: Template, model_input: list[Any]) -> ModelRespons function = tool_call.function function_name = typing.cast(str, function.name) tool = tools[function_name] - tool_result = _call_tool_with_json_args(template, tool, function.arguments) + tool_result = tool.call_with_json_args(template, function.arguments) model_input.append( { "role": "tool", @@ -406,13 +398,9 @@ def compute_response(template: Template, model_input: list[Any]) -> ModelRespons ) -# Note: typing template as Template[P, T] causes term conversion to fail due to -# unification limitations. -@defop def decode_response[**P, T](template: Callable[P, T], response: ModelResponse) -> T: """Decode an LLM response into an instance of the template return type. This operation should raise if the output cannot be decoded. - """ assert isinstance(template, Template) choice: Choices = typing.cast(Choices, response.choices[0]) @@ -422,13 +410,18 @@ def decode_response[**P, T](template: Callable[P, T], response: ModelResponse) - assert result_str ret_type = template.__signature__.return_annotation - if ret_type == str: - return result_str # type: ignore[return-value] + encodable_ty = type_to_encodable_type(ret_type) + + if encodable_ty.t == str: + # if encoding as a type, value is just directly what the llm returned + value = result_str + else: + Result = pydantic.create_model("Result", value=encodable_ty.t) + result = Result.model_validate_json(result_str) + assert isinstance(result, Result) + value = result.value # type: ignore - Result = _pydantic_model_from_type(ret_type) - result = Result.model_validate_json(result_str) - assert isinstance(result, Result) - return result.value + return encodable_ty.decode(value) # type: ignore @defop @@ -441,8 +434,17 @@ def format_model_input[**P, T]( """ bound_args = template.__signature__.bind(*args, **kwargs) bound_args.apply_defaults() + # encode arguments + arguments = {} + for param in bound_args.arguments: + encoder = type_to_encodable_type( + template.__signature__.parameters[param].annotation + ) + encoded = encoder.encode(bound_args.arguments[param]) + arguments[param] = encoder.serialize(encoded) + prompt = _OpenAIPromptFormatter().format_as_messages( - template.__prompt_template__, **bound_args.arguments + template.__prompt_template__, **arguments ) # Note: The OpenAI api only seems to accept images in the 'user' role. The diff --git a/tests/test_handlers_llm_encoding.py b/tests/test_handlers_llm_encoding.py new file mode 100644 index 00000000..ce50979e --- /dev/null +++ b/tests/test_handlers_llm_encoding.py @@ -0,0 +1,709 @@ +from dataclasses import asdict, dataclass +from typing import NamedTuple, TypedDict + +import pydantic +import pytest +from PIL import Image + +from effectful.handlers.llm.encoding import type_to_encodable_type + + +def test_type_to_encodable_type_str(): + encodable = type_to_encodable_type(str) + encoded = encodable.encode("hello") + decoded = encodable.decode(encoded) + assert decoded == "hello" + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": "hello"}) + assert decoded.value == "hello" + + +def test_type_to_encodable_type_int(): + encodable = type_to_encodable_type(int) + encoded = encodable.encode(42) + decoded = encodable.decode(encoded) + assert decoded == 42 + assert isinstance(decoded, int) + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": 42}) + assert decoded.value == 42 + assert isinstance(decoded.value, int) + + +def test_type_to_encodable_type_bool(): + encodable = type_to_encodable_type(bool) + encoded = encodable.encode(True) + decoded = encodable.decode(encoded) + assert decoded is True + assert isinstance(decoded, bool) + encoded_false = encodable.encode(False) + decoded_false = encodable.decode(encoded_false) + assert decoded_false is False + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": True}) + assert decoded.value is True + assert isinstance(decoded.value, bool) + + +def test_type_to_encodable_type_float(): + encodable = type_to_encodable_type(float) + encoded = encodable.encode(3.14) + decoded = encodable.decode(encoded) + assert decoded == 3.14 + assert isinstance(decoded, float) + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": 3.14}) + assert decoded.value == 3.14 + assert isinstance(decoded.value, float) + + +def test_type_to_encodable_type_image(): + encodable = type_to_encodable_type(Image.Image) + image = Image.new("RGB", (10, 10), color="red") + encoded = encodable.encode(image) + assert isinstance(encoded, dict) + assert "url" in encoded + assert "detail" in encoded + assert encoded["detail"] == "auto" + assert encoded["url"].startswith("data:image/png;base64,") + decoded = encodable.decode(encoded) + assert isinstance(decoded, Image.Image) + assert decoded.size == (10, 10) + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": encoded}) + assert decoded.value["url"] == encoded["url"] + assert decoded.value["detail"] == "auto" + + +def test_type_to_encodable_type_image_roundtrip(): + encodable = type_to_encodable_type(Image.Image) + original = Image.new("RGB", (20, 20), color="green") + encoded = encodable.encode(original) + decoded = encodable.decode(encoded) + assert isinstance(decoded, Image.Image) + assert decoded.size == original.size + assert decoded.mode == original.mode + + +def test_type_to_encodable_type_image_decode_invalid_url(): + encodable = type_to_encodable_type(Image.Image) + encoded = {"url": "http://example.com/image.png", "detail": "auto"} + with pytest.raises(RuntimeError, match="expected base64 encoded image as data uri"): + encodable.decode(encoded) + + +def test_type_to_encodable_type_tuple(): + encodable = type_to_encodable_type(tuple[int, str]) + value = (1, "test") + encoded = encodable.encode(value) + decoded = encodable.decode(encoded) + assert decoded == value + assert isinstance(decoded, tuple) + assert decoded[0] == 1 + assert decoded[1] == "test" + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded}) + assert model_instance.value == encoded + assert isinstance(model_instance.value, tuple) + assert model_instance.value[0] == 1 + assert model_instance.value[1] == "test" + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == value + assert isinstance(decoded_from_model, tuple) + + +def test_type_to_encodable_type_tuple_empty(): + encodable = type_to_encodable_type(tuple[()]) + value = () + encoded = encodable.encode(value) + decoded = encodable.decode(encoded) + assert decoded == value + assert isinstance(decoded, tuple) + assert len(decoded) == 0 + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded}) + assert model_instance.value == encoded + assert isinstance(model_instance.value, tuple) + assert len(model_instance.value) == 0 + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == value + assert isinstance(decoded_from_model, tuple) + + +def test_type_to_encodable_type_tuple_three_elements(): + encodable = type_to_encodable_type(tuple[int, str, bool]) + value = (42, "hello", True) + encoded = encodable.encode(value) + decoded = encodable.decode(encoded) + assert decoded == value + assert isinstance(decoded, tuple) + assert decoded[0] == 42 + assert decoded[1] == "hello" + assert decoded[2] is True + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded}) + assert model_instance.value == encoded + assert isinstance(model_instance.value, tuple) + assert model_instance.value[0] == 42 + assert model_instance.value[1] == "hello" + assert model_instance.value[2] is True + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == value + assert isinstance(decoded_from_model, tuple) + + +def test_type_to_encodable_type_list(): + encodable = type_to_encodable_type(list[int]) + value = [1, 2, 3, 4, 5] + encoded = encodable.encode(value) + decoded = encodable.decode(encoded) + assert decoded == value + assert isinstance(decoded, list) + assert all(isinstance(elem, int) for elem in decoded) + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded}) + assert model_instance.value == encoded + assert isinstance(model_instance.value, list) + assert model_instance.value == [1, 2, 3, 4, 5] + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == value + assert isinstance(decoded_from_model, list) + assert all(isinstance(elem, int) for elem in decoded_from_model) + + +def test_type_to_encodable_type_list_str(): + encodable = type_to_encodable_type(list[str]) + value = ["hello", "world", "test"] + encoded = encodable.encode(value) + decoded = encodable.decode(encoded) + assert decoded == value + assert isinstance(decoded, list) + assert all(isinstance(elem, str) for elem in decoded) + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded}) + assert model_instance.value == encoded + assert isinstance(model_instance.value, list) + assert model_instance.value == ["hello", "world", "test"] + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == value + assert isinstance(decoded_from_model, list) + assert all(isinstance(elem, str) for elem in decoded_from_model) + + +def test_type_to_encodable_type_namedtuple(): + class Point(NamedTuple): + x: int + y: int + + encodable = type_to_encodable_type(Point) + point = Point(10, 20) + encoded = encodable.encode(point) + decoded = encodable.decode(encoded) + assert decoded == point + assert isinstance(decoded, Point) + assert decoded.x == 10 + assert decoded.y == 20 + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": {"x": 10, "y": 20}}) + assert decoded.value == point + assert isinstance(decoded.value, Point) + + +def test_type_to_encodable_type_namedtuple_with_str(): + class Person(NamedTuple): + name: str + age: int + + encodable = type_to_encodable_type(Person) + person = Person("Alice", 30) + encoded = encodable.encode(person) + decoded = encodable.decode(encoded) + assert decoded == person + assert isinstance(decoded, Person) + assert decoded.name == "Alice" + assert decoded.age == 30 + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": {"name": "Alice", "age": 30}}) + assert decoded.value == person + assert isinstance(decoded.value, Person) + + +def test_type_to_encodable_type_typeddict(): + class User(TypedDict): + name: str + age: int + + encodable = type_to_encodable_type(User) + user = User(name="Bob", age=25) + encoded = encodable.encode(user) + decoded = encodable.decode(encoded) + assert decoded == user + assert isinstance(decoded, dict) + assert decoded["name"] == "Bob" + assert decoded["age"] == 25 + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": {"name": "Bob", "age": 25}}) + assert decoded.value == user + assert isinstance(decoded.value, dict) + + +def test_type_to_encodable_type_typeddict_optional(): + class Config(TypedDict, total=False): + host: str + port: int + + encodable = type_to_encodable_type(Config) + config = Config(host="localhost", port=8080) + encoded = encodable.encode(config) + decoded = encodable.decode(encoded) + assert decoded == config + assert decoded["host"] == "localhost" + assert decoded["port"] == 8080 + Model = pydantic.create_model("Model", value=encodable.t) + decoded = Model.model_validate({"value": {"host": "localhost", "port": 8080}}) + assert decoded.value == config + assert isinstance(decoded.value, dict) + + +def test_type_to_encodable_type_complex(): + encodable = type_to_encodable_type(complex) + value = 3 + 4j + encoded = encodable.encode(value) + decoded = encodable.decode(encoded) + assert decoded == value + assert isinstance(decoded, complex) + assert decoded.real == 3.0 + assert decoded.imag == 4.0 + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded}) + assert model_instance.value == encoded + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == value + assert isinstance(decoded_from_model, complex) + + +def test_type_to_encodable_type_tuple_of_images(): + encodable = type_to_encodable_type(tuple[Image.Image, Image.Image]) + image1 = Image.new("RGB", (10, 10), color="red") + image2 = Image.new("RGB", (20, 20), color="blue") + value = (image1, image2) + + encoded = encodable.encode(value) + assert isinstance(encoded, tuple) + assert len(encoded) == 2 + assert isinstance(encoded[0], dict) + assert isinstance(encoded[1], dict) + assert "url" in encoded[0] + assert "url" in encoded[1] + assert encoded[0]["url"].startswith("data:image/png;base64,") + assert encoded[1]["url"].startswith("data:image/png;base64,") + + decoded = encodable.decode(encoded) + assert isinstance(decoded, tuple) + assert len(decoded) == 2 + assert isinstance(decoded[0], Image.Image) + assert isinstance(decoded[1], Image.Image) + assert decoded[0].size == (10, 10) + assert decoded[1].size == (20, 20) + + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded}) + assert model_instance.value == encoded + assert isinstance(model_instance.value, tuple) + assert len(model_instance.value) == 2 + assert isinstance(model_instance.value[0], dict) + assert isinstance(model_instance.value[1], dict) + assert model_instance.value[0]["url"] == encoded[0]["url"] + assert model_instance.value[1]["url"] == encoded[1]["url"] + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert isinstance(decoded_from_model, tuple) + assert len(decoded_from_model) == 2 + assert isinstance(decoded_from_model[0], Image.Image) + assert isinstance(decoded_from_model[1], Image.Image) + assert decoded_from_model[0].size == (10, 10) + assert decoded_from_model[1].size == (20, 20) + + # Roundtrip test + original = ( + Image.new("RGB", (15, 15), color="green"), + Image.new("RGB", (25, 25), color="yellow"), + ) + encoded_roundtrip = encodable.encode(original) + decoded_roundtrip = encodable.decode(encoded_roundtrip) + assert isinstance(decoded_roundtrip, tuple) + assert len(decoded_roundtrip) == 2 + assert decoded_roundtrip[0].size == original[0].size + assert decoded_roundtrip[1].size == original[1].size + assert decoded_roundtrip[0].mode == original[0].mode + assert decoded_roundtrip[1].mode == original[1].mode + + +def test_type_to_encodable_type_list_of_images(): + encodable = type_to_encodable_type(list[Image.Image]) + images = [ + Image.new("RGB", (10, 10), color="red"), + Image.new("RGB", (20, 20), color="blue"), + Image.new("RGB", (30, 30), color="green"), + ] + + encoded = encodable.encode(images) + assert isinstance(encoded, list) + assert len(encoded) == 3 + assert all(isinstance(elem, dict) for elem in encoded) + assert all("url" in elem for elem in encoded) + assert all(elem["url"].startswith("data:image/png;base64,") for elem in encoded) + + decoded = encodable.decode(encoded) + assert isinstance(decoded, list) + assert len(decoded) == 3 + assert all(isinstance(elem, Image.Image) for elem in decoded) + assert decoded[0].size == (10, 10) + assert decoded[1].size == (20, 20) + assert decoded[2].size == (30, 30) + + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded}) + assert model_instance.value == encoded + assert isinstance(model_instance.value, list) + assert len(model_instance.value) == 3 + assert all(isinstance(elem, dict) for elem in model_instance.value) + assert all("url" in elem for elem in model_instance.value) + assert model_instance.value[0]["url"] == encoded[0]["url"] + assert model_instance.value[1]["url"] == encoded[1]["url"] + assert model_instance.value[2]["url"] == encoded[2]["url"] + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert isinstance(decoded_from_model, list) + assert len(decoded_from_model) == 3 + assert all(isinstance(elem, Image.Image) for elem in decoded_from_model) + assert decoded_from_model[0].size == (10, 10) + assert decoded_from_model[1].size == (20, 20) + assert decoded_from_model[2].size == (30, 30) + + # Roundtrip test + original = [ + Image.new("RGB", (15, 15), color="yellow"), + Image.new("RGB", (25, 25), color="purple"), + ] + encoded_roundtrip = encodable.encode(original) + decoded_roundtrip = encodable.decode(encoded_roundtrip) + assert isinstance(decoded_roundtrip, list) + assert len(decoded_roundtrip) == 2 + assert decoded_roundtrip[0].size == original[0].size + assert decoded_roundtrip[1].size == original[1].size + assert decoded_roundtrip[0].mode == original[0].mode + assert decoded_roundtrip[1].mode == original[1].mode + + +def test_type_to_encodable_type_dataclass(): + @dataclass + class Point: + x: int + y: int + + encodable = type_to_encodable_type(Point) + point = Point(10, 20) + encoded = encodable.encode(point) + decoded = encodable.decode(encoded) + assert decoded == point + assert isinstance(decoded, Point) + assert decoded.x == 10 + assert decoded.y == 20 + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": asdict(encoded)}) + assert model_instance.value.x == 10 + assert model_instance.value.y == 20 + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == point + assert isinstance(decoded_from_model, Point) + + +def test_type_to_encodable_type_dataclass_with_str(): + @dataclass + class Person: + name: str + age: int + + encodable = type_to_encodable_type(Person) + person = Person("Alice", 30) + encoded = encodable.encode(person) + decoded = encodable.decode(encoded) + assert decoded == person + assert isinstance(decoded, Person) + assert decoded.name == "Alice" + assert decoded.age == 30 + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": asdict(encoded)}) + assert model_instance.value.name == "Alice" + assert model_instance.value.age == 30 + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == person + assert isinstance(decoded_from_model, Person) + + +def test_type_to_encodable_type_dataclass_with_list(): + @dataclass + class Container: + items: list[int] + name: str + + encodable = type_to_encodable_type(Container) + container = Container(items=[1, 2, 3], name="test") + encoded = encodable.encode(container) + decoded = encodable.decode(encoded) + assert decoded == container + assert isinstance(decoded, Container) + assert decoded.items == [1, 2, 3] + assert decoded.name == "test" + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": asdict(encoded)}) + assert model_instance.value.items == [1, 2, 3] + assert model_instance.value.name == "test" + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == container + assert isinstance(decoded_from_model, Container) + + +def test_type_to_encodable_type_dataclass_with_tuple(): + @dataclass + class Pair: + values: tuple[int, str] + count: int + + encodable = type_to_encodable_type(Pair) + pair = Pair(values=(42, "hello"), count=2) + encoded = encodable.encode(pair) + decoded = encodable.decode(encoded) + assert decoded == pair + assert isinstance(decoded, Pair) + assert decoded.values == (42, "hello") + assert decoded.count == 2 + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": asdict(encoded)}) + assert model_instance.value.values == (42, "hello") + assert model_instance.value.count == 2 + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == pair + assert isinstance(decoded_from_model, Pair) + + +def test_type_to_encodable_type_dataclass_with_optional(): + @dataclass + class Config: + host: str + port: int + timeout: float | None = None + + encodable = type_to_encodable_type(Config) + config = Config(host="localhost", port=8080, timeout=5.0) + encoded = encodable.encode(config) + decoded = encodable.decode(encoded) + assert decoded == config + assert isinstance(decoded, Config) + assert decoded.host == "localhost" + assert decoded.port == 8080 + assert decoded.timeout == 5.0 + + # Test with None value + config_none = Config(host="localhost", port=8080, timeout=None) + encoded_none = encodable.encode(config_none) + decoded_none = encodable.decode(encoded_none) + assert decoded_none == config_none + assert decoded_none.timeout is None + + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": asdict(encoded)}) + assert model_instance.value.host == "localhost" + assert model_instance.value.port == 8080 + assert model_instance.value.timeout == 5.0 + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == config + + +def test_type_to_encodable_type_nested_dataclass(): + @dataclass + class Address: + street: str + city: str + + @dataclass + class Person: + name: str + age: int + address: Address + + encodable = type_to_encodable_type(Person) + address = Address(street="123 Main St", city="New York") + person = Person(name="Bob", age=25, address=address) + + encoded = encodable.encode(person) + assert isinstance(encoded, Person) + assert hasattr(encoded, "name") + assert hasattr(encoded, "age") + assert hasattr(encoded, "address") + assert isinstance(encoded.address, Address) + assert encoded.address.street == "123 Main St" + assert encoded.address.city == "New York" + + decoded = encodable.decode(encoded) + assert isinstance(decoded, Person) + assert isinstance(decoded.address, Address) + assert decoded.name == "Bob" + assert decoded.age == 25 + assert decoded.address.street == "123 Main St" + assert decoded.address.city == "New York" + + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": asdict(encoded)}) + assert model_instance.value.name == "Bob" + assert model_instance.value.age == 25 + assert model_instance.value.address.street == "123 Main St" + assert model_instance.value.address.city == "New York" + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == person + assert isinstance(decoded_from_model, Person) + assert isinstance(decoded_from_model.address, Address) + + +def test_type_to_encodable_type_pydantic_model(): + class Point(pydantic.BaseModel): + x: int + y: int + + encodable = type_to_encodable_type(Point) + point = Point(x=10, y=20) + encoded = encodable.encode(point) + decoded = encodable.decode(encoded) + assert decoded == point + assert isinstance(decoded, Point) + assert decoded.x == 10 + assert decoded.y == 20 + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded.model_dump()}) + assert model_instance.value.x == 10 + assert model_instance.value.y == 20 + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == point + assert isinstance(decoded_from_model, Point) + + +def test_type_to_encodable_type_pydantic_model_with_str(): + class Person(pydantic.BaseModel): + name: str + age: int + + encodable = type_to_encodable_type(Person) + person = Person(name="Alice", age=30) + encoded = encodable.encode(person) + decoded = encodable.decode(encoded) + assert decoded == person + assert isinstance(decoded, Person) + assert decoded.name == "Alice" + assert decoded.age == 30 + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded.model_dump()}) + assert model_instance.value.name == "Alice" + assert model_instance.value.age == 30 + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == person + assert isinstance(decoded_from_model, Person) + + +def test_type_to_encodable_type_pydantic_model_with_list(): + class Container(pydantic.BaseModel): + items: list[int] + name: str + + encodable = type_to_encodable_type(Container) + container = Container(items=[1, 2, 3], name="test") + encoded = encodable.encode(container) + decoded = encodable.decode(encoded) + assert decoded == container + assert isinstance(decoded, Container) + assert decoded.items == [1, 2, 3] + assert decoded.name == "test" + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded.model_dump()}) + assert model_instance.value.items == [1, 2, 3] + assert model_instance.value.name == "test" + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == container + assert isinstance(decoded_from_model, Container) + + +def test_type_to_encodable_type_nested_pydantic_model(): + class Address(pydantic.BaseModel): + street: str + city: str + + class Person(pydantic.BaseModel): + name: str + age: int + address: Address + + encodable = type_to_encodable_type(Person) + address = Address(street="123 Main St", city="New York") + person = Person(name="Bob", age=25, address=address) + + encoded = encodable.encode(person) + assert isinstance(encoded, pydantic.BaseModel) + assert hasattr(encoded, "name") + assert hasattr(encoded, "age") + assert hasattr(encoded, "address") + assert isinstance(encoded.address, pydantic.BaseModel) + assert encoded.address.street == "123 Main St" + assert encoded.address.city == "New York" + + decoded = encodable.decode(encoded) + assert isinstance(decoded, Person) + assert isinstance(decoded.address, Address) + assert decoded.name == "Bob" + assert decoded.age == 25 + assert decoded.address.street == "123 Main St" + assert decoded.address.city == "New York" + + # Test with pydantic model validation + Model = pydantic.create_model("Model", value=encodable.t) + model_instance = Model.model_validate({"value": encoded.model_dump()}) + assert model_instance.value.name == "Bob" + assert model_instance.value.age == 25 + assert model_instance.value.address.street == "123 Main St" + assert model_instance.value.address.city == "New York" + # Decode from model + decoded_from_model = encodable.decode(model_instance.value) + assert decoded_from_model == person + assert isinstance(decoded_from_model, Person) + assert isinstance(decoded_from_model.address, Address) diff --git a/tests/test_handlers_llm_provider.py b/tests/test_handlers_llm_provider.py index 3fbd307d..9a0bcc5c 100644 --- a/tests/test_handlers_llm_provider.py +++ b/tests/test_handlers_llm_provider.py @@ -13,7 +13,7 @@ import pytest from PIL import Image -from pydantic import Field +from pydantic import BaseModel, Field from pydantic.dataclasses import dataclass from effectful.handlers.llm import Template @@ -377,3 +377,134 @@ def test_image_input(): handler(LimitLLMCallsHandler(max_calls=3)), ): assert any("smile" in categorise_image(smiley_face()) for _ in range(3)) + + +class BookReview(BaseModel): + """A book review with rating and summary.""" + + title: str = Field(..., description="title of the book") + rating: int = Field(..., description="rating from 1 to 5", ge=1, le=5) + summary: str = Field(..., description="brief summary of the review") + + +@Template.define +def review_book(plot: str) -> BookReview: + """Review a book based on this plot: {plot}""" + raise NotImplementedError + + +class TestPydanticBaseModelReturn: + @requires_openai + def test_pydantic_basemodel_return(self): + plot = "A young wizard discovers he has magical powers and goes to a school for wizards." + + with ( + handler(LiteLLMProvider(model_name="gpt-5-nano")), + handler(LimitLLMCallsHandler(max_calls=1)), + ): + review = review_book(plot) + + assert isinstance(review, BookReview) + assert isinstance(review.title, str) + assert len(review.title) > 0 + assert isinstance(review.rating, int) + assert 1 <= review.rating <= 5 + assert isinstance(review.summary, str) + assert len(review.summary) > 0 + + +class BookRecommendation(BaseModel): + """A book recommendation with details.""" + + title: str = Field(..., description="title of the recommended book") + reason: str = Field(..., description="reason for the recommendation") + + +@defop +def recommend_book_tool(genre: str, explanation: str) -> BookRecommendation: + """Recommend a book based on genre preference. + + Parameters: + - genre: The genre of book to recommend + - explanation: Natural language explanation of the recommendation + """ + raise NotHandled + + +class LoggingBookRecommendationInterpretation(ObjectInterpretation): + """Provides an interpretation for `recommend_book_tool` that tracks recommendations.""" + + recommendation_count: int = 0 + recommendation_results: list[dict] = [] + + @implements(recommend_book_tool) + def _recommend_book_tool(self, genre: str, explanation: str) -> BookRecommendation: + self.recommendation_count += 1 + + # Simple heuristic: recommend based on genre + recommendations = { + "fantasy": BookRecommendation( + title="The Lord of the Rings", reason="Classic fantasy epic" + ), + "sci-fi": BookRecommendation( + title="Dune", reason="Epic science fiction masterpiece" + ), + "mystery": BookRecommendation( + title="The Hound of the Baskervilles", + reason="Classic mystery novel", + ), + } + + recommendation = recommendations.get( + genre.lower(), + BookRecommendation( + title="1984", reason="Thought-provoking dystopian novel" + ), + ) + + self.recommendation_results.append( + { + "genre": genre, + "explanation": explanation, + "recommendation": recommendation, + } + ) + + return recommendation + + +@Template.define(tools=[recommend_book_tool]) +def get_book_recommendation(user_preference: str) -> BookRecommendation: + """Get a book recommendation based on user preference: {user_preference}. + Use the provided tools to make a recommendation. + """ + raise NotHandled + + +class TestPydanticBaseModelToolCalls: + @pytest.mark.parametrize( + "model_name", + [ + pytest.param("gpt-5-nano", marks=requires_openai), + pytest.param("claude-sonnet-4-5-20250929", marks=requires_anthropic), + ], + ) + def test_pydantic_basemodel_tool_calling(self, model_name): + """Test that templates with tools work with Pydantic BaseModel.""" + book_rec_ctx = LoggingBookRecommendationInterpretation() + with ( + handler(LiteLLMProvider(model_name=model_name)), + handler(LimitLLMCallsHandler(max_calls=4)), + handler(book_rec_ctx), + ): + recommendation = get_book_recommendation("I love fantasy novels") + + assert isinstance(recommendation, BookRecommendation) + assert isinstance(recommendation.title, str) + assert len(recommendation.title) > 0 + assert isinstance(recommendation.reason, str) + assert len(recommendation.reason) > 0 + + # Verify the tool was called at least once + assert book_rec_ctx.recommendation_count >= 1 + assert len(book_rec_ctx.recommendation_results) >= 1