Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ repos:
- id: cargo-test
name: cargo-nextest
description: "Run 'cargo nextest' and doc tests"
entry: cargo nextest run
entry: cargo t
language: system
types: [rust]
require_serial: true
Expand Down
27 changes: 26 additions & 1 deletion api/oas_generator/rust_oas_generator/parser/oas_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ class Operation:
request_body_supports_msgpack: bool = False
request_body_supports_text_plain: bool = False
has_optional_string: bool = False
# When the original spec had a query param `format` with enum ['msgpack'] only,
# we don't expose it to callers but still need to set it implicitly on requests
force_msgpack_query: bool = False

def __post_init__(self) -> None:
for param in self.parameters:
Expand Down Expand Up @@ -706,6 +709,19 @@ def _parse_operation(
operation_data,
)

# Detect if the original spec constrained `format` to only msgpack
force_msgpack_query = False
for p in operation_data.get("parameters", []) or []:
p_obj = self._resolve_reference(p["$ref"]) if isinstance(p, dict) and "$ref" in p else p
if not isinstance(p_obj, dict):
continue
if p_obj.get("in", "query") == "query" and p_obj.get("name") == "format":
schema_obj = p_obj.get("schema", {}) or {}
enum_vals = schema_obj.get("enum")
if isinstance(enum_vals, list) and len(enum_vals) == 1 and enum_vals[0] == "msgpack":
force_msgpack_query = True
break

parameters = []
for param_data in operation_data.get("parameters", []):
param = self._parse_parameter(param_data)
Expand All @@ -730,6 +746,7 @@ def _parse_operation(
supports_msgpack=supports_msgpack,
request_body_supports_msgpack=request_body_supports_msgpack,
request_body_supports_text_plain=request_body_supports_text_plain,
force_msgpack_query=force_msgpack_query,
)

