Skip to content

Commit

Permalink
👍 support decoding polymorphic variants
Browse files Browse the repository at this point in the history
  • Loading branch information
akabe committed Aug 9, 2021
1 parent 37cb16f commit 2e09016
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 16 deletions.
11 changes: 9 additions & 2 deletions src/ppx/astmisc.ml
Expand Up @@ -30,14 +30,21 @@ let pint ?loc ?suffix x = Pat.constant ?loc (Const.int ?suffix x)
let eint ?loc ?suffix x = Exp.constant ?loc (Const.int ?suffix x)
let estring ?loc x = Exp.constant ?loc (Const.string x)

let attr_base_type ~deriver ~loc attrs =
let open Ppx_deriving in
match attr ~deriver "t" attrs with
| Some { attr_payload = PTyp core_type; _ } -> core_type
| _ -> Ppx_deriving.raise_errorf ~loc
"ppx_deriving_binary requires [@t: base_type] for variants or polymorphic variants"

let attr_length ~deriver attrs =
Ppx_deriving.(attrs |> attr ~deriver "length" |> Arg.(get_attr ~deriver int))

let attr_length_exn ?loc ~deriver attrs =
let attr_length_exn ~deriver ~loc attrs =
match attr_length ~deriver attrs with
| Some x -> x
| None ->
Ppx_deriving.raise_errorf ?loc
Ppx_deriving.raise_errorf ~loc
"ppx_deriving_binary requires [@length] for string, bytes, list and array"

