Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions api/oas-generator/src/oas_generator/builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from typing import Any, ClassVar

from oas_generator import models as ctx
from oas_generator.naming import IdentifierSanitizer
Expand Down Expand Up @@ -33,6 +33,13 @@ class TypeInfo:
imports: set[str] = field(default_factory=set)


LEDGER_STATE_DELTA_MODEL_NAMES: set[str] = {
"LedgerStateDelta",
"LedgerStateDeltaForTransactionGroup",
"GetTransactionGroupLedgerStateDeltasForRoundResponseModel",
}


class SchemaRegistry:
def __init__(self, spec: ctx.ParsedSpec, sanitizer: IdentifierSanitizer) -> None:
self.spec = spec
Expand Down Expand Up @@ -454,6 +461,12 @@ def _collect_alias_imports(self, annotation: str, entry: SchemaEntry) -> list[st


class OperationBuilder:
RAW_LEDGER_STATE_DELTA_OPERATIONS: ClassVar[set[str]] = {
"GetLedgerStateDelta",
"GetLedgerStateDeltaForTransactionGroup",
"GetTransactionGroupLedgerStateDeltasForRound",
}

def __init__(
self, spec: ctx.ParsedSpec, resolver: TypeResolver, sanitizer: IdentifierSanitizer, registry: SchemaRegistry
) -> None:
Expand Down Expand Up @@ -622,6 +635,17 @@ def _build_response(self, responses: dict[str, Any], operation_id: str) -> ctx.R
if media_type in content:
schema = content[media_type].get("schema")
media_types.append(media_type)
if operation_id in self.RAW_LEDGER_STATE_DELTA_OPERATIONS:
if not media_types:
media_types = ["application/msgpack"]
if "application/msgpack" in media_types:
self.uses_msgpack = True
return ctx.ResponseDescriptor(
type_hint="bytes",
media_types=media_types,
description=payload.get("description"),
is_raw_msgpack=True,
)
if operation_id == "GetBlock" and schema is not None:
self.uses_block_models = True
media_types = media_types or ["application/json"]
Expand Down Expand Up @@ -673,6 +697,7 @@ def build_client_descriptor(
groups = operation_builder.build()
model_builder = ModelBuilder(registry, resolver, sanitizer)
models, enums, aliases = model_builder.build()
models = [model for model in models if model.name not in LEDGER_STATE_DELTA_MODEL_NAMES]
uses_signed_txn = model_builder.uses_signed_transaction or operation_builder.uses_signed_transaction
defaults = {
"algod_client": ("http://localhost:4001", "X-Algo-API-Token"),
Expand All @@ -694,5 +719,4 @@ def build_client_descriptor(
uses_signed_transaction=uses_signed_txn,
uses_msgpack=operation_builder.uses_msgpack,
include_block_models=operation_builder.uses_block_models,
include_ledger_state_delta_models="LedgerStateDelta" in registry.entries,
)
2 changes: 1 addition & 1 deletion api/oas-generator/src/oas_generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class ResponseDescriptor:
media_types: list[str]
description: str | None
is_binary: bool = False
is_raw_msgpack: bool = False
model: str | None = None
list_model: str | None = None
enum: str | None = None
Expand Down Expand Up @@ -146,4 +147,3 @@ class ClientDescriptor:
uses_signed_transaction: bool = False
uses_msgpack: bool = False
include_block_models: bool = False
include_ledger_state_delta_models: bool = False
37 changes: 1 addition & 36 deletions api/oas-generator/src/oas_generator/renderer/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,11 @@ class TemplateRenderer:
"BlockAppEvalDelta",
"BlockStateProofTrackingData",
"BlockStateProofTracking",
"ParticipationUpdates",
"SignedTxnInBlock",
"Block",
"GetBlock",
]
LEDGER_STATE_DELTA_EXPORTS: ClassVar[list[str]] = [
"LedgerTealValue",
"LedgerStateSchema",
"LedgerAppParams",
"LedgerAppLocalState",
"LedgerAppLocalStateDelta",
"LedgerAppParamsDelta",
"LedgerAppResourceRecord",
"LedgerAssetHolding",
"LedgerAssetHoldingDelta",
"LedgerAssetParams",
"LedgerAssetParamsDelta",
"LedgerAssetResourceRecord",
"LedgerVotingData",
"LedgerAccountBaseData",
"LedgerAccountData",
"LedgerBalanceRecord",
"LedgerAccountDeltas",
"LedgerKvValueDelta",
"LedgerIncludedTransactions",
"LedgerModifiedCreatable",
"LedgerAlgoCount",
"LedgerAccountTotals",
"LedgerStateDelta",
]

def __init__(self, template_dir: Path | None = None) -> None:
if template_dir:
Expand Down Expand Up @@ -81,8 +57,6 @@ def render(self, client: ctx.ClientDescriptor, config: GeneratorConfig) -> dict[
files[models_dir / "__init__.py"] = self._render_template("models/__init__.py.j2", context)
files[models_dir / "_serde_helpers.py"] = self._render_template("models/_serde_helpers.py.j2", context)
for model in context["client"].models:
if context["client"].include_ledger_state_delta_models and model.name == "LedgerStateDelta":
continue
model_context = {**context, "model": model}
files[models_dir / f"{model.module_name}.py"] = self._render_template("models/model.py.j2", model_context)
for enum in context["client"].enums:
Expand All @@ -95,10 +69,6 @@ def render(self, client: ctx.ClientDescriptor, config: GeneratorConfig) -> dict[
)
if client.include_block_models:
files[models_dir / "_block.py"] = self._render_template("models/block.py.j2", context)
if client.include_ledger_state_delta_models:
files[models_dir / "_ledger_state_delta.py"] = self._render_template(
"models/ledger_state_delta.py.j2", context
)
files[target / "py.typed"] = ""
return files

Expand All @@ -116,10 +86,6 @@ def _build_context(self, client: ctx.ClientDescriptor, config: GeneratorConfig)
for name in self.BLOCK_MODEL_EXPORTS:
if name not in model_exports:
model_exports.append(name)
if client.include_ledger_state_delta_models:
for name in self.LEDGER_STATE_DELTA_EXPORTS:
if name not in model_exports:
model_exports.append(name)
metadata_usage = self._collect_metadata_usage(client)
model_modules = [{"module": model.module_name, "name": model.name} for model in client.models]
enum_modules = [{"module": enum.module_name, "name": enum.name} for enum in client.enums]
Expand All @@ -140,7 +106,6 @@ def _build_context(self, client: ctx.ClientDescriptor, config: GeneratorConfig)
"needs_datetime": any(model.requires_datetime for model in client.models),
"client_needs_datetime": self._client_requires_datetime(client),
"block_exports": self.BLOCK_MODEL_EXPORTS,
"ledger_state_delta_exports": self.LEDGER_STATE_DELTA_EXPORTS,
"needs_literal": needs_literal,
}

Expand Down
18 changes: 16 additions & 2 deletions api/oas-generator/src/oas_generator/renderer/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ def descriptor_literal(descriptor: object, indent: int = 0) -> str:
if descriptor is None:
return "{}"
fields: dict[str, Any] = {}
for key in ("is_binary", "model", "list_model", "enum", "list_enum"):
bool_fields = ("is_binary", "is_raw_msgpack")
for key in bool_fields:
if getattr(descriptor, key, False):
fields[key] = True
for key in ("model", "list_model", "enum", "list_enum"):
value = getattr(descriptor, key, None)
if value is not None:
fields[key] = value
Expand All @@ -41,15 +45,25 @@ def response_decode_arguments(descriptor: object, indent: int = 0) -> str:
model = getattr(descriptor, "model", None) or getattr(descriptor, "enum", None)
list_model = getattr(descriptor, "list_model", None) or getattr(descriptor, "list_enum", None)
is_binary = bool(getattr(descriptor, "is_binary", False))
raw_msgpack = bool(getattr(descriptor, "is_raw_msgpack", False))
type_hint = getattr(descriptor, "type_hint", None)
parts: list[str] = []
if is_binary:
parts.append("is_binary=True")
if raw_msgpack:
parts.append("raw_msgpack=True")
if model:
parts.append(f"model=models.{model}")
if list_model:
parts.append(f"list_model=models.{list_model}")
if not model and not list_model and not is_binary and type_hint and type_hint != "object":
if (
not model
and not list_model
and not is_binary
and not raw_msgpack
and type_hint
and type_hint != "object"
):
parts.append(f"type_={type_hint}")
if not parts:
return ""
Expand Down
115 changes: 108 additions & 7 deletions api/oas-generator/src/oas_generator/renderer/templates/client.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ ModelT = TypeVar("ModelT")
ListModelT = TypeVar("ListModelT")
PrimitiveT = TypeVar("PrimitiveT")

# Prefixed markers used when converting unhashable msgpack map keys into hashable tuples
_UNHASHABLE_PREFIXES: dict[str, str] = {
"dict": "__dict_key__",
"list": "__list_key__",
"set": "__set_key__",
"generic": "__unhashable__",
}


class {{ client.class_name }}:
def __init__(self, config: ClientConfig | None = None, *, http_client: httpx.Client | None = None) -> None:
Expand Down Expand Up @@ -193,6 +201,7 @@ class {{ client.class_name }}:
*,
model: type[ModelT],
is_binary: bool = False,
raw_msgpack: bool = False,
) -> ModelT:
...

Expand All @@ -203,6 +212,7 @@ class {{ client.class_name }}:
*,
list_model: type[ListModelT],
is_binary: bool = False,
raw_msgpack: bool = False,
) -> list[ListModelT]:
...