def _check_request_body_msgpack_support(
Expand Down Expand Up @@ -770,13 +787,21 @@ def _parse_parameter(self, param_data: dict[str, Any]) -> Parameter | None:
if not name:
return None

# Skip `format` query parameter when constrained to msgpack only
in_location = param_data.get("in", "query")
if name == "format" and in_location == "query":
schema_obj = param_data.get("schema", {}) or {}
enum_vals = schema_obj.get("enum")
if isinstance(enum_vals, list) and len(enum_vals) == 1 and enum_vals[0] == "msgpack":
return None

schema = param_data.get("schema", {})
rust_type = rust_type_from_openapi(schema, self.schemas, set())
enum_values = schema.get("enum", []) if schema.get("type") == "string" else []

return Parameter(
name=name,
param_type=param_data.get("in", "query"),
param_type=in_location,
rust_type=rust_type,
required=param_data.get("required", False),
description=param_data.get("description"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub async fn {{ operation.rust_function_name }}(
let path = "{{ rust_path }}".to_string();
{% endif %}

let{% if get_query_parameters(operation) %} mut{% endif %} query_params: HashMap<String, String> = HashMap::new();
let{% if get_query_parameters(operation) or operation.force_msgpack_query %} mut{% endif %} query_params: HashMap<String, String> = HashMap::new();
{% for param in get_query_parameters(operation) %}
{% if param.required %}
{% if param.is_array %}
Expand All @@ -111,6 +111,9 @@ pub async fn {{ operation.rust_function_name }}(
}
{% endif %}
{% endfor %}
{% if operation.force_msgpack_query %}
query_params.insert("format".to_string(), "msgpack".to_string());
{% endif %}

{% if operation.request_body_supports_text_plain %}
let mut headers: HashMap<String, String> = HashMap::new();
Expand Down
4 changes: 4 additions & 0 deletions api/oas_generator/ts_oas_generator/generator/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ class OperationContext:
returns_msgpack: bool = False
has_format_param: bool = False
format_var_name: str | None = None
# When the original spec had a query param `format` with enum ['msgpack'] only,
# we don't expose it to callers but still need to set it implicitly on requests
force_msgpack_query: bool = False
error_types: list[ErrorDescriptor] | None = None

def to_dict(self) -> dict[str, Any]:
Expand All @@ -67,6 +70,7 @@ def to_dict(self) -> dict[str, Any]:
"returnsMsgpack": self.returns_msgpack,
"hasFormatParam": self.has_format_param,
"formatVarName": self.format_var_name,
"forceMsgpackQuery": self.force_msgpack_query,
"errorTypes": [self._error_to_dict(e) for e in (self.error_types or [])],
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def _create_operation_context(self, op_input: OperationInput) -> OperationContex

# Compute additional properties
self._compute_format_param(context)
self._compute_force_msgpack_query(context, op_input.operation, op_input.spec)
self._compute_import_types(context)

return context
Expand All @@ -384,6 +385,14 @@ def _process_parameters(self, params: list[Schema], spec: Schema) -> list[Parame

# Extract parameter details
raw_name = str(param.get("name"))
# Skip `format` query param when it's constrained to only msgpack
location_candidate = param.get(constants.OperationKey.IN, constants.ParamLocation.QUERY)
if location_candidate == constants.ParamLocation.QUERY and raw_name == constants.FORMAT_PARAM_NAME:
schema_obj = param.get("schema", {}) or {}
enum_vals = schema_obj.get(constants.SchemaKey.ENUM)
if isinstance(enum_vals, list) and len(enum_vals) == 1 and enum_vals[0] == "msgpack":
# Endpoint only supports msgpack; do not expose/append `format`
continue
var_name = self._sanitize_variable_name(ts_camel_case(raw_name), used_names)
used_names.add(var_name)

Expand All @@ -397,7 +406,7 @@ def _process_parameters(self, params: list[Schema], spec: Schema) -> list[Parame
else:
stringify_bigint = constants.TypeScriptType.BIGINT in ts_type_str

location = param.get(constants.OperationKey.IN, constants.ParamLocation.QUERY)
location = location_candidate
required = param.get(constants.SchemaKey.REQUIRED, False) or location == constants.ParamLocation.PATH

parameters.append(
Expand Down Expand Up @@ -501,6 +510,24 @@ def _compute_format_param(self, context: OperationContext) -> None:
context.format_var_name = param.var_name
break

def _compute_force_msgpack_query(self, context: OperationContext, raw_operation: Schema, spec: Schema) -> None:
"""Detect if the raw spec constrains query format to only 'msgpack' and mark for implicit query injection."""
params = raw_operation.get(constants.OperationKey.PARAMETERS, []) or []
for param_def in params:
param = (
self._resolve_ref(param_def, spec) if isinstance(param_def, dict) and "$ref" in param_def else param_def
)
if not isinstance(param, dict):
continue
name = param.get("name")
location = param.get(constants.OperationKey.IN, constants.ParamLocation.QUERY)
if location == constants.ParamLocation.QUERY and name == constants.FORMAT_PARAM_NAME:
schema_obj = param.get("schema", {}) or {}
enum_vals = schema_obj.get(constants.SchemaKey.ENUM)
if isinstance(enum_vals, list) and len(enum_vals) == 1 and enum_vals[0] == "msgpack":
context.force_msgpack_query = True
return

def _compute_import_types(self, context: OperationContext) -> None:
"""Collect model types that need importing."""
builtin_types = constants.TS_BUILTIN_TYPES
Expand Down
36 changes: 24 additions & 12 deletions api/oas_generator/ts_oas_generator/templates/apis/service.ts.j2
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { BaseHttpRequest, ApiRequestOptions } from '../core/base-http-request';
import { AlgorandSerializer } from '../core/model-runtime';
import type { BodyFormat } from '../core/model-runtime';
{% if import_types and import_types|length > 0 %}
{% set sorted = import_types | sort %}
import type { {{ sorted | join(', ') }} } from '../models/index';
Expand All @@ -8,13 +9,13 @@ import { {% for t in sorted %}{{ t }}Meta{% if not loop.last %}, {% endif %}{% e

{% macro field_type_meta(type_name) -%}
{%- if type_name in import_types -%}
{ kind: 'model', meta: () => {{ type_name }}Meta }
({ kind: 'model', meta: () => {{ type_name }}Meta } as const)
{%- elif type_name == 'SignedTransaction' -%}
{ kind: 'codec', codecKey: 'SignedTransaction' }
({ kind: 'codec', codecKey: 'SignedTransaction' } as const)
{%- elif type_name == 'Uint8Array' -%}
{ kind: 'scalar', isBytes: true }
({ kind: 'scalar', isBytes: true } as const)
{%- elif type_name == 'bigint' -%}
{ kind: 'scalar', isBigint: true }
({ kind: 'scalar', isBigint: true } as const)
{%- else -%}
null
{%- endif -%}
Expand All @@ -33,10 +34,10 @@ null
null
{%- else -%}
{%- set item = field_type_meta(base) -%}
{%- if item == 'null' -%}
{%- if item == 'null' -%}
null
{%- else -%}
({ name: '{{ base }}[]', kind: 'array', arrayItems: {{ item }} })
({ name: '{{ base }}[]', kind: 'array', arrayItems: {{ item }} } as const)
{%- endif -%}
{%- endif -%}
{%- endmacro %}
Expand All @@ -49,11 +50,11 @@ null
{%- if t in import_types -%}
{{ t }}Meta
{%- elif t == 'SignedTransaction' -%}
({ name: 'SignedTransaction', kind: 'passthrough', codecKey: 'SignedTransaction' })
({ name: 'SignedTransaction', kind: 'passthrough', codecKey: 'SignedTransaction' } as const)
{%- elif t == 'Uint8Array' -%}
({ name: 'Uint8Array', kind: 'passthrough', passThrough: { kind: 'scalar', isBytes: true } })
({ name: 'Uint8Array', kind: 'passthrough', passThrough: { kind: 'scalar', isBytes: true } as const } as const)
{%- elif t == 'bigint' -%}
({ name: 'bigint', kind: 'passthrough', passThrough: { kind: 'scalar', isBigint: true } })
({ name: 'bigint', kind: 'passthrough', passThrough: { kind: 'scalar', isBigint: true } as const } as const)
{%- else -%}
null
{%- endif -%}
Expand All @@ -72,6 +73,14 @@ undefined
export class {{ service_class_name }} {
constructor(public readonly httpRequest: BaseHttpRequest) {}

private static acceptFor(format: BodyFormat): string {
return format === 'json' ? 'application/json' : 'application/msgpack';
}

private static mediaFor(format: BodyFormat): string {
return format === 'json' ? 'application/json' : 'application/msgpack';
}

{% for op in operations %}
{% set is_raw_bytes_body = op.requestBody and op.requestBody.tsType == 'Uint8Array' %}
{{ op.description | ts_doc_comment }}
Expand All @@ -93,8 +102,8 @@ export class {{ service_class_name }} {
): Promise<{{ op.responseTsType }}> {
const headers: Record<string, string> = {};
{% set supports_msgpack = op.returnsMsgpack or (op.requestBody and op.requestBody.supportsMsgpack) %}
const responseFormat: 'json' | 'msgpack' = {% if supports_msgpack %}{% if op.hasFormatParam and op.formatVarName %}(params?.{{ op.formatVarName }} as 'json' | 'msgpack' | undefined) ?? 'msgpack'{% else %}'msgpack'{% endif %}{% else %}'json'{% endif %};
headers['Accept'] = responseFormat === 'json' ? 'application/json' : 'application/msgpack';
const responseFormat: BodyFormat = {% if supports_msgpack %}{% if op.hasFormatParam and op.formatVarName %}(params?.{{ op.formatVarName }} as BodyFormat | undefined) ?? 'msgpack'{% else %}'msgpack'{% endif %}{% else %}'json'{% endif %};
headers['Accept'] = {{ service_class_name }}.acceptFor(responseFormat);

{% if op.requestBody and op.method.upper() not in ['GET', 'HEAD', 'DELETE'] %}
{% if is_raw_bytes_body %}
Expand All @@ -103,7 +112,7 @@ export class {{ service_class_name }} {
headers['Content-Type'] = mediaType;
{% else %}
const bodyMeta = {{ meta_expr(op.requestBody.tsType) }};
const mediaType = bodyMeta ? (responseFormat === 'json' ? 'application/json' : 'application/msgpack') : undefined;
const mediaType = bodyMeta ? {{ service_class_name }}.mediaFor(responseFormat) : undefined;
if (mediaType) headers['Content-Type'] = mediaType;
const serializedBody = bodyMeta && params?.body !== undefined
? AlgorandSerializer.encode(params.body, bodyMeta, responseFormat)
Expand Down Expand Up @@ -131,6 +140,9 @@ export class {{ service_class_name }} {
{%- for p in query_params %}
'{{ p.name }}': {% if p.stringifyBigInt %}(typeof params?.{{ p.varName }} === 'bigint' ? (params!.{{ p.varName }} as bigint).toString() : params?.{{ p.varName }}){% else %}params?.{{ p.varName }}{% endif %},
{%- endfor %}
{%- if op.forceMsgpackQuery %}
'format': 'msgpack',
{%- endif %}
},
headers,
{% if op.requestBody and op.method.upper() not in ['GET', 'HEAD', 'DELETE'] %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,25 @@ export async function request<T>(config: ClientConfig, options: {
headers['Authorization'] = `Basic ${btoa(`${config.username}:${config.password}`)}`;
}

let body: BodyValue | undefined = undefined;
let bodyPayload: BodyInit | undefined = undefined;
if (options.body != null) {
if (options.body instanceof Uint8Array || typeof options.body === 'string') {
body = options.body;
if (options.body instanceof Uint8Array) {
bodyPayload = options.body;
} else if (typeof options.body === 'string') {
bodyPayload = options.body;
} else if (options.mediaType?.includes('msgpack')) {
body = encodeMsgPack(options.body);
bodyPayload = encodeMsgPack(options.body);
} else if (options.mediaType?.includes('json')) {
body = JSON.stringify(options.body);
bodyPayload = JSON.stringify(options.body);
} else {
body = options.body;
bodyPayload = JSON.stringify(options.body);
}
}

const response = await fetch(url.toString(), {
method: options.method,
headers,
body,
body: bodyPayload,
credentials: config.credentials,
});

Expand Down
29 changes: 5 additions & 24 deletions crates/algod_client/src/apis/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,12 @@ impl AlgodClient {
&self,
address: &str,
max: Option<u64>,
format: Option<Format>,
) -> Result<GetPendingTransactionsByAddress, Error> {
let result =
super::get_pending_transactions_by_address::get_pending_transactions_by_address(
self.http_client.as_ref(),
address,
max,
format,
)
.await;

Expand All @@ -224,11 +222,9 @@ impl AlgodClient {
&self,
round: u64,
header_only: Option<bool>,
format: Option<Format>,
) -> Result<GetBlock, Error> {
let result =
super::get_block::get_block(self.http_client.as_ref(), round, header_only, format)
.await;
super::get_block::get_block(self.http_client.as_ref(), round, header_only).await;

result
}
Expand Down Expand Up @@ -440,12 +436,10 @@ impl AlgodClient {
pub async fn get_pending_transactions(
&self,
max: Option<u64>,
format: Option<Format>,
) -> Result<GetPendingTransactions, Error> {
let result = super::get_pending_transactions::get_pending_transactions(
self.http_client.as_ref(),
max,
format,
)
.await;

Expand All @@ -456,30 +450,21 @@ impl AlgodClient {
pub async fn pending_transaction_information(
&self,
txid: &str,
format: Option<Format>,
) -> Result<PendingTransactionResponse, Error> {
let result = super::pending_transaction_information::pending_transaction_information(
self.http_client.as_ref(),
txid,
format,
)
.await;

result
}

/// Get a LedgerStateDelta object for a given round
pub async fn get_ledger_state_delta(
&self,
round: u64,
format: Option<Format>,
) -> Result<LedgerStateDelta, Error> {
let result = super::get_ledger_state_delta::get_ledger_state_delta(
self.http_client.as_ref(),
round,
format,
)
.await;
pub async fn get_ledger_state_delta(&self, round: u64) -> Result<LedgerStateDelta, Error> {
let result =
super::get_ledger_state_delta::get_ledger_state_delta(self.http_client.as_ref(), round)
.await;

result
}
Expand All @@ -488,12 +473,10 @@ impl AlgodClient {
pub async fn get_transaction_group_ledger_state_deltas_for_round(
&self,
round: u64,
format: Option<Format>,
) -> Result<GetTransactionGroupLedgerStateDeltasForRound, Error> {
let result = super::get_transaction_group_ledger_state_deltas_for_round::get_transaction_group_ledger_state_deltas_for_round(
self.http_client.as_ref(),
round,
format,
).await;

result
Expand All @@ -503,12 +486,10 @@ impl AlgodClient {
pub async fn get_ledger_state_delta_for_transaction_group(
&self,
id: &str,
format: Option<Format>,
) -> Result<LedgerStateDelta, Error> {
let result = super::get_ledger_state_delta_for_transaction_group::get_ledger_state_delta_for_transaction_group(
self.http_client.as_ref(),
id,
format,
).await;

result
Expand Down
Loading