Skip to content

Commit e6e7c7a

Browse files
mjambonclaude
andauthored
atdpy: add <json repr="object"> for sum types (externally-tagged encoding) (#487)
* atdpy: add <json repr="object"> for sum types Tagged variants are encoded as single-key JSON objects {"Constructor": payload} instead of the default two-element array ["Constructor", payload]. This matches the default Rust/Serde externally-tagged encoding. It also reads naturally in YAML as a single-key mapping, which is one motivation for the feature. Unit variants (no payload) are always encoded as plain strings regardless of the repr annotation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Update gitignore --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 0fc67a5 commit e6e7c7a

8 files changed

Lines changed: 208 additions & 18 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,6 @@ tmp
3333
# Where we install binaries and other things locally to make them easily
3434
# available for testing (used by atdml)
3535
local/
36+
37+
# Local Claude Code settings
38+
/.claude/settings.local.json

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
#
2-
# Makefile for developer's convenience.
1+
2+
## Makefile for developer's convenience.
33
# Build logic is implemented with dune.
44
#
55

atd-jsonlike.opam

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ bug-reports: "https://github.com/ahrefs/atd/issues"
9191
depends: [
9292
"dune" {>= "3.18"}
9393
"ocaml" {>= "4.08"}
94+
"re" {>= "1.9.0"}
95+
"testo" {>= "0.3.0" & with-test}
9496
"odoc" {with-doc}
9597
]
9698
dev-repo: "git+https://github.com/ahrefs/atd.git"

atdpy/src/lib/Codegen.ml

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ let alias_wrapper env ~class_decorators ~class_doc name type_expr =
11031103
]
11041104
]
11051105
1106-
let case_class env ~class_decorators type_name
1106+
let case_class env ~class_decorators ~json_sum_repr type_name
11071107
(loc, orig_name, unique_name, an, opt_e) =
11081108
let json_name = Atd.Json.get_json_cons orig_name an in
11091109
let case_doc = trans_case_doc_to_docstring loc an in
@@ -1130,6 +1130,8 @@ let case_class env ~class_decorators type_name
11301130
Line "@staticmethod";
11311131
Line "def to_json() -> Any:";
11321132
Block [
1133+
(* Unit variants (no payload) are always encoded as a plain string,
1134+
regardless of the sum repr. *)
11331135
Line (sprintf "return '%s'" (single_esc json_name))
11341136
];
11351137
Line "";
@@ -1140,6 +1142,24 @@ let case_class env ~class_decorators type_name
11401142
]
11411143
]
11421144
| Some e ->
1145+
(* Tagged variants (with payload).
1146+
The encoding depends on the sum-level <json repr="..."> annotation:
1147+
- Array (default): ["Constructor", payload]
1148+
e.g. ["Circle", 3.14]
1149+
- Object: {"Constructor": payload}
1150+
e.g. {"Circle": 3.14}
1151+
This is the default Rust/Serde externally-tagged encoding.
1152+
It also reads naturally in YAML as a single-key mapping. *)
1153+
let to_json_line = match json_sum_repr with
1154+
| Atd.Json.Array ->
1155+
sprintf "return ['%s', %s(self.value)]"
1156+
(single_esc json_name)
1157+
(json_writer env e)
1158+
| Atd.Json.Object ->
1159+
sprintf "return {'%s': %s(self.value)}"
1160+
(single_esc json_name)
1161+
(json_writer env e)
1162+
in
11431163
[
11441164
Inline class_decorators;
11451165
Line (sprintf "class %s:" (trans env unique_name));
@@ -1162,9 +1182,7 @@ let case_class env ~class_decorators type_name
11621182
Line "";
11631183
Line "def to_json(self) -> Any:";
11641184
Block [
1165-
Line (sprintf "return ['%s', %s(self.value)]"
1166-
(single_esc json_name)
1167-
(json_writer env e))
1185+
Line to_json_line
11681186
];
11691187
Line "";
11701188
Line "def to_json_string(self, **kw: Any) -> str:";
@@ -1193,7 +1211,15 @@ let read_cases0 env loc name cases0 =
11931211
(class_name env name |> single_esc))
11941212
]
11951213
1196-
let read_cases1 env loc name cases1 =
1214+
let read_cases1 env loc name cases1 json_sum_repr =
1215+
(* How we read the payload value depends on the sum repr:
1216+
- Array: the value is x[1] (e.g. ["Circle", 3.14])
1217+
- Object: the value is x[cons] (e.g. {"Circle": 3.14}, where
1218+
'cons' holds the single key extracted from the dict) *)
1219+
let value_expr = match json_sum_repr with
1220+
| Atd.Json.Array -> "x[1]"
1221+
| Atd.Json.Object -> "x[cons]"
1222+
in
11971223
let ifs =
11981224
cases1
11991225
|> List.map (fun (loc, orig_name, unique_name, an, opt_e) ->
@@ -1206,9 +1232,10 @@ let read_cases1 env loc name cases1 =
12061232
Inline [
12071233
Line (sprintf "if cons == '%s':" (single_esc json_name));
12081234
Block [
1209-
Line (sprintf "return cls(%s(%s(x[1])))"
1235+
Line (sprintf "return cls(%s(%s(%s)))"
12101236
(trans env unique_name)
1211-
(json_reader env e))
1237+
(json_reader env e)
1238+
value_expr)
12121239
]
12131240
]
12141241
)
@@ -1235,21 +1262,39 @@ let sum_container env ~class_decorators ~class_doc loc name cases an =
12351262
let cases0_block =
12361263
if cases0 <> [] then
12371264
[
1265+
(* Unit variants are always encoded as plain strings regardless of
1266+
the sum repr annotation. *)
12381267
Line "if isinstance(x, str):";
12391268
Block (read_cases0 env loc name cases0)
12401269
]
12411270
else
12421271
[]
12431272
in
1273+
(* Determine how tagged variants (those with a payload) are encoded.
1274+
The <json repr="object"> annotation selects the Rust/Serde-style
1275+
externally-tagged object encoding {"Constructor": payload}, which is
1276+
also clean in YAML. The default is the two-element array encoding
1277+
["Constructor", payload]. *)
1278+
let json_sum_repr = (Atd.Json.get_json_sum an).json_sum_repr in
12441279
let cases1_block =
12451280
if cases1 <> [] then
1246-
[
1247-
Line "if isinstance(x, List) and len(x) == 2:";
1248-
Block [
1249-
Line "cons = x[0]";
1250-
Inline (read_cases1 env loc name cases1)
1251-
]
1252-
]
1281+
match json_sum_repr with
1282+
| Atd.Json.Array ->
1283+
[
1284+
Line "if isinstance(x, List) and len(x) == 2:";
1285+
Block [
1286+
Line "cons = x[0]";
1287+
Inline (read_cases1 env loc name cases1 Atd.Json.Array)
1288+
]
1289+
]
1290+
| Atd.Json.Object ->
1291+
[
1292+
Line "if isinstance(x, dict) and len(x) == 1:";
1293+
Block [
1294+
Line "cons = next(iter(x))";
1295+
Inline (read_cases1 env loc name cases1 Atd.Json.Object)
1296+
]
1297+
]
12531298
else
12541299
[]
12551300
in
@@ -1306,6 +1351,7 @@ let sum_container env ~class_decorators ~class_doc loc name cases an =
13061351
]
13071352
13081353
let sum env ~class_decorators ~class_doc loc name cases an =
1354+
let json_sum_repr = (Atd.Json.get_json_sum an).json_sum_repr in
13091355
let cases =
13101356
List.map (fun (x : variant) ->
13111357
match x with
@@ -1316,7 +1362,7 @@ let sum env ~class_decorators ~class_doc loc name cases an =
13161362
) cases
13171363
in
13181364
let case_classes =
1319-
List.map (fun x -> Inline (case_class env ~class_decorators name x)) cases
1365+
List.map (fun x -> Inline (case_class env ~class_decorators ~json_sum_repr name x)) cases
13201366
|> double_spaced
13211367
in
13221368
let container_class =

