Skip to content

Commit aa4bd06

Browse files
mjambonclaude
andauthored
atdts: add <json repr="object"> for sum types (#488)
Tagged variants are encoded as single-key JSON objects {"Constructor": payload} instead of the default two-element array ["Constructor", payload]. This matches the default Rust/Serde externally-tagged encoding. It also reads naturally in YAML as a single-key mapping, which is one motivation for the feature. Unit variants (no payload) are always encoded as plain strings regardless of the repr annotation. Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent e6e7c7a commit aa4bd06

5 files changed

Lines changed: 157 additions & 13 deletions

File tree

atdts/src/lib/Codegen.ml

Lines changed: 70 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -920,43 +920,70 @@ let make_type_def env (def : A.type_def) : B.t =
920920
| Wrap (loc, e, an) -> assert false
921921
| Tvar _ -> assert false
922922

923-
let read_case env loc orig_name an opt_e =
923+
let read_case env loc orig_name an opt_e ~json_sum_repr =
924924
let json_name = Atd.Json.get_json_cons orig_name an in
925925
match opt_e with
926926
| None ->
927+
(* Unit variants are always plain strings, regardless of sum repr. *)
927928
[
928929
Line (sprintf "case '%s':" (single_esc json_name));
929930
Block [
930931
Line (sprintf "return { kind: '%s' }" (single_esc orig_name))
931932
]
932933
]
933934
| Some e ->
935+
(* Tagged variants (with payload).
936+
Array repr: the payload is at x[1] in the two-element array.
937+
Object repr: the payload is the value under the constructor's key
938+
e.g. {"Circle": 3.14} -> x['Circle']
939+
The object encoding matches the Rust/Serde default externally-tagged
940+
format and is also natural YAML syntax. *)
941+
let value_expr = match json_sum_repr with
942+
| Atd.Json.Array ->
943+
sprintf "%s(x[1], x)" (json_reader env e)
944+
| Atd.Json.Object ->
945+
sprintf "%s(x['%s'], x)" (json_reader env e) (single_esc json_name)
946+
in
934947
[
935948
Line (sprintf "case '%s':" (single_esc json_name));
936949
Block [
937-
Line (sprintf "return { kind: '%s', value: %s(x[1], x) }"
950+
Line (sprintf "return { kind: '%s', value: %s }"
938951
(single_esc orig_name)
939-
(json_reader env e))
952+
value_expr)
940953
]
941954
]
942955

943-
let write_case env loc orig_name an opt_e =
956+
let write_case env loc orig_name an opt_e ~json_sum_repr =
944957
let json_name = Atd.Json.get_json_cons orig_name an in
945958
match opt_e with
946959
| None ->
960+
(* Unit variants are always plain strings, regardless of sum repr. *)
947961
[
948962
Line (sprintf "case '%s':" (single_esc orig_name));
949963
Block [
950964
Line (sprintf "return '%s'" (single_esc json_name))
951965
]
952966
]
953967
| Some e ->
968+
(* Tagged variants (with payload).
969+
Array repr (default): ["Constructor", payload]
970+
Object repr: {"Constructor": payload}
971+
This is the Rust/Serde externally-tagged default encoding,
972+
and also reads naturally as a YAML single-key mapping. *)
973+
let return_expr = match json_sum_repr with
974+
| Atd.Json.Array ->
975+
sprintf "return ['%s', %s(x.value, x)]"
976+
(single_esc json_name)
977+
(json_writer env e)
978+
| Atd.Json.Object ->
979+
sprintf "return { '%s': %s(x.value, x) }"
980+
(single_esc json_name)
981+
(json_writer env e)
982+
in
954983
[
955984
Line (sprintf "case '%s':" (single_esc orig_name));
956985
Block [
957-
Line (sprintf "return ['%s', %s(x.value, x)]"
958-
(single_esc json_name)
959-
(json_writer env e))
986+
Line return_expr
960987
]
961988
]
962989

@@ -967,13 +994,15 @@ let read_root_expr env ~ts_type_name e =
967994
let cases0, cases1 =
968995
List.partition (fun (loc, orig_name, an, opt_e) -> opt_e = None) cases
969996
in
997+
(* Determine the encoding for tagged (payload-carrying) variants. *)
998+
let json_sum_repr = (Atd.Json.get_json_sum an).json_sum_repr in
970999
let part0 =
9711000
[
9721001
Line "switch (x) {";
9731002
Block (
9741003
List.map
9751004
(fun (loc, orig_name, an, opt_e) ->
976-
read_case env loc orig_name an opt_e
1005+
read_case env loc orig_name an opt_e ~json_sum_repr
9771006
) cases0
9781007
|> List.flatten
9791008
);
@@ -988,14 +1017,18 @@ let read_root_expr env ~ts_type_name e =
9881017
Line "}";
9891018
]
9901019
in
1020+
(* Build the block that reads tagged variants, switching on encoding. *)
9911021
let part1 =
992-
[
1022+
match json_sum_repr with
1023+
| Atd.Json.Array ->
1024+
(* Default: ["Constructor", payload] *)
1025+
[
9931026
Line "_atd_check_json_tuple(2, x, context)";
9941027
Line "switch (x[0]) {";
9951028
Block (
9961029
List.map
9971030
(fun (loc, orig_name, an, opt_e) ->
998-
read_case env loc orig_name an opt_e
1031+
read_case env loc orig_name an opt_e ~json_sum_repr
9991032
) cases1
10001033
|> List.flatten
10011034
);
@@ -1008,7 +1041,31 @@ let read_root_expr env ~ts_type_name e =
10081041
]
10091042
];
10101043
Line "}";
1011-
]
1044+
]
1045+
| Atd.Json.Object ->
1046+
(* Object encoding: {"Constructor": payload}
1047+
This is the Rust/Serde default externally-tagged encoding
1048+
and reads naturally as a YAML single-key mapping. *)
1049+
[
1050+
Line "const key = Object.keys(x)[0];";
1051+
Line "switch (key) {";
1052+
Block (
1053+
List.map
1054+
(fun (loc, orig_name, an, opt_e) ->
1055+
read_case env loc orig_name an opt_e ~json_sum_repr
1056+
) cases1
1057+
|> List.flatten
1058+
);
1059+
Block [
1060+
Line "default:";
1061+
Block [
1062+
Line (sprintf "_atd_bad_json('%s', x, context)"
1063+
(single_esc ts_type_name));
1064+
Line impossible
1065+
]
1066+
];
1067+
Line "}";
1068+
]
10121069
in
10131070
(match cases0, cases1 with
10141071
| _, [] -> (* pure enum *)
@@ -1090,10 +1147,11 @@ let write_root_expr env ~ts_type_name e =
10901147
match e with
10911148
| Sum (loc, variants, an) ->
10921149
let cases = flatten_variants variants in
1150+
let json_sum_repr = (Atd.Json.get_json_sum an).json_sum_repr in
10931151
[
10941152
Line "switch (x.kind) {";
10951153
Block (List.map (fun (loc, orig_name, an, opt_e) ->
1096-
Inline (write_case env loc orig_name an opt_e)
1154+
Inline (write_case env loc orig_name an opt_e ~json_sum_repr)
10971155
) cases);
10981156
Line "}";
10991157
]

atdts/test/gen-expect-tests/everything.atd

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,17 @@ type root = {
5656

5757
type alias = int list
5858

59+
(* Test for <json repr="object"> on sum types.
60+
Tagged variants are encoded as single-key JSON objects {"Constructor": payload}
61+
instead of the default two-element array ["Constructor", payload].
62+
This matches the default Rust/Serde externally-tagged encoding and
63+
also maps naturally to YAML (each variant is a single-key mapping). *)
64+
type shape = [
65+
| Circle of float (* radius *)
66+
| Square of float (* side length *)
67+
| Point (* unit variant -- still encoded as a plain string *)
68+
] <json repr="object">
69+
5970
type pair = (string * int)
6071

6172
type foo = {

atdts/test/gen-expect-tests/everything.ts.expected

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@ export type Root = {
5858

5959
export type Alias = number /*int*/[]
6060

61+
export type Shape =
62+
| { kind: 'Circle'; value: number }
63+
| { kind: 'Square'; value: number }
64+
| { kind: 'Point' }
65+
6166
export type Pair = [string, number /*int*/]
6267

6368
export type Foo = {
@@ -210,6 +215,41 @@ export function readAlias(x: any, context: any = x): Alias {
210215
return _atd_read_array(_atd_read_int)(x, context);
211216
}
212217

218+
export function writeShape(x: Shape, context: any = x): any {
219+
switch (x.kind) {
220+
case 'Circle':
221+
return { 'Circle': _atd_write_float(x.value, x) }
222+
case 'Square':
223+
return { 'Square': _atd_write_float(x.value, x) }
224+
case 'Point':
225+
return 'Point'
226+
}
227+
}
228+
229+
export function readShape(x: any, context: any = x): Shape {
230+
if (typeof x === 'string') {
231+
switch (x) {
232+
case 'Point':
233+
return { kind: 'Point' }
234+
default:
235+
_atd_bad_json('Shape', x, context)
236+
throw new Error('impossible')
237+
}
238+
}
239+
else {
240+
const key = Object.keys(x)[0];
241+
switch (key) {
242+
case 'Circle':
243+
return { kind: 'Circle', value: _atd_read_float(x['Circle'], x) }
244+
case 'Square':
245+
return { kind: 'Square', value: _atd_read_float(x['Square'], x) }
246+
default:
247+
_atd_bad_json('Shape', x, context)
248+
throw new Error('impossible')
249+
}
250+
}
251+
}
252+
213253
export function writePair(x: Pair, context: any = x): any {
214254
return ((x, context) => [_atd_write_string(x[0], x), _atd_write_int(x[1], x)])(x, context);
215255
}

atdts/test/ts-tests/test_atdts.ts

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,5 +248,41 @@ function test_import_alias() {
248248
assert(ext_types.writeTag(obj2.tag) === "renamed", "tag round-trip failed")
249249
}
250250

251+
function test_sum_repr_object() {
252+
// With <json repr="object">, tagged variants (those carrying a payload)
253+
// are encoded as single-key JSON objects {"Constructor": payload} instead
254+
// of the default two-element array ["Constructor", payload].
255+
// This matches the default Rust/Serde externally-tagged encoding and is
256+
// also natural YAML syntax (a single-key mapping per variant).
257+
// Unit variants (no payload) remain plain strings in all cases.
258+
259+
// Encoding
260+
const circle: API.Shape = { kind: 'Circle', value: 3.14 }
261+
const square: API.Shape = { kind: 'Square', value: 2.0 }
262+
const point: API.Shape = { kind: 'Point' }
263+
264+
assert(JSON.stringify(API.writeShape(circle)) === '{"Circle":3.14}',
265+
'Circle encoding failed')
266+
assert(JSON.stringify(API.writeShape(square)) === '{"Square":2}',
267+
'Square encoding failed')
268+
assert(JSON.stringify(API.writeShape(point)) === '"Point"',
269+
'Point (unit variant) should be a plain string')
270+
271+
// Round-trip decoding
272+
const c2 = API.readShape(JSON.parse('{"Circle":1.0}'))
273+
assert(c2.kind === 'Circle', 'Circle decode: wrong kind')
274+
assert((c2 as {kind: 'Circle'; value: number}).value === 1.0, 'Circle decode: wrong value')
275+
276+
const p2 = API.readShape(JSON.parse('"Point"'))
277+
assert(p2.kind === 'Point', 'Point decode: wrong kind')
278+
279+
// Error on unknown constructor
280+
let threw = false
281+
try { API.readShape(JSON.parse('{"Triangle":3}')) }
282+
catch (_) { threw = true }
283+
assert(threw, 'Expected error for unknown constructor')
284+
}
285+
251286
test_everything()
252287
test_import_alias()
288+
test_sum_repr_object()

internal/support_matrix.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@ let languages : (string * lang_support) list = [
115115
binary_serialization = No;
116116
};
117117
"atdts (TypeScript)", { all_yes with
118-
sum_repr_object = Planned;
119118
json_adapter = Planned;
120119
open_enums = Planned;
121120
binary_serialization = No;

0 commit comments

Comments
 (0)