Expand All @@ -213,6 +223,7 @@ class {{ client.class_name }}:
*,
type_: type[PrimitiveT],
is_binary: bool = False,
raw_msgpack: bool = False,
) -> PrimitiveT:
...

Expand All @@ -222,6 +233,16 @@ class {{ client.class_name }}:
response: httpx.Response,
*,
is_binary: Literal[True],
raw_msgpack: bool = False,
) -> bytes:
...

@overload
def _decode_response(
self,
response: httpx.Response,
*,
raw_msgpack: Literal[True],
) -> bytes:
...

Expand All @@ -232,6 +253,7 @@ class {{ client.class_name }}:
*,
type_: None = None,
is_binary: bool = False,
raw_msgpack: bool = False,
) -> object:
...

Expand All @@ -243,12 +265,28 @@ class {{ client.class_name }}:
list_model: type[Any] | None = None,
type_: type[Any] | None = None,
is_binary: bool = False,
raw_msgpack: bool = False,
) -> object:
if is_binary:
if is_binary or raw_msgpack:
return response.content
content_type = response.headers.get("content-type", "application/json")
if "msgpack" in content_type:
data = msgpack.unpackb(response.content, raw=True, strict_map_key=False)
# Handle msgpack unpacking with support for unhashable keys
# Use Unpacker for more control over the unpacking process
unpacker = msgpack.Unpacker(
raw=True,
strict_map_key=False,
object_pairs_hook=self._msgpack_pairs_hook,
)
unpacker.feed(response.content)
try:
data = unpacker.unpack()
except TypeError:
# If unpacking fails due to unhashable keys, try without the hook
# and handle in normalization
unpacker = msgpack.Unpacker(raw=True, strict_map_key=False)
unpacker.feed(response.content)
data = unpacker.unpack()
data = self._normalize_msgpack(data)
elif content_type.startswith("application/json"):
data = response.json()
Expand All @@ -262,12 +300,42 @@ class {{ client.class_name }}:
return data
return data

