Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions api/oas-generator/src/oas_generator/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,8 @@ def _build_model(self, entry: SchemaEntry) -> ctx.ModelDescriptor: # noqa: C901

for prop_name in sorted(properties):
prop_schema = properties[prop_name] or {}
wire_name = prop_schema.get("x-algokit-field-rename") or prop_name
wire_name = prop_name
python_name_hint = prop_schema.get("x-algokit-field-rename") or prop_name
type_info = self.resolver.resolve(prop_schema, hint=entry.python_name + self.sanitizer.pascal(prop_name))
if type_info.is_signed_transaction:
self.uses_signed_transaction = True
Expand All @@ -250,7 +251,7 @@ def _build_model(self, entry: SchemaEntry) -> ctx.ModelDescriptor: # noqa: C901
annotation = f"{annotation} | None"
annotation = self._apply_forward_reference_annotation(annotation, entry, type_info)
field = ctx.ModelField(
name=self.sanitizer.snake(wire_name),
name=self.sanitizer.snake(python_name_hint),
wire_name=wire_name,
type_hint=annotation,
required=prop_name in required,
Expand Down Expand Up @@ -693,4 +694,5 @@ def build_client_descriptor(
uses_signed_transaction=uses_signed_txn,
uses_msgpack=operation_builder.uses_msgpack,
include_block_models=operation_builder.uses_block_models,
include_ledger_state_delta_models="LedgerStateDelta" in registry.entries,
)
1 change: 1 addition & 0 deletions api/oas-generator/src/oas_generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ class ClientDescriptor:
uses_signed_transaction: bool = False
uses_msgpack: bool = False
include_block_models: bool = False
include_ledger_state_delta_models: bool = False
38 changes: 37 additions & 1 deletion api/oas-generator/src/oas_generator/renderer/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ class TemplateRenderer:
"Block",
"GetBlock",
]
LEDGER_STATE_DELTA_EXPORTS: ClassVar[list[str]] = [
"LedgerTealValue",
"LedgerStateSchema",
"LedgerAppParams",
"LedgerAppLocalState",
"LedgerAppLocalStateDelta",
"LedgerAppParamsDelta",
"LedgerAppResourceRecord",
"LedgerAssetHolding",
"LedgerAssetHoldingDelta",
"LedgerAssetParams",
"LedgerAssetParamsDelta",
"LedgerAssetResourceRecord",
"LedgerVotingData",
"LedgerAccountBaseData",
"LedgerAccountData",
"LedgerBalanceRecord",
"LedgerAccountDeltas",
"LedgerKvValueDelta",
"LedgerIncludedTransactions",
"LedgerModifiedCreatable",
"LedgerAlgoCount",
"LedgerAccountTotals",
"LedgerStateDelta",
]