atdpy/test/atd-input/everything.atd

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,17 @@ first line of the preformatted block
8686

8787
type alias = int list
8888

89+
(* Test for <json repr="object"> on sum types.
90+
Tagged variants are encoded as single-key JSON objects {"Constructor": payload}
91+
instead of the default two-element array ["Constructor", payload].
92+
This matches the default Rust/Serde externally-tagged encoding and
93+
also maps naturally to YAML (each variant is a single-key mapping). *)
94+
type shape = [
95+
| Circle of float (* radius *)
96+
| Square of float (* side length *)
97+
| Point (* unit variant -- still encoded as a plain string *)
98+
] <json repr="object">
99+
89100
type pair <doc text="Def-level doc"> =
90101
(string * int) <doc text="Type-level doc">
91102

atdpy/test/python-expected/everything.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,100 @@ def to_json_string(self, **kw: Any) -> str:
340340
return json.dumps(self.to_json(), **kw)
341341

342342

343+
@dataclass
344+
class Circle:
345+
"""Original type: shape = [ ... | Circle of ... | ... ]
346+
"""
347+
348+
value: float
349+
350+
@property
351+
def kind(self) -> str:
352+
"""Name of the class representing this variant."""
353+
return 'Circle'
354+
355+
def to_json(self) -> Any:
356+
return {'Circle': _atd_write_float(self.value)}
357+
358+
def to_json_string(self, **kw: Any) -> str:
359+
return json.dumps(self.to_json(), **kw)
360+
361+
362+
@dataclass
363+
class Square:
364+
"""Original type: shape = [ ... | Square of ... | ... ]
365+
"""
366+
367+
value: float
368+
369+
@property
370+
def kind(self) -> str:
371+
"""Name of the class representing this variant."""
372+
return 'Square'
373+
374+
def to_json(self) -> Any:
375+
return {'Square': _atd_write_float(self.value)}
376+
377+
def to_json_string(self, **kw: Any) -> str:
378+
return json.dumps(self.to_json(), **kw)
379+
380+
381+
@dataclass
382+
class Point:
383+
"""Original type: shape = [ ... | Point | ... ]
384+
"""
385+
386+
@property
387+
def kind(self) -> str:
388+
"""Name of the class representing this variant."""
389+
return 'Point'
390+
391+
@staticmethod
392+
def to_json() -> Any:
393+
return 'Point'
394+
395+
def to_json_string(self, **kw: Any) -> str:
396+
return json.dumps(self.to_json(), **kw)
397+
398+
399+
@dataclass
400+
class Shape:
401+
"""Original type: shape = [ ... ]
402+
"""
403+
404+
value: Union[Circle, Square, Point]
405+
406+
@property
407+
def kind(self) -> str:
408+
"""Name of the class representing this variant."""
409+
return self.value.kind
410+
411+
@classmethod
412+
def from_json(cls, x: Any) -> 'Shape':
413+
if isinstance(x, str):
414+
if x == 'Point':
415+
return cls(Point())
416+
_atd_bad_json('Shape', x)
417+
if isinstance(x, dict) and len(x) == 1:
418+
cons = next(iter(x))
419+
if cons == 'Circle':
420+
return cls(Circle(_atd_read_float(x[cons])))
421+
if cons == 'Square':
422+
return cls(Square(_atd_read_float(x[cons])))
423+
_atd_bad_json('Shape', x)
424+
_atd_bad_json('Shape', x)
425+
426+
def to_json(self) -> Any:
427+
return self.value.to_json()
428+
429+
@classmethod
430+
def from_json_string(cls, x: str) -> 'Shape':
431+
return cls.from_json(json.loads(x))
432+
433+
def to_json_string(self, **kw: Any) -> str:
434+
return json.dumps(self.to_json(), **kw)
435+
436+
343437
@dataclass
344438
class Root_:
345439
"""Original type: kind = [ ... | Root | ... ]

atdpy/test/python-tests/test_atdpy.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,5 +289,40 @@ def test_imported_types() -> None:
289289
assert b_obj.priority.value == 5
290290

291291

292+
def test_sum_repr_object() -> None:
293+
"""
294+
Test sum types with <json repr="object">.
295+
296+
With this annotation, tagged variants (those carrying a payload) are
297+
encoded as single-key JSON objects {"Constructor": payload} instead
298+
of the default two-element array ["Constructor", payload].
299+
This matches the default Rust/Serde externally-tagged encoding and
300+
is also natural YAML syntax (a single-key mapping).
301+
Unit variants (no payload) remain plain strings regardless.
302+
"""
303+
# Encoding
304+
assert e.Shape(e.Circle(3.14)).to_json() == {'Circle': 3.14}
305+
assert e.Shape(e.Square(2.0)).to_json() == {'Square': 2.0}
306+
# unit variant: always a plain string, regardless of repr
307+
assert e.Shape(e.Point()).to_json() == 'Point'
308+
309+
# Round-trip via JSON string
310+
for json_str, expected_kind in [
311+
('{"Circle": 1.0}', 'Circle'),
312+
('{"Square": 2.5}', 'Square'),
313+
('"Point"', 'Point'),
314+
]:
315+
obj = e.Shape.from_json_string(json_str)
316+
assert obj.kind == expected_kind
317+
assert obj.to_json_string() == json_str
318+
319+
# Error on unknown constructor
320+
try:
321+
e.Shape.from_json_string('{"Triangle": 3}')
322+
assert False, "Expected ValueError"
323+
except ValueError:
324+
pass
325+
326+
292327
# print updated json
293328
test_everything_to_json()

internal/support_matrix.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ let languages : (string * lang_support) list = [
110110
};
111111
"atdpy (Python)", { all_yes with
112112
wrap = Planned;
113-
sum_repr_object = Planned;
114113
json_adapter = Planned;
115114
open_enums = Planned;
116115
binary_serialization = No;

0 commit comments

Comments
 (0)