Skip to content

Commit

Permalink
Pgx binary value (#83)
Browse files Browse the repository at this point in the history
* Add Pgx_value binary types and remove binary conversion by default

This will allow people to pass binary strings without breaking normal strings.

* Show hints when type conversions fail

* Add tests for #38
  • Loading branch information
brendanlong committed May 8, 2020
1 parent d48deda commit 25a907a
Show file tree
Hide file tree
Showing 9 changed files with 240 additions and 135 deletions.
1 change: 1 addition & 0 deletions dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
(>= 2.0.0)))
(dune
(>= 1.11))
hex
ipaddr
(ocaml
(>= 4.08))
Expand Down
1 change: 1 addition & 0 deletions pgx.opam
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ depends: [
"base64" {with-test & >= "3.0.0"}
"bisect_ppx" {dev & >= "2.0.0"}
"dune" {>= "1.11"}
"hex"
"ipaddr"
"ocaml" {>= "4.08"}
"ppx_custom_printf" {>= "v0.13.0"}
Expand Down
2 changes: 1 addition & 1 deletion pgx/src/dune
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ let () = Jbuild_plugin.V1.send @@ {|
(library
(public_name pgx)
(wrapped false)
(libraries ipaddr uuidm re sexplib0)
(libraries hex ipaddr uuidm re sexplib0)
(preprocess (pps ppx_custom_printf ppx_sexp_conv |} ^ preprocess ^ {|)))
|}
114 changes: 5 additions & 109 deletions pgx/src/pgx.ml
Original file line number Diff line number Diff line change
Expand Up @@ -384,87 +384,6 @@ module Message_out = struct
;;
end

let is_first_oct_digit c = c >= '0' && c <= '3'
let is_oct_digit c = c >= '0' && c <= '7'
let oct_val c = Char.code c - 0x30

let is_hex_digit = function
| '0' .. '9' | 'a' .. 'f' | 'A' .. 'F' -> true
| _ -> false
;;

let hex_val c =
let offset =
match c with
| '0' .. '9' -> 0x30
| 'a' .. 'f' -> 0x57
| 'A' .. 'F' -> 0x37
| _ -> failwith "hex_val"
in
Char.code c - offset
;;

(* Deserialiser for the new 'hex' format introduced in PostgreSQL 9.0. *)
let deserialize_hex str =
let len = String.length str in
let buf = Buffer.create ((len - 2) / 2) in
let i = ref 3 in
while !i < len do
let hi_nibble = str.[!i - 1] in
let lo_nibble = str.[!i] in
i := !i + 2;
if is_hex_digit hi_nibble && is_hex_digit lo_nibble
then (
let byte = (hex_val hi_nibble lsl 4) + hex_val lo_nibble in
Buffer.add_char buf (Char.chr byte))
done;
Buffer.contents buf
;;

(* Deserialiser for the old 'escape' format used in PostgreSQL < 9.0. *)
let deserialize_string_escape str =
let len = String.length str in
let buf = Buffer.create len in
let i = ref 0 in
while !i < len do
let c = str.[!i] in
if c = '\\'
then (
incr i;
if !i < len && str.[!i] = '\\'
then (
Buffer.add_char buf '\\';
incr i)
else if !i + 2 < len
&& is_first_oct_digit str.[!i]
&& is_oct_digit str.[!i + 1]
&& is_oct_digit str.[!i + 2]
then (
let byte = oct_val str.[!i] in
incr i;
let byte = (byte lsl 3) + oct_val str.[!i] in
incr i;
let byte = (byte lsl 3) + oct_val str.[!i] in
incr i;
Buffer.add_char buf (Char.chr byte)))
else (
incr i;
Buffer.add_char buf c)
done;
Buffer.contents buf
;;

(* PostgreSQL 9.0 introduced the new 'hex' format for binary data.
We must therefore check whether the data begins with a magic sequence
that identifies this new format and if so call the appropriate parser;
if it doesn't, then we invoke the parser for the old 'escape' format.
*)
let deserialize_string str =
if String.starts_with str "\\x"
then deserialize_hex str
else deserialize_string_escape str
;;

module Value = Pgx_value

module type Io = Io_intf.S
Expand Down Expand Up @@ -866,23 +785,7 @@ module Make (Thread : Io) = struct
;;

let execute_iter ?(portal = "") { name; conn } ~params ~f =
let encode_unprintable b =
let len = String.length b in
let buf = Buffer.create (len * 2) in
for i = 0 to len - 1 do
let c = b.[i] in
let cc = Char.code c in
if cc < 0x20 || cc > 0x7e
then Buffer.add_string buf (sprintf "\\%03o" cc) (* non-print -> \ooo *)
else if c = '\\'
then Buffer.add_string buf "\\\\" (* \ -> \\ *)
else Buffer.add_char buf c
done;
Buffer.contents buf
in
let params =
List.map (fun s -> Value.to_string s |> Option.map encode_unprintable) params
in
let params = List.map Value.to_string params in
Sequencer.enqueue conn (fun conn ->
send_message conn (Message_out.Bind { Message_out.portal; name; params })
>>= fun () ->
Expand All @@ -907,11 +810,7 @@ module Make (Thread : Io) = struct
| Message_in.CommandComplete _ -> loop ()
| Message_in.EmptyQueryResponse -> loop ()
| Message_in.DataRow fields ->
List.map
(Option.bind (fun v -> deserialize_string v |> Value.of_string))
fields
|> f
>>= loop
List.map (Option.bind Value.of_string) fields |> f >>= loop
| Message_in.NoData -> loop ()
| Message_in.ParameterStatus _ ->
(* 43.2.6: ParameterStatus messages will be generated whenever
Expand All @@ -938,7 +837,7 @@ module Make (Thread : Io) = struct
fail_msg
"Pgx.iter_execute: CopyOutResponse for binary is not implemented yet")
| Message_in.CopyData row ->
f [ row |> deserialize_string |> Value.of_string ] >>= fun () -> loop ()
f [ row |> Value.of_string ] >>= fun () -> loop ()
| Message_in.CopyDone -> loop ()
| m -> fail_msg "Pgx: unknown response message: %s" (Message_in.to_string m)
in
Expand Down Expand Up @@ -1067,13 +966,10 @@ module Make (Thread : Io) = struct
loop acc rows state
| Message_in.Binary ->
fail_msg "Pgx.query: CopyOutResponse for binary is not implemented yet")
| _, Message_in.CopyData row ->
loop acc ([ row |> deserialize_string |> Value.of_string ] :: rows) state
| _, Message_in.CopyData row -> loop acc ([ row |> Value.of_string ] :: rows) state
| _, Message_in.CopyDone -> loop acc rows state
| `Rows, Message_in.DataRow row ->
let row =
List.map (Option.bind (fun v -> deserialize_string v |> Value.of_string)) row
in
let row = List.map (Option.bind Value.of_string) row in
loop acc (row :: rows) `Rows
| (`Row_desc | `Rows), Message_in.CommandComplete _ ->
let rows = List.rev rows in
Expand Down
50 changes: 40 additions & 10 deletions pgx/src/pgx_value.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ type t = string option [@@deriving sexp_of]

exception Conversion_failure of string [@@deriving sexp]

let convert_failure type_ s =
Conversion_failure (Printf.sprintf "Unable to convert to %s: %s" type_ s) |> raise
let convert_failure ?hint type_ s =
let hint =
match hint with
| None -> ""
| Some hint -> Printf.sprintf " (%s)" hint
in
Conversion_failure (Printf.sprintf "Unable to convert to %s%s: %s" type_ hint s)
|> raise
;;

let required f = function
Expand All @@ -17,6 +23,32 @@ let required f = function
let opt = Option.bind
let null = None

let of_binary b =
match b with
| "" -> Some ""
| _ ->
(try
let (`Hex hex) = Hex.of_string b in
Some ("\\x" ^ hex)
with
| exn -> convert_failure ~hint:(Printexc.to_string exn) "binary" b)
;;

let to_binary' = function
| "" -> ""
| t ->
((* Skip if not encoded as hex *)
try
if String.sub t 0 2 <> "\\x"
then t (* Decode if encoded as hex *)
else `Hex (String.sub t 2 (String.length t - 2)) |> Hex.to_string
with
| exn -> convert_failure ~hint:(Printexc.to_string exn) "binary" t)
;;

let to_binary_exn = required to_binary'
let to_binary = Option.map to_binary'

let of_bool = function
| true -> Some "t"
| false -> Some "f"
Expand Down Expand Up @@ -48,7 +80,7 @@ let to_float' t =
| "nan" -> nan
| _ ->
(try float_of_string t with
| Failure _ -> convert_failure "float" t)
| Failure hint -> convert_failure ~hint "float" t)
;;

let to_float_exn = required to_float'
Expand Down Expand Up @@ -159,7 +191,7 @@ let to_inet' =
then addr, if Re.Group.get subs 2 = "." then 32 else 128
else addr, int_of_string mask
with
| _ -> convert_failure "inet" str
| exn -> convert_failure ~hint:(Printexc.to_string exn) "inet" str
;;

let to_inet_exn = required to_inet'
Expand All @@ -168,7 +200,7 @@ let of_int i = Some (string_of_int i)

let to_int' t =
try int_of_string t with
| Failure _ -> convert_failure "int" t
| Failure hint -> convert_failure ~hint "int" t
;;

let to_int_exn = required to_int'
Expand All @@ -177,7 +209,7 @@ let of_int32 i = Some (Int32.to_string i)

let to_int32' t =
try Int32.of_string t with
| Failure _ -> convert_failure "int32" t
| Failure hint -> convert_failure ~hint "int32" t
;;

let to_int32_exn = required to_int32'
Expand All @@ -186,7 +218,7 @@ let of_int64 i = Some (Int64.to_string i)

let to_int64' t =
try Int64.of_string t with
| Failure _ -> convert_failure "int64" t
| Failure hint -> convert_failure ~hint "int64" t
;;

let to_int64_exn = required to_int64'
Expand Down Expand Up @@ -279,9 +311,7 @@ let to_point' =
let subs = Re.exec point_re str in
float_of_string (Re.Group.get subs 1), float_of_string (Re.Group.get subs 2)
with
| e ->
Printexc.to_string e |> print_endline;
convert_failure "point" str
| exn -> convert_failure ~hint:(Printexc.to_string exn) "point" str
;;

let to_point_exn = required to_point'
Expand Down
5 changes: 3 additions & 2 deletions pgx/src/pgx_value.mli
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ include Pgx_value_intf.S
(* Exposed for extending this module *)

(** [convert_failure type_ str] raises [Convert_failure] with a useful
error message. *)
val convert_failure : string -> string -> _
error message. Add [~hint] if there's additional info you can give the
user about the error. *)
val convert_failure : ?hint:string -> string -> string -> _
3 changes: 3 additions & 0 deletions pgx/src/pgx_value_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ module type S = sig
val required : ('a -> 'b) -> 'a option -> 'b
val opt : ('a -> t) -> 'a option -> t
val null : t
val of_binary : string -> t
val to_binary_exn : t -> string
val to_binary : t -> string option
val of_bool : bool -> t
val to_bool_exn : t -> bool
val to_bool : t -> bool option
Expand Down
16 changes: 13 additions & 3 deletions pgx/test/test_pgx_value.ml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,11 @@ let make_test name typ to_value of_value of_value_exn values fail_values =
let value = of_string str in
Alcotest.test_case test_name `Quick
@@ fun () ->
let msg = sprintf "Unable to convert to %s: %s" name str in
Alcotest.check_raises "conversion error" (Conversion_failure msg) (fun () ->
ignore (of_value value)))
try
of_value value |> ignore;
Alcotest.fail "Expected Conversion_failure"
with
| Conversion_failure _ -> ())
fail_values
in
let success_opt_tests =
Expand Down Expand Up @@ -78,6 +80,14 @@ let () =
Alcotest.run
"Pgx.Value"
[ make_test
"binary"
Alcotest.string
of_binary
to_binary
to_binary_exn
[ ""; "normal string"; "string with null\x00 in the midddle"; all_chars ]
[]
; make_test
"bool"
Alcotest.bool
of_bool
Expand Down
Loading

0 comments on commit 25a907a

Please sign in to comment.