Skip to content

Commit 948706a

Browse files
mjambonclaude
andauthored
atdcpp: add <json repr="object"> for sum types (#489)
Tagged variants are encoded as single-key JSON objects {"Constructor": payload} instead of the default two-element array ["Constructor", payload]. This matches the default Rust/Serde externally-tagged encoding. It also reads naturally in YAML as a single-key mapping, which is one motivation for the feature. Unit variants (no payload) are always encoded as plain strings regardless of the repr annotation. The RapidJSON writer uses StartObject/Key/EndObject instead of StartArray/String/EndArray. The reader checks x.IsObject() and extracts the sole member name via MemberBegin()->name. Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 88e9272 commit 948706a

6 files changed

Lines changed: 226 additions & 26 deletions

File tree

atdcpp/src/lib/Codegen.ml

Lines changed: 74 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,7 +1143,8 @@ let alias_wrapper env name type_expr codegen_type =
11431143
| _ -> []
11441144
11451145
1146-
let case_class env type_name (loc, orig_name, unique_name, an, opt_e) case_classes =
1146+
let case_class env type_name (loc, orig_name, unique_name, an, opt_e)
1147+
case_classes ~json_sum_repr =
11471148
let json_name = Atd.Json.get_json_cons orig_name an in
11481149
match case_classes with
11491150
| Declaration -> (match opt_e with
@@ -1171,6 +1172,7 @@ let case_class env type_name (loc, orig_name, unique_name, an, opt_e) case_clas
11711172
])
11721173
| Definition -> (match opt_e with
11731174
| None ->
1175+
(* Unit variants are always encoded as plain strings. *)
11741176
[
11751177
Line (sprintf "void %s::to_json(const %s &e, rapidjson::Writer<rapidjson::StringBuffer> &writer){" (trans env orig_name) (trans env orig_name));
11761178
Block [
@@ -1179,14 +1181,30 @@ let case_class env type_name (loc, orig_name, unique_name, an, opt_e) case_clas
11791181
Line (sprintf "};");
11801182
]
11811183
| Some e ->
1184+
(* Tagged variants (with payload).
1185+
Array repr (default): ["Constructor", payload]
1186+
Object repr: {"Constructor": payload}
1187+
This is the Rust/Serde default externally-tagged encoding
1188+
and also maps naturally to YAML as a single-key mapping. *)
1189+
let body = match json_sum_repr with
1190+
| Atd.Json.Array ->
1191+
[
1192+
Line (sprintf "writer.StartArray();");
1193+
Line (sprintf "writer.String(\"%s\");" (single_esc json_name));
1194+
Line (sprintf "%se.value, writer);" (json_writer env e));
1195+
Line (sprintf "writer.EndArray();");
1196+
]
1197+
| Atd.Json.Object ->
1198+
[
1199+
Line (sprintf "writer.StartObject();");
1200+
Line (sprintf "writer.Key(\"%s\");" (single_esc json_name));
1201+
Line (sprintf "%se.value, writer);" (json_writer env e));
1202+
Line (sprintf "writer.EndObject();");
1203+
]
1204+
in
11821205
[
11831206
Line (sprintf "void %s::to_json(const %s &e, rapidjson::Writer<rapidjson::StringBuffer> &writer){" (trans env orig_name) (trans env orig_name));
1184-
Block [
1185-
Line (sprintf "writer.StartArray();");
1186-
Line (sprintf "writer.String(\"%s\");" (single_esc json_name));
1187-
Line (sprintf "%se.value, writer);" (json_writer env e));
1188-
Line (sprintf "writer.EndArray();");
1189-
];
1207+
Block body;
11901208
Line("}");
11911209
])
11921210
| _ -> []
@@ -1212,7 +1230,11 @@ let read_cases0 env loc name cases0 sum_repr =
12121230
(struct_name env name |> single_esc))
12131231
]
12141232
1215-
let read_cases1 env loc name cases1 =
1233+
let read_cases1 env loc name cases1 json_sum_repr =
1234+
(* How the payload value is accessed depends on the sum repr:
1235+
Array: x[1] -- value is second element of ["Constructor", payload]
1236+
Object: x["Constructor"] -- value is the field under the constructor key
1237+
(cons is the variable holding the key name) *)
12161238
let ifs =
12171239
cases1
12181240
|> List.map (fun (loc, orig_name, unique_name, an, opt_e) ->
@@ -1222,12 +1244,20 @@ let read_cases1 env loc name cases1 =
12221244
| Some x -> x
12231245
in
12241246
let json_name = Atd.Json.get_json_cons orig_name an in
1247+
let value_expr = match json_sum_repr with
1248+
| Atd.Json.Array ->
1249+
sprintf "%sx[1])}" (json_reader env e)
1250+
| Atd.Json.Object ->
1251+
(* Use the literal key rather than the 'cons' variable for
1252+
clarity, even though both would work after the if-guard. *)
1253+
sprintf "%sx[\"%s\"])}" (json_reader env e) (single_esc json_name)
1254+
in
12251255
Inline [
12261256
Line (sprintf "if (cons == \"%s\")" (single_esc json_name));
12271257
Block [
1228-
Line (sprintf "return Types::%s({%sx[1])});"
1258+
Line (sprintf "return Types::%s({%s);"
12291259
(trans env orig_name)
1230-
(json_reader env e))
1260+
value_expr)
12311261
]
12321262
]
12331263
)
@@ -1238,7 +1268,7 @@ let read_cases1 env loc name cases1 =
12381268
(struct_name env name |> single_esc))
12391269
]
12401270
1241-
let sum_container env loc name cases codegen_type =
1271+
let sum_container env loc name cases codegen_type ~json_sum_repr =
12421272
let cpp_struct_name = struct_name env name in
12431273
let cases0, cases1 =
12441274
List.partition (fun (loc, orig_name, unique_name, an, opt_e) ->
@@ -1248,23 +1278,42 @@ let sum_container env loc name cases codegen_type =
12481278
let cases0_block =
12491279
if cases0 <> [] then
12501280
[
1281+
(* Unit variants are always encoded as plain strings. *)
12511282
Line "if (x.IsString()) {";
12521283
Block (read_cases0 env loc name cases0 Cpp_annot.Variant);
12531284
Line "}";
12541285
]
12551286
else
12561287
[]
12571288
in
1289+
(* Determine how tagged variants are decoded based on the sum repr.
1290+
Array (default): ["Constructor", payload]
1291+
The tag is x[0].GetString() and the payload is x[1].
1292+
Object: {"Constructor": payload}
1293+
This is the Rust/Serde default externally-tagged encoding and also
1294+
maps naturally to YAML. The tag is the sole member name, obtained
1295+
via MemberBegin()->name. *)
12581296
let cases1_block =
12591297
if cases1 <> [] then
1260-
[
1261-
Line "if (x.IsArray() && x.Size() == 2 && x[0].IsString()) {";
1262-
Block [
1263-
Line "std::string cons = x[0].GetString();";
1264-
Inline (read_cases1 env loc name cases1)
1265-
];
1266-
Line "}";
1267-
]
1298+
match json_sum_repr with
1299+
| Atd.Json.Array ->
1300+
[
1301+
Line "if (x.IsArray() && x.Size() == 2 && x[0].IsString()) {";
1302+
Block [
1303+
Line "std::string cons = x[0].GetString();";
1304+
Inline (read_cases1 env loc name cases1 Atd.Json.Array)
1305+
];
1306+
Line "}";
1307+
]
1308+
| Atd.Json.Object ->
1309+
[
1310+
Line "if (x.IsObject() && x.MemberCount() == 1) {";
1311+
Block [
1312+
Line "std::string cons = x.MemberBegin()->name.GetString();";
1313+
Inline (read_cases1 env loc name cases1 Atd.Json.Object)
1314+
];
1315+
Line "}";
1316+
]
12681317
else
12691318
[]
12701319
in
@@ -1318,7 +1367,8 @@ let sum_container env loc name cases codegen_type =
13181367
]
13191368
| _ -> []
13201369
1321-
let sum env loc name cases codegen_type =
1370+
let sum env loc name cases an codegen_type =
1371+
let json_sum_repr = (Atd.Json.get_json_sum an).json_sum_repr in
13221372
let cases =
13231373
List.map (fun (x : variant) ->
13241374
match x with
@@ -1328,10 +1378,10 @@ let sum env loc name cases codegen_type =
13281378
| Inherit _ -> assert false
13291379
) cases
13301380
in
1331-
let case_classes =
1332-
List.map (fun x -> Inline (case_class env name x codegen_type)) cases
1381+
let case_classes =
1382+
List.map (fun x -> Inline (case_class env name x codegen_type ~json_sum_repr)) cases
13331383
in
1334-
let container_class = sum_container env loc name cases codegen_type in
1384+
let container_class = sum_container env loc name cases codegen_type ~json_sum_repr in
13351385
match codegen_type with
13361386
| Declaration ->
13371387
[
@@ -1494,7 +1544,7 @@ let type_def env (def : A.type_def) codegen_type : B.t =
14941544
match e with
14951545
| Sum (loc, cases, an) ->
14961546
(match (Cpp_annot.get_cpp_sumtype_repr an) with
1497-
| Variant -> sum env loc name cases codegen_type
1547+
| Variant -> sum env loc name cases an codegen_type
14981548
| Enum -> enum env loc name cases codegen_type)
14991549
| Record (loc, fields, an) ->
15001550
record env loc name fields an codegen_type

atdcpp/test/atd-input/everything.atd

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,4 +132,15 @@ type null_opt = {
132132
}
133133

134134
type empty_record = {
135-
}
135+
}
136+
137+
(* Test for <json repr="object"> on sum types.
138+
Tagged variants are encoded as single-key JSON objects {"Constructor": payload}
139+
instead of the default two-element array ["Constructor", payload].
140+
This matches the default Rust/Serde externally-tagged encoding and
141+
also maps naturally to YAML (each variant is a single-key mapping). *)
142+
type shape = [
143+
| Circle of float (* radius *)
144+
| Square of float (* side length *)
145+
| Point (* unit variant -- still encoded as a plain string *)
146+
] <json repr="object">

