@@ -1103,7 +1103,7 @@ let alias_wrapper env ~class_decorators ~class_doc name type_expr =
11031103 ]
11041104 ]
11051105
1106- let case_class env ~class_decorators type_name
1106+ let case_class env ~class_decorators ~json_sum_repr type_name
11071107 (loc, orig_name, unique_name, an, opt_e) =
11081108 let json_name = Atd.Json.get_json_cons orig_name an in
11091109 let case_doc = trans_case_doc_to_docstring loc an in
@@ -1130,6 +1130,8 @@ let case_class env ~class_decorators type_name
11301130 Line " @ staticmethod";
11311131 Line " def to_json () -> Any :";
11321132 Block [
1133+ (* Unit variants (no payload) are always encoded as a plain string,
1134+ regardless of the sum repr. *)
11331135 Line (sprintf " return '% s'" (single_esc json_name))
11341136 ];
11351137 Line "";
@@ -1140,6 +1142,24 @@ let case_class env ~class_decorators type_name
11401142 ]
11411143 ]
11421144 | Some e ->
1145+ (* Tagged variants (with payload).
1146+ The encoding depends on the sum-level <json repr=" ..."> annotation:
1147+ - Array (default): [" Constructor ", payload]
1148+ e.g. [" Circle ", 3.14]
1149+ - Object: {" Constructor ": payload}
1150+ e.g. {" Circle ": 3.14}
1151+ This is the default Rust/Serde externally-tagged encoding.
1152+ It also reads naturally in YAML as a single-key mapping. *)
1153+ let to_json_line = match json_sum_repr with
1154+ | Atd.Json.Array ->
1155+ sprintf " return ['% s', % s(self .value)]"
1156+ (single_esc json_name)
1157+ (json_writer env e)
1158+ | Atd.Json.Object ->
1159+ sprintf " return {'% s' : % s(self .value)}"
1160+ (single_esc json_name)
1161+ (json_writer env e)
1162+ in
11431163 [
11441164 Inline class_decorators;
11451165 Line (sprintf " class % s:" (trans env unique_name));
@@ -1162,9 +1182,7 @@ let case_class env ~class_decorators type_name
11621182 Line "";
11631183 Line " def to_json (self ) -> Any :";
11641184 Block [
1165- Line (sprintf " return ['% s', % s(self .value)]"
1166- (single_esc json_name)
1167- (json_writer env e))
1185+ Line to_json_line
11681186 ];
11691187 Line "";
11701188 Line " def to_json_string (self , ** kw: Any ) -> str :";
@@ -1193,7 +1211,15 @@ let read_cases0 env loc name cases0 =
11931211 (class_name env name |> single_esc))
11941212 ]
11951213
1196- let read_cases1 env loc name cases1 =
1214+ let read_cases1 env loc name cases1 json_sum_repr =
1215+ (* How we read the payload value depends on the sum repr:
1216+ - Array: the value is x[1] (e.g. [" Circle ", 3.14])
1217+ - Object: the value is x[cons] (e.g. {" Circle ": 3.14}, where
1218+ 'cons' holds the single key extracted from the dict) *)
1219+ let value_expr = match json_sum_repr with
1220+ | Atd.Json.Array -> " x[1 ]"
1221+ | Atd.Json.Object -> " x[cons]"
1222+ in
11971223 let ifs =
11981224 cases1
11991225 |> List.map (fun (loc, orig_name, unique_name, an, opt_e) ->
@@ -1206,9 +1232,10 @@ let read_cases1 env loc name cases1 =
12061232 Inline [
12071233 Line (sprintf " if cons == '% s':" (single_esc json_name));
12081234 Block [
1209- Line (sprintf " return cls (% s(% s(x [ 1 ] )))"
1235+ Line (sprintf " return cls (% s(% s(% s )))"
12101236 (trans env unique_name)
1211- (json_reader env e))
1237+ (json_reader env e)
1238+ value_expr)
12121239 ]
12131240 ]
12141241 )
@@ -1235,21 +1262,39 @@ let sum_container env ~class_decorators ~class_doc loc name cases an =
12351262 let cases0_block =
12361263 if cases0 <> [] then
12371264 [
1265+ (* Unit variants are always encoded as plain strings regardless of
1266+ the sum repr annotation. *)
12381267 Line " if isinstance (x , str ):";
12391268 Block (read_cases0 env loc name cases0)
12401269 ]
12411270 else
12421271 []
12431272 in
1273+ (* Determine how tagged variants (those with a payload) are encoded.
1274+ The <json repr=" object "> annotation selects the Rust/Serde-style
1275+ externally-tagged object encoding {" Constructor ": payload}, which is
1276+ also clean in YAML. The default is the two-element array encoding
1277+ [" Constructor ", payload]. *)
1278+ let json_sum_repr = (Atd.Json.get_json_sum an).json_sum_repr in
12441279 let cases1_block =
12451280 if cases1 <> [] then
1246- [
1247- Line " if isinstance (x , List ) and len (x ) == 2 :";
1248- Block [
1249- Line " cons = x [0 ]";
1250- Inline (read_cases1 env loc name cases1)
1251- ]
1252- ]
1281+ match json_sum_repr with
1282+ | Atd.Json.Array ->
1283+ [
1284+ Line " if isinstance (x , List ) and len (x ) == 2 :";
1285+ Block [
1286+ Line " cons = x [0 ]";
1287+ Inline (read_cases1 env loc name cases1 Atd.Json.Array)
1288+ ]
1289+ ]
1290+ | Atd.Json.Object ->
1291+ [
1292+ Line " if isinstance (x , dict ) and len (x ) == 1 :";
1293+ Block [
1294+ Line " cons = next (iter (x ))";
1295+ Inline (read_cases1 env loc name cases1 Atd.Json.Object)
1296+ ]
1297+ ]
12531298 else
12541299 []
12551300 in
@@ -1306,6 +1351,7 @@ let sum_container env ~class_decorators ~class_doc loc name cases an =
13061351 ]
13071352
13081353let sum env ~class_decorators ~class_doc loc name cases an =
1354+ let json_sum_repr = (Atd.Json.get_json_sum an).json_sum_repr in
13091355 let cases =
13101356 List.map (fun (x : variant) ->
13111357 match x with
@@ -1316,7 +1362,7 @@ let sum env ~class_decorators ~class_doc loc name cases an =
13161362 ) cases
13171363 in
13181364 let case_classes =
1319- List.map (fun x -> Inline (case_class env ~class_decorators name x)) cases
1365+ List.map (fun x -> Inline (case_class env ~class_decorators ~json_sum_repr name x)) cases
13201366 |> double_spaced
13211367 in
13221368 let container_class =
0 commit comments