@@ -1143,7 +1143,8 @@ let alias_wrapper env name type_expr codegen_type =
11431143 | _ -> []
11441144
11451145
1146- let case_class env type_name (loc, orig_name, unique_name, an, opt_e) case_classes =
1146+ let case_class env type_name (loc, orig_name, unique_name, an, opt_e)
1147+ case_classes ~json_sum_repr =
11471148 let json_name = Atd.Json.get_json_cons orig_name an in
11481149 match case_classes with
11491150 | Declaration -> (match opt_e with
@@ -1171,6 +1172,7 @@ let case_class env type_name (loc, orig_name, unique_name, an, opt_e) case_clas
11711172 ])
11721173 | Definition -> (match opt_e with
11731174 | None ->
1175+ (* Unit variants are always encoded as plain strings. *)
11741176 [
11751177 Line (sprintf "void %s::to_json(const %s &e, rapidjson::Writer<rapidjson::StringBuffer> &writer){" (trans env orig_name) (trans env orig_name));
11761178 Block [
@@ -1179,14 +1181,30 @@ let case_class env type_name (loc, orig_name, unique_name, an, opt_e) case_clas
11791181 Line (sprintf "};");
11801182 ]
11811183 | Some e ->
1184+ (* Tagged variants (with payload).
1185+ Array repr (default): ["Constructor", payload]
1186+ Object repr: {"Constructor": payload}
1187+ This is the Rust/Serde default externally-tagged encoding
1188+ and also maps naturally to YAML as a single-key mapping. *)
1189+ let body = match json_sum_repr with
1190+ | Atd.Json.Array ->
1191+ [
1192+ Line (sprintf "writer.StartArray();");
1193+ Line (sprintf "writer.String(\"%s\");" (single_esc json_name));
1194+ Line (sprintf "%se.value, writer);" (json_writer env e));
1195+ Line (sprintf "writer.EndArray();");
1196+ ]
1197+ | Atd.Json.Object ->
1198+ [
1199+ Line (sprintf "writer.StartObject();");
1200+ Line (sprintf "writer.Key(\"%s\");" (single_esc json_name));
1201+ Line (sprintf "%se.value, writer);" (json_writer env e));
1202+ Line (sprintf "writer.EndObject();");
1203+ ]
1204+ in
11821205 [
11831206 Line (sprintf "void %s::to_json(const %s &e, rapidjson::Writer<rapidjson::StringBuffer> &writer){" (trans env orig_name) (trans env orig_name));
1184- Block [
1185- Line (sprintf "writer.StartArray();");
1186- Line (sprintf "writer.String(\"%s\");" (single_esc json_name));
1187- Line (sprintf "%se.value, writer);" (json_writer env e));
1188- Line (sprintf "writer.EndArray();");
1189- ];
1207+ Block body;
11901208 Line("}");
11911209 ])
11921210 | _ -> []
@@ -1212,7 +1230,11 @@ let read_cases0 env loc name cases0 sum_repr =
12121230 (struct_name env name |> single_esc))
12131231 ]
12141232
1215- let read_cases1 env loc name cases1 =
1233+ let read_cases1 env loc name cases1 json_sum_repr =
1234+ (* How the payload value is accessed depends on the sum repr:
1235+ Array: x[1] -- value is second element of ["Constructor", payload]
1236+ Object: x["Constructor"] -- value is the field under the constructor key
1237+ (cons is the variable holding the key name) *)
12161238 let ifs =
12171239 cases1
12181240 |> List.map (fun (loc, orig_name, unique_name, an, opt_e) ->
@@ -1222,12 +1244,20 @@ let read_cases1 env loc name cases1 =
12221244 | Some x -> x
12231245 in
12241246 let json_name = Atd.Json.get_json_cons orig_name an in
1247+ let value_expr = match json_sum_repr with
1248+ | Atd.Json.Array ->
1249+ sprintf "%sx[1])}" (json_reader env e)
1250+ | Atd.Json.Object ->
1251+ (* Use the literal key rather than the 'cons' variable for
1252+ clarity, even though both would work after the if-guard. *)
1253+ sprintf "%sx[\"%s\"])}" (json_reader env e) (single_esc json_name)
1254+ in
12251255 Inline [
12261256 Line (sprintf "if (cons == \"%s\")" (single_esc json_name));
12271257 Block [
1228- Line (sprintf "return Types::%s({%sx[1])} );"
1258+ Line (sprintf "return Types::%s({%s );"
12291259 (trans env orig_name)
1230- (json_reader env e) )
1260+ value_expr )
12311261 ]
12321262 ]
12331263 )
@@ -1238,7 +1268,7 @@ let read_cases1 env loc name cases1 =
12381268 (struct_name env name |> single_esc))
12391269 ]
12401270
1241- let sum_container env loc name cases codegen_type =
1271+ let sum_container env loc name cases codegen_type ~json_sum_repr =
12421272 let cpp_struct_name = struct_name env name in
12431273 let cases0, cases1 =
12441274 List.partition (fun (loc, orig_name, unique_name, an, opt_e) ->
@@ -1248,23 +1278,42 @@ let sum_container env loc name cases codegen_type =
12481278 let cases0_block =
12491279 if cases0 <> [] then
12501280 [
1281+ (* Unit variants are always encoded as plain strings. *)
12511282 Line "if (x.IsString()) {";
12521283 Block (read_cases0 env loc name cases0 Cpp_annot.Variant);
12531284 Line "}";
12541285 ]
12551286 else
12561287 []
12571288 in
1289+ (* Determine how tagged variants are decoded based on the sum repr.
1290+ Array (default): ["Constructor", payload]
1291+ The tag is x[0].GetString() and the payload is x[1].
1292+ Object: {"Constructor": payload}
1293+ This is the Rust/Serde default externally-tagged encoding and also
1294+ maps naturally to YAML. The tag is the sole member name, obtained
1295+ via MemberBegin()->name. *)
12581296 let cases1_block =
12591297 if cases1 <> [] then
1260- [
1261- Line "if (x.IsArray() && x.Size() == 2 && x[0].IsString()) {";
1262- Block [
1263- Line "std::string cons = x[0].GetString();";
1264- Inline (read_cases1 env loc name cases1)
1265- ];
1266- Line "}";
1267- ]
1298+ match json_sum_repr with
1299+ | Atd.Json.Array ->
1300+ [
1301+ Line "if (x.IsArray() && x.Size() == 2 && x[0].IsString()) {";
1302+ Block [
1303+ Line "std::string cons = x[0].GetString();";
1304+ Inline (read_cases1 env loc name cases1 Atd.Json.Array)
1305+ ];
1306+ Line "}";
1307+ ]
1308+ | Atd.Json.Object ->
1309+ [
1310+ Line "if (x.IsObject() && x.MemberCount() == 1) {";
1311+ Block [
1312+ Line "std::string cons = x.MemberBegin()->name.GetString();";
1313+ Inline (read_cases1 env loc name cases1 Atd.Json.Object)
1314+ ];
1315+ Line "}";
1316+ ]
12681317 else
12691318 []
12701319 in
@@ -1318,7 +1367,8 @@ let sum_container env loc name cases codegen_type =
13181367 ]
13191368 | _ -> []
13201369
1321- let sum env loc name cases codegen_type =
1370+ let sum env loc name cases an codegen_type =
1371+ let json_sum_repr = (Atd.Json.get_json_sum an).json_sum_repr in
13221372 let cases =
13231373 List.map (fun (x : variant) ->
13241374 match x with
@@ -1328,10 +1378,10 @@ let sum env loc name cases codegen_type =
13281378 | Inherit _ -> assert false
13291379 ) cases
13301380 in
1331- let case_classes =
1332- List.map (fun x -> Inline (case_class env name x codegen_type)) cases
1381+ let case_classes =
1382+ List.map (fun x -> Inline (case_class env name x codegen_type ~json_sum_repr )) cases
13331383 in
1334- let container_class = sum_container env loc name cases codegen_type in
1384+ let container_class = sum_container env loc name cases codegen_type ~json_sum_repr in
13351385 match codegen_type with
13361386 | Declaration ->
13371387 [
@@ -1494,7 +1544,7 @@ let type_def env (def : A.type_def) codegen_type : B.t =
14941544 match e with
14951545 | Sum (loc, cases, an) ->
14961546 (match (Cpp_annot.get_cpp_sumtype_repr an) with
1497- | Variant -> sum env loc name cases codegen_type
1547+ | Variant -> sum env loc name cases an codegen_type
14981548 | Enum -> enum env loc name cases codegen_type)
14991549 | Record (loc, fields, an) ->
15001550 record env loc name fields an codegen_type
0 commit comments