def _normalize_msgpack(self, value: object) -> object:
def _normalize_msgpack(self, value: object) -> object: # noqa: C901, PLR0912
# Handle pairs returned from msgpack_pairs_hook when keys are unhashable
_pair_length = 2
if (
isinstance(value, list)
and value
and isinstance(value[0], tuple | list)
and len(value[0]) == _pair_length
):
# Convert to dict with normalized keys
pairs_dict: dict[object, object] = {}
for pair in value:
if isinstance(pair, tuple | list) and len(pair) == _pair_length:
k, v = pair
# For unhashable keys (like dict keys), use a tuple representation
try:
normalized_key = self._coerce_msgpack_key(k)
pairs_dict[normalized_key] = self._normalize_msgpack(v)
except TypeError:
# Key is unhashable - use tuple representation
normalized_key = ("__unhashable__", id(k), str(k))
pairs_dict[normalized_key] = self._normalize_msgpack(v)
return pairs_dict
if isinstance(value, dict):
normalized: dict[object, object] = {}
for key, item in value.items():
normalized[self._coerce_msgpack_key(key)] = self._normalize_msgpack(item)
return normalized
# Safely normalize maps: coerce string/bytes keys, but tolerate complex/unhashable keys
try:
normalized_dict: dict[object, object] = {}
for key, item in value.items():
normalized_dict[self._coerce_msgpack_key(key)] = self._normalize_msgpack(item)
return normalized_dict
except TypeError:
# Some maps can decode to object/dict keys; keep original keys and
# only normalize values to avoid "unhashable type: 'dict'" errors.
for k, item in list(value.items()):
value[k] = self._normalize_msgpack(item)
return value
if isinstance(value, list):
return [self._normalize_msgpack(item) for item in value]
return value
Expand All @@ -279,3 +347,36 @@ class {{ client.class_name }}:
except UnicodeDecodeError:
return key
return key