def __init__(self, template_dir: Path | None = None) -> None:
if template_dir:
Expand Down Expand Up @@ -56,6 +81,8 @@ def render(self, client: ctx.ClientDescriptor, config: GeneratorConfig) -> dict[
files[models_dir / "__init__.py"] = self._render_template("models/__init__.py.j2", context)
files[models_dir / "_serde_helpers.py"] = self._render_template("models/_serde_helpers.py.j2", context)
for model in context["client"].models:
if context["client"].include_ledger_state_delta_models and model.name == "LedgerStateDelta":
continue
model_context = {**context, "model": model}
files[models_dir / f"{model.module_name}.py"] = self._render_template("models/model.py.j2", model_context)
for enum in context["client"].enums:
Expand All @@ -67,7 +94,11 @@ def render(self, client: ctx.ClientDescriptor, config: GeneratorConfig) -> dict[
"models/type_alias.py.j2", alias_context
)
if client.include_block_models:
files[models_dir / "block.py"] = self._render_template("models/block.py.j2", context)
files[models_dir / "_block.py"] = self._render_template("models/block.py.j2", context)
if client.include_ledger_state_delta_models:
files[models_dir / "_ledger_state_delta.py"] = self._render_template(
"models/ledger_state_delta.py.j2", context
)
files[target / "py.typed"] = ""
return files

Expand All @@ -85,6 +116,10 @@ def _build_context(self, client: ctx.ClientDescriptor, config: GeneratorConfig)
for name in self.BLOCK_MODEL_EXPORTS:
if name not in model_exports:
model_exports.append(name)
if client.include_ledger_state_delta_models:
for name in self.LEDGER_STATE_DELTA_EXPORTS:
if name not in model_exports:
model_exports.append(name)
metadata_usage = self._collect_metadata_usage(client)
model_modules = [{"module": model.module_name, "name": model.name} for model in client.models]
enum_modules = [{"module": enum.module_name, "name": enum.name} for enum in client.enums]
Expand All @@ -105,6 +140,7 @@ def _build_context(self, client: ctx.ClientDescriptor, config: GeneratorConfig)
"needs_datetime": any(model.requires_datetime for model in client.models),
"client_needs_datetime": self._client_requires_datetime(client),
"block_exports": self.BLOCK_MODEL_EXPORTS,
"ledger_state_delta_exports": self.LEDGER_STATE_DELTA_EXPORTS,
"needs_literal": needs_literal,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class {{ client.class_name }}:
return response.content
content_type = response.headers.get("content-type", "application/json")
if "msgpack" in content_type:
data = msgpack.unpackb(response.content, raw=False, strict_map_key=False)
data = msgpack.unpackb(response.content, raw=True, strict_map_key=False)
data = self._normalize_msgpack(data)
elif content_type.startswith("application/json"):
data = response.json()
Expand All @@ -266,13 +266,13 @@ class {{ client.class_name }}:
if isinstance(value, dict):
normalized: dict[object, object] = {}
for key, item in value.items():
normalized[self._ensure_str_key(key)] = self._normalize_msgpack(item)
normalized[self._coerce_msgpack_key(key)] = self._normalize_msgpack(item)
return normalized
if isinstance(value, list):
return [self._normalize_msgpack(item) for item in value]
return value

def _ensure_str_key(self, key: object) -> object:
def _coerce_msgpack_key(self, key: object) -> object:
if isinstance(key, bytes):
try:
return key.decode("utf-8")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@


{% if client.uses_signed_transaction %}from algokit_transact.models.signed_transaction import SignedTransaction
{% endif %}{% for item in model_modules %}from .{{ item.module }} import {{ item.name }}
{% endfor %}{% for item in enum_modules %}from .{{ item.module }} import {{ item.name }}
{% endif %}{% for item in model_modules %}{% if not (client.include_ledger_state_delta_models and item.name == "LedgerStateDelta") %}from .{{ item.module }} import {{ item.name }}
{% endif %}{% endfor %}{% for item in enum_modules %}from .{{ item.module }} import {{ item.name }}
{% endfor %}{% for item in alias_modules %}from .{{ item.module }} import {{ item.name }}
{% endfor %}{% if client.include_block_models %}from .block import (
{% endfor %}{% if client.include_block_models %}from ._block import (
{{ block_exports | join(',\n ') }}
)
{% endif %}{% if client.include_ledger_state_delta_models %}from ._ledger_state_delta import (
{{ ledger_state_delta_exports | join(',\n ') }}
)
{% endif %}

__all__ = [
{% for name in model_exports %}"{{ name }}",
{% endfor %}
]

Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ from typing import Callable, TypeAlias, TypeVar

from algokit_common.serde import from_wire, to_wire

T = TypeVar("T")
E = TypeVar("E", bound=Enum)
DecodedT = TypeVar("DecodedT")
EnumValueT = TypeVar("EnumValueT", bound=Enum)
MapKeyT = TypeVar("MapKeyT")
BytesLike: TypeAlias = bytes | bytearray | memoryview


Expand All @@ -36,6 +37,20 @@ def decode_bytes_base64(raw: object) -> bytes:
raise TypeError(f"Unsupported value for bytes field: {type(raw)!r}")


def decode_bytes_map_key(raw: object) -> bytes:
if isinstance(raw, bytes | bytearray | memoryview):
return bytes(raw)
if isinstance(raw, str):
try:
return decode_bytes_base64(raw)
except ValueError:
try:
return raw.encode("utf-8")
except UnicodeEncodeError as fallback_exc:
raise ValueError("Invalid bytes map key") from fallback_exc
raise TypeError(f"Unsupported map key for bytes field: {type(raw)!r}")


def encode_bytes_sequence(values: Iterable[BytesLike | None] | None) -> list[str | None] | None:
if values is None:
return None
Expand Down Expand Up @@ -73,11 +88,11 @@ def encode_model_sequence(values: Iterable[object] | None) -> list[dict[str, obj
return encoded or None


def decode_model_sequence(cls_factory: Callable[[], type[T]], raw: object) -> list[T] | None:
def decode_model_sequence(cls_factory: Callable[[], type[DecodedT]], raw: object) -> list[DecodedT] | None:
if not isinstance(raw, list):
return None
cls = cls_factory()
decoded: list[T] = []
decoded: list[DecodedT] = []
for item in raw:
if isinstance(item, Mapping):
decoded.append(from_wire(cls, item))
Expand All @@ -95,11 +110,11 @@ def encode_enum_sequence(values: Iterable[object] | None) -> list[object] | None
return encoded or None


def decode_enum_sequence(enum_factory: Callable[[], type[E]], raw: object) -> list[E] | None:
def decode_enum_sequence(enum_factory: Callable[[], type[EnumValueT]], raw: object) -> list[EnumValueT] | None:
if not isinstance(raw, list):
return None
enum_cls = enum_factory()
decoded: list[E] = []
decoded: list[EnumValueT] = []
for item in raw:
try:
decoded.append(enum_cls(item))
Expand All @@ -109,7 +124,10 @@ def decode_enum_sequence(enum_factory: Callable[[], type[E]], raw: object) -> li


def encode_model_mapping(
factory: Callable[[], type[T]], mapping: Mapping[str, object] | None
factory: Callable[[], type[DecodedT]],
mapping: Mapping[object, object] | None,
*,
key_encoder: Callable[[object], str] | None = None,
) -> dict[str, object] | None:
if mapping is None:
return None
Expand All @@ -118,35 +136,60 @@ def encode_model_mapping(
for key, value in mapping.items():
if value is None:
continue
encoded_key: str
if key_encoder is not None:
encoded_key = key_encoder(key)
elif isinstance(key, str):
encoded_key = key
else:
encoded_key = str(key)
if isinstance(value, cls) or is_dataclass(value):
encoded[str(key)] = to_wire(value)
encoded[encoded_key] = to_wire(value)
else:
encoded[str(key)] = value
encoded[encoded_key] = value
return encoded or None


def decode_model_mapping(factory: Callable[[], type[T]], raw: object) -> dict[str, T] | None:
def decode_model_mapping(
factory: Callable[[], type[DecodedT]],
raw: object,
*,
key_decoder: Callable[[object], MapKeyT] | None = None,
) -> dict[MapKeyT, DecodedT] | None:
if not isinstance(raw, Mapping):
return None
cls = factory()
decoded: dict[str, T] = {}
decoded: dict[MapKeyT, DecodedT] = {}
for key, value in raw.items():
if isinstance(value, Mapping):
decoded[str(key)] = from_wire(cls, value)
decoded_key = key_decoder(key) if key_decoder is not None else key
decoded[decoded_key] = from_wire(cls, value)
return decoded or None


def decode_optional_bool(raw: object) -> bool | None:
if raw is None:
return None
return bool(raw)


def mapping_encoder(
factory: Callable[[], type[T]],
) -> Callable[[Mapping[str, object] | None], dict[str, object] | None]:
def _encode(mapping: Mapping[str, object] | None) -> dict[str, object] | None:
return encode_model_mapping(factory, mapping)
factory: Callable[[], type[DecodedT]],
*,
key_encoder: Callable[[object], str] | None = None,
) -> Callable[[Mapping[object, object] | None], dict[str, object] | None]:
def _encode(mapping: Mapping[object, object] | None) -> dict[str, object] | None:
return encode_model_mapping(factory, mapping, key_encoder=key_encoder)

return _encode


def mapping_decoder(factory: Callable[[], type[T]]) -> Callable[[object], dict[str, T] | None]:
def _decode(raw: object) -> dict[str, T] | None:
return decode_model_mapping(factory, raw)
def mapping_decoder(
factory: Callable[[], type[DecodedT]],
*,
key_decoder: Callable[[object], MapKeyT] | None = None,
) -> Callable[[object], dict[MapKeyT, DecodedT] | None]:
def _decode(raw: object) -> dict[MapKeyT, DecodedT] | None:
return decode_model_mapping(factory, raw, key_decoder=key_decoder)

return _decode
Loading
Loading