(** Collect labelled arguments from an expression of
Expand Down
80 changes: 67 additions & 13 deletions src/ppx/decoder.ml
Expand Up @@ -43,7 +43,7 @@ let erecord labels =
labels in
Exp.record label_bindings None

let rec decoder_of_core_type ~deriver typ =
let rec decoder_of_core_type ~deriver ~path typ =
let loc = typ.ptyp_loc in
match typ with
| [%type: unit] -> [%expr fun _ _i -> ((), _i)]
Expand All @@ -56,24 +56,26 @@ let rec decoder_of_core_type ~deriver typ =
decoder_of_string_like_type ~deriver typ
~func:[%expr Ppx_deriving_binary_runtime.Std.bytes_of_binary_bytes]
| [%type: [%t? elt] list] ->
decoder_of_list_like_type ~deriver typ elt
decoder_of_list_like_type ~deriver ~path typ elt
~func:[%expr Ppx_deriving_binary_runtime.Std.list_of_binary_bytes]
| [%type: [%t? elt] array] ->
decoder_of_list_like_type ~deriver typ elt
decoder_of_list_like_type ~deriver ~path typ elt
~func:[%expr Ppx_deriving_binary_runtime.Std.array_of_binary_bytes]
| [%type: [%t? typ] ref] ->
[%expr fun _b _i ->
let _x, _i = [%e decoder_of_core_type ~deriver typ] _b _i in
let _x, _i = [%e decoder_of_core_type ~deriver ~path typ] _b _i in
(ref _x, _i)]
| { ptyp_desc = Ptyp_constr (lid, args); _ } ->
let f = Exp.ident (mknoloc (Ppx_deriving.mangle_lid affix lid.txt)) in
let args' = List.map (decoder_of_core_type ~deriver) args in
let args' = List.map (decoder_of_core_type ~deriver ~path) args in
let fwd = app f args' in
[%expr fun _b _i -> [%e fwd] _b _i]
| { ptyp_desc = Ptyp_tuple typs; _ } ->
decoder_of_tuple ~deriver ~constructor:(fun exps -> Exp.tuple exps) ~loc typs
decoder_of_tuple ~deriver ~path ~constructor:(fun exps -> Exp.tuple exps) ~loc typs
| { ptyp_desc = Ptyp_var name; _ } ->
[%expr ([%e evar ("poly_" ^ name)] : bytes -> int -> _ * int)]
| { ptyp_desc = Ptyp_variant (rows, Closed, None); _ } ->
decoder_of_polymorphic_variant ~deriver ~path ~loc rows typ.ptyp_attributes
(* Errors *)
| _ ->
Ppx_deriving.raise_errorf ~loc "%s cannot be derived for %s"
Expand All @@ -84,34 +86,86 @@ and decoder_of_string_like_type ~deriver ~func t =
let len = Astmisc.attr_length_exn ~loc ~deriver t.ptyp_attributes in
[%expr [%e func] ~n:[%e Astmisc.eint len]]

and decoder_of_list_like_type ~deriver ~func t elt =
let decoder = decoder_of_core_type ~deriver elt in
and decoder_of_list_like_type ~deriver ~path ~func t elt =
let decoder = decoder_of_core_type ~deriver ~path elt in
let loc = t.ptyp_loc in
let len = Astmisc.attr_length_exn ~loc ~deriver t.ptyp_attributes in
[%expr [%e func] ~n:[%e Astmisc.eint len] [%e decoder]]

and decoder_of_tuple ~deriver ~constructor ~loc typs =
and decoder_of_tuple ~deriver ~path ~constructor ~loc typs =
let vars_typs = List.mapi (fun i t -> (sprintf "_%d" i, t)) typs in
let tuple = constructor (List.map (fun (s, _) -> evar s) vars_typs) in
let body =
List.fold_right
(fun (x, t) k' ->
let loc = t.ptyp_loc in
let decoder = decoder_of_core_type ~deriver t in
let decoder = decoder_of_core_type ~deriver ~path t in
[%expr let ([%p pvar x], _i) = [%e decoder] _b _i in [%e k']])
vars_typs [%expr ([%e tuple], _i)] in
[%expr fun _b _i -> [%e body]]

and decoder_of_record ~deriver ~constructor ~loc labels =
and decoder_of_polymorphic_variant
~deriver
~path
~loc
row_fields attrs
=
let base_type = Astmisc.attr_base_type ~deriver ~loc attrs in
Variant.constructors_of_ocaml_row_fields ~deriver row_fields
|> decoder_of_constructors
~deriver ~path ~base_type ~loc
~constructor:Exp.variant

and decoder_of_constructors
~deriver
~path
~base_type
~loc
~constructor
(constrs : Variant.constructor list)
=
let case_of_constructor c =
let loc = c.con_loc in
let name = c.Variant.con_name in
let tag = Pat.constant c.Variant.con_value in
let path' = path ^ "." ^ c.con_name in
let decoder =
match c.con_args with
| `TUPLE typs -> (* C (arg1, ..., argN) *)
decoder_of_tuple
~deriver ~path:path' ~loc typs
~constructor:(function
| [] -> constructor name None
| [_] -> constructor name (Some [%expr _0])
| exps -> constructor name (Some (Exp.tuple exps)))
| `RECORD labels -> (* C { label1: arg1; ...; labelN: argN; } *)
decoder_of_record
~deriver ~path:path' ~loc:c.con_loc labels
~constructor:(fun args -> constructor name (Some args))
in
Exp.case tag [%expr [%e decoder] _b _i]
in
let cases = List.map case_of_constructor constrs in
let err_mesg = Astmisc.estring ~loc path in
let raise_ = [%expr raise (Ppx_deriving_binary_runtime.Std.Parse_error [%e err_mesg])] in
let cases = cases @ [Exp.case [%pat? _] raise_] in
let base_decoder = decoder_of_core_type ~deriver ~path base_type in
[%expr
fun _b _i ->
let _x, _i = [%e base_decoder] _b _i in
[%e Exp.match_ [%expr _x] cases]]

and decoder_of_record ~deriver ~path ~constructor ~loc labels =
let record = constructor (erecord labels) in
let body =
List.fold_right
(fun ld k' ->
let loc = ld.pld_type.ptyp_loc in
let decoder =
match attr_decoder ~deriver ld.pld_attributes with
| None -> decoder_of_core_type ~deriver ld.pld_type
| Some (labels, decoder) ->
| None ->
decoder_of_core_type ~deriver ~path ld.pld_type
| Some (labels, decoder) -> (* a decoder function is given by a user. *)
let args = List.map (fun s -> Labelled s, evar s) labels in
Exp.apply decoder args (* apply pre-decoded record fields *)
in
Expand Down
4 changes: 3 additions & 1 deletion src/ppx/ppx_deriving_binary.ml
Expand Up @@ -23,4 +23,6 @@
let () =
let open Ppx_deriving in
register (create "of_binary_bytes" ()
~core_type:(Decoder.decoder_of_core_type ~deriver:"of_binary_bytes"))
~core_type:(Decoder.decoder_of_core_type
~deriver:"of_binary_bytes"
~path:"<abstract>"))
21 changes: 21 additions & 0 deletions src/runtime/std.ml
Expand Up @@ -63,6 +63,17 @@ let binary_bytes_of_int32le = BytesBuffer.add_int32le
let binary_bytes_of_int64be = BytesBuffer.add_int64be
let binary_bytes_of_int64le = BytesBuffer.add_int64le

let pp_int8 = Format.pp_print_int
let pp_uint8 = Format.pp_print_int
let pp_int16be = Format.pp_print_int
let pp_int16le = Format.pp_print_int
let pp_uint16be = Format.pp_print_int
let pp_uint16le = Format.pp_print_int
let pp_int32be ppf = Format.fprintf ppf "%ld"
let pp_int32le ppf = Format.fprintf ppf "%ld"
let pp_int64be ppf = Format.fprintf ppf "%Ld"
let pp_int64le ppf = Format.fprintf ppf "%Ld"

(** {3 Aliases of [int]} *)

type int32bei = int
Expand All @@ -80,6 +91,11 @@ let binary_bytes_of_int32lei b x = BytesBuffer.add_int32le b (Int32.of_int x)
let binary_bytes_of_int64bei b x = BytesBuffer.add_int64be b (Int64.of_int x)
let binary_bytes_of_int64lei b x = BytesBuffer.add_int64le b (Int64.of_int x)

let pp_int32bei = Format.pp_print_int
let pp_int32lei = Format.pp_print_int
let pp_int64bei = Format.pp_print_int
let pp_int64lei = Format.pp_print_int

(** {2 Floating-point values} *)

type float32be = float
Expand All @@ -103,6 +119,11 @@ let float64le_of_binary_bytes cs i =
let n, i = int64le_of_binary_bytes cs i in
Int64.float_of_bits n, i

let pp_float32be = Format.pp_print_float
let pp_float32le = Format.pp_print_float
let pp_float64be = Format.pp_print_float
let pp_float64le = Format.pp_print_float

(** {2 Constant-length string-like types} *)

let string_of_binary_bytes ~n b i =
Expand Down
12 changes: 12 additions & 0 deletions tests/ppx/test_of_binary_bytes.ml
Expand Up @@ -82,6 +82,17 @@ let test_of_binary_bytes_list ctxt =
let expected = (['H'; 'e'; 'l'; 'l'; 'o'], 6) in
assert_equal ~ctxt ~printer:[%show: char list * int] expected actual

let test_of_binary_bytes_polymorphic_variant ctxt =
let b = b_ "\x00\x1c" in
let actual = [%of_binary_bytes: [ `A [@value 0x1c] | `B of uint8 * uint16le ] [@t: uint8]] b 1 in
let expected = (`A, 2) in
assert_equal ~ctxt ~printer:[%show: [ `A | `B of uint8 * uint16le ] * int] expected actual
;
let b = b_ "\x00\x01\x11\x22\x33" in
let actual = [%of_binary_bytes: [ `A [@value 0x1c] | `B of uint8 * uint16le ] [@t: uint8]] b 1 in
let expected = (`B (0x11, 0x3322), 5) in
assert_equal ~ctxt ~printer:[%show: [ `A | `B of uint8 * uint16le ] * int] expected actual

let suite =
"of_binary_bytes driver" >::: [
"[%of_binary_bytes: core-type]" >::: [
Expand All @@ -96,5 +107,6 @@ let suite =
"tuple" >:: test_of_binary_bytes_tuple;
"string" >:: test_of_binary_bytes_string;
"list" >:: test_of_binary_bytes_list;
"polymorphic_variant" >:: test_of_binary_bytes_polymorphic_variant;
]
]

0 comments on commit 2e09016

Please sign in to comment.