def _msgpack_pairs_hook(self, pairs: list[tuple[object, object]] | list[list[object]]) -> dict[object, object]:
# Convert pairs to dict, handling unhashable keys by converting them to hashable tuples
out: dict[object, object] = {}
_hashable_type_tuple = (str, int, float, bool, type(None), bytes)

for k, v in pairs:
if isinstance(k, dict | list | set):
# Convert unhashable key to hashable tuple
hashable_key: tuple[str, object]
if isinstance(k, dict):
try:
hashable_key = (_UNHASHABLE_PREFIXES["dict"], tuple(sorted(k.items())))
except TypeError:
hashable_key = (_UNHASHABLE_PREFIXES["dict"], str(k))
elif isinstance(k, list):
prefix = _UNHASHABLE_PREFIXES["list"]
hashable_key = (prefix, tuple(k) if all(isinstance(x, _hashable_type_tuple) for x in k) else str(k))
else: # set
prefix = _UNHASHABLE_PREFIXES["set"]
if all(isinstance(x, _hashable_type_tuple) for x in k):
hashable_key = (prefix, tuple(sorted(k)))
else:
hashable_key = (prefix, str(k))
out[hashable_key] = v
else:
# Key should be hashable, use as-is
try:
out[k] = v
except TypeError:
# Unexpected unhashable type, convert to tuple
out[(_UNHASHABLE_PREFIXES["generic"], str(type(k).__name__), str(k))] = v
return out
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@


{% if client.uses_signed_transaction %}from algokit_transact.models.signed_transaction import SignedTransaction
{% endif %}{% for item in model_modules %}{% if not (client.include_ledger_state_delta_models and item.name == "LedgerStateDelta") %}from .{{ item.module }} import {{ item.name }}
{% endif %}{% endfor %}{% for item in enum_modules %}from .{{ item.module }} import {{ item.name }}
{% endif %}{% for item in model_modules %}from .{{ item.module }} import {{ item.name }}
{% endfor %}{% for item in enum_modules %}from .{{ item.module }} import {{ item.name }}
{% endfor %}{% for item in alias_modules %}from .{{ item.module }} import {{ item.name }}
{% endfor %}{% if client.include_block_models %}from ._block import (
{{ block_exports | join(',\n ') }}
)
{% endif %}{% if client.include_ledger_state_delta_models %}from ._ledger_state_delta import (
{{ ledger_state_delta_exports | join(',\n ') }}
)
{% endif %}

__all__ = [
Expand Down
Loading
Loading