atdcpp/test/cpp-expected/everything_atd.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,75 @@ namespace St {
602602
}
603603

604604

605+
namespace Shape::Types {
606+
607+
608+
void Circle::to_json(const Circle &e, rapidjson::Writer<rapidjson::StringBuffer> &writer){
609+
writer.StartObject();
610+
writer.Key("Circle");
611+
_atd_write_float(e.value, writer);
612+
writer.EndObject();
613+
}
614+
615+
616+
void Square::to_json(const Square &e, rapidjson::Writer<rapidjson::StringBuffer> &writer){
617+
writer.StartObject();
618+
writer.Key("Square");
619+
_atd_write_float(e.value, writer);
620+
writer.EndObject();
621+
}
622+
623+
624+
void Point::to_json(const Point &e, rapidjson::Writer<rapidjson::StringBuffer> &writer){
625+
writer.String("Point");
626+
};
627+
628+
629+
}
630+
631+
632+
namespace Shape {
633+
typedefs::Shape from_json(const rapidjson::Value &x) {
634+
if (x.IsString()) {
635+
if (std::string_view(x.GetString()) == "Point")
636+
return Types::Point();
637+
throw _atd_bad_json("Shape", x);
638+
}
639+
if (x.IsObject() && x.MemberCount() == 1) {
640+
std::string cons = x.MemberBegin()->name.GetString();
641+
if (cons == "Circle")
642+
return Types::Circle({_atd_read_float(x["Circle"])});
643+
if (cons == "Square")
644+
return Types::Square({_atd_read_float(x["Square"])});
645+
throw _atd_bad_json("Shape", x);
646+
}
647+
throw _atd_bad_json("Shape", x);
648+
}
649+
typedefs::Shape from_json_string(const std::string &s) {
650+
rapidjson::Document doc;
651+
doc.Parse(s.c_str());
652+
if (doc.HasParseError()) {
653+
throw AtdException("Failed to parse JSON");
654+
}
655+
return from_json(doc);
656+
}
657+
void to_json(const typedefs::Shape &x, rapidjson::Writer<rapidjson::StringBuffer> &writer) {
658+
std::visit([&writer](auto &&arg) {
659+
using T = std::decay_t<decltype(arg)>;
660+
if constexpr (std::is_same_v<T, Types::Circle>) Types::Circle::to_json(arg, writer);
661+
if constexpr (std::is_same_v<T, Types::Square>) Types::Square::to_json(arg, writer);
662+
if constexpr (std::is_same_v<T, Types::Point>) Types::Point::to_json(arg, writer);
663+
}, x);
664+
}
665+
std::string to_json_string(const typedefs::Shape &x) {
666+
rapidjson::StringBuffer buffer;
667+
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
668+
to_json(x, writer);
669+
return buffer.GetString();
670+
}
671+
}
672+
673+
605674
namespace Kind::Types {
606675

607676

atdcpp/test/cpp-expected/everything_atd.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ struct RecursiveRecord2;
3333
struct RecursiveClass;
3434
struct ThreeLevelNestedListRecord;
3535
struct StructWithRecursiveVariant;
36+
namespace Shape::Types {
37+
struct Circle;
38+
struct Square;
39+
struct Point;
40+
}
3641
namespace Kind::Types {
3742
struct Root;
3843
struct Thing;
@@ -75,6 +80,7 @@ namespace typedefs {
7580
typedef Credentials Credentials;
7681

7782
typedef std::variant<RecursiveVariant::Types::Integer, RecursiveVariant::Types::Rec> RecursiveVariant;
83+
typedef std::variant<Shape::Types::Circle, Shape::Types::Square, Shape::Types::Point> Shape;
7884
typedef std::variant<Kind::Types::Root, Kind::Types::Thing, Kind::Types::WOW, Kind::Types::Amaze> Kind;
7985
typedef std::variant<Frozen::Types::A, Frozen::Types::B> Frozen;
8086

@@ -170,6 +176,33 @@ namespace St {
170176
}
171177

172178

179+
namespace Shape {
180+
namespace Types {
181+
// Original type: shape = [ ... | Circle of ... | ... ]
182+
struct Circle
183+
{
184+
double value;
185+
static void to_json(const Circle &e, rapidjson::Writer<rapidjson::StringBuffer> &writer);
186+
};
187+
// Original type: shape = [ ... | Square of ... | ... ]
188+
struct Square
189+
{
190+
double value;
191+
static void to_json(const Square &e, rapidjson::Writer<rapidjson::StringBuffer> &writer);
192+
};
193+
// Original type: shape = [ ... | Point | ... ]
194+
struct Point {
195+
static void to_json(const Point &e, rapidjson::Writer<rapidjson::StringBuffer> &writer);
196+
};
197+
}
198+
199+
typedefs::Shape from_json(const rapidjson::Value &x);
200+
typedefs::Shape from_json_string(const std::string &s);
201+
void to_json(const typedefs::Shape &x, rapidjson::Writer<rapidjson::StringBuffer> &writer);
202+
std::string to_json_string(const typedefs::Shape &x);
203+
}
204+
205+
173206
namespace Kind {
174207
namespace Types {
175208
// Original type: kind = [ ... | Root | ... ]

atdcpp/test/cpp-tests/test_atdd.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,44 @@ int main() {
138138
}
139139
};
140140

141+
// Test for <json repr="object"> on sum types.
142+
// Tagged variants are encoded as single-key JSON objects {"Constructor": payload}
143+
// instead of the default two-element array ["Constructor", payload].
144+
// This matches the default Rust/Serde externally-tagged encoding and
145+
// also maps naturally to YAML (each variant is a single-key mapping).
146+
tests["sum repr object"] = []() {
147+
// Encoding: tagged variants use {"Constructor": payload}
148+
typedefs::Shape circle = Shape::Types::Circle{3.14};
149+
typedefs::Shape square = Shape::Types::Square{2.0};
150+
typedefs::Shape point = Shape::Types::Point{};
151+
152+
auto circle_json = Shape::to_json_string(circle);
153+
auto square_json = Shape::to_json_string(square);
154+
auto point_json = Shape::to_json_string(point);
155+
156+
if (circle_json != R"({"Circle":3.14})")
157+
throw std::runtime_error("Circle encoding failed: " + circle_json);
158+
if (square_json != R"({"Square":2.0})")
159+
throw std::runtime_error("Square encoding failed: " + square_json);
160+
// Unit variants remain plain strings regardless of repr
161+
if (point_json != R"("Point")")
162+
throw std::runtime_error("Point encoding failed: " + point_json);
163+
164+
// Decoding: round-trip
165+
auto c2 = Shape::from_json_string(circle_json);
166+
auto s2 = Shape::from_json_string(square_json);
167+
auto p2 = Shape::from_json_string(point_json);
168+
169+
if (Shape::to_json_string(c2) != circle_json)
170+
throw std::runtime_error("Circle round-trip failed");
171+
if (Shape::to_json_string(s2) != square_json)
172+
throw std::runtime_error("Square round-trip failed");
173+
if (Shape::to_json_string(p2) != point_json)
174+
throw std::runtime_error("Point round-trip failed");
175+
176+
std::cout << "Test passed: sum repr object" << std::endl;
177+
};
178+
141179
tests["empty record"] = []() {
142180
typedefs::EmptyRecord emptyRecord;
143181
std::string json = "{}";

internal/support_matrix.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ let languages : (string * lang_support) list = [
147147
"atdcpp (C++)", { all_yes with
148148
doc_comments = Planned;
149149
json_repr_object = Planned;
150-
sum_repr_object = Planned;
151150
json_adapter = Planned;
152151
imports = Planned;
153152
open_enums = Planned;

0 commit comments

Comments
 (0)