diff --git a/dune-project b/dune-project index 74b2144..5f6fef0 100644 --- a/dune-project +++ b/dune-project @@ -37,6 +37,7 @@ (>= 2.0.0))) (dune (>= 1.11)) + hex ipaddr (ocaml (>= 4.08)) diff --git a/pgx.opam b/pgx.opam index 52f135c..41658b8 100644 --- a/pgx.opam +++ b/pgx.opam @@ -14,6 +14,7 @@ depends: [ "base64" {with-test & >= "3.0.0"} "bisect_ppx" {dev & >= "2.0.0"} "dune" {>= "1.11"} + "hex" "ipaddr" "ocaml" {>= "4.08"} "ppx_custom_printf" {>= "v0.13.0"} diff --git a/pgx/src/dune b/pgx/src/dune index 6521e4e..7ee9de8 100644 --- a/pgx/src/dune +++ b/pgx/src/dune @@ -11,6 +11,6 @@ let () = Jbuild_plugin.V1.send @@ {| (library (public_name pgx) (wrapped false) - (libraries ipaddr uuidm re sexplib0) + (libraries hex ipaddr uuidm re sexplib0) (preprocess (pps ppx_custom_printf ppx_sexp_conv |} ^ preprocess ^ {|))) |} diff --git a/pgx/src/pgx.ml b/pgx/src/pgx.ml index b8ffb1e..b7cc6be 100644 --- a/pgx/src/pgx.ml +++ b/pgx/src/pgx.ml @@ -384,87 +384,6 @@ module Message_out = struct ;; end -let is_first_oct_digit c = c >= '0' && c <= '3' -let is_oct_digit c = c >= '0' && c <= '7' -let oct_val c = Char.code c - 0x30 - -let is_hex_digit = function - | '0' .. '9' | 'a' .. 'f' | 'A' .. 'F' -> true - | _ -> false -;; - -let hex_val c = - let offset = - match c with - | '0' .. '9' -> 0x30 - | 'a' .. 'f' -> 0x57 - | 'A' .. 'F' -> 0x37 - | _ -> failwith "hex_val" - in - Char.code c - offset -;; - -(* Deserialiser for the new 'hex' format introduced in PostgreSQL 9.0. *) -let deserialize_hex str = - let len = String.length str in - let buf = Buffer.create ((len - 2) / 2) in - let i = ref 3 in - while !i < len do - let hi_nibble = str.[!i - 1] in - let lo_nibble = str.[!i] in - i := !i + 2; - if is_hex_digit hi_nibble && is_hex_digit lo_nibble - then ( - let byte = (hex_val hi_nibble lsl 4) + hex_val lo_nibble in - Buffer.add_char buf (Char.chr byte)) - done; - Buffer.contents buf -;; - -(* Deserialiser for the old 'escape' format used in PostgreSQL < 9.0. *) -let deserialize_string_escape str = - let len = String.length str in - let buf = Buffer.create len in - let i = ref 0 in - while !i < len do - let c = str.[!i] in - if c = '\\' - then ( - incr i; - if !i < len && str.[!i] = '\\' - then ( - Buffer.add_char buf '\\'; - incr i) - else if !i + 2 < len - && is_first_oct_digit str.[!i] - && is_oct_digit str.[!i + 1] - && is_oct_digit str.[!i + 2] - then ( - let byte = oct_val str.[!i] in - incr i; - let byte = (byte lsl 3) + oct_val str.[!i] in - incr i; - let byte = (byte lsl 3) + oct_val str.[!i] in - incr i; - Buffer.add_char buf (Char.chr byte))) - else ( - incr i; - Buffer.add_char buf c) - done; - Buffer.contents buf -;; - -(* PostgreSQL 9.0 introduced the new 'hex' format for binary data. - We must therefore check whether the data begins with a magic sequence - that identifies this new format and if so call the appropriate parser; - if it doesn't, then we invoke the parser for the old 'escape' format. -*) -let deserialize_string str = - if String.starts_with str "\\x" - then deserialize_hex str - else deserialize_string_escape str -;; - module Value = Pgx_value module type Io = Io_intf.S @@ -866,23 +785,7 @@ module Make (Thread : Io) = struct ;; let execute_iter ?(portal = "") { name; conn } ~params ~f = - let encode_unprintable b = - let len = String.length b in - let buf = Buffer.create (len * 2) in - for i = 0 to len - 1 do - let c = b.[i] in - let cc = Char.code c in - if cc < 0x20 || cc > 0x7e - then Buffer.add_string buf (sprintf "\\%03o" cc) (* non-print -> \ooo *) - else if c = '\\' - then Buffer.add_string buf "\\\\" (* \ -> \\ *) - else Buffer.add_char buf c - done; - Buffer.contents buf - in - let params = - List.map (fun s -> Value.to_string s |> Option.map encode_unprintable) params - in + let params = List.map Value.to_string params in Sequencer.enqueue conn (fun conn -> send_message conn (Message_out.Bind { Message_out.portal; name; params }) >>= fun () -> @@ -907,11 +810,7 @@ module Make (Thread : Io) = struct | Message_in.CommandComplete _ -> loop () | Message_in.EmptyQueryResponse -> loop () | Message_in.DataRow fields -> - List.map - (Option.bind (fun v -> deserialize_string v |> Value.of_string)) - fields - |> f - >>= loop + List.map (Option.bind Value.of_string) fields |> f >>= loop | Message_in.NoData -> loop () | Message_in.ParameterStatus _ -> (* 43.2.6: ParameterStatus messages will be generated whenever @@ -938,7 +837,7 @@ module Make (Thread : Io) = struct fail_msg "Pgx.iter_execute: CopyOutResponse for binary is not implemented yet") | Message_in.CopyData row -> - f [ row |> deserialize_string |> Value.of_string ] >>= fun () -> loop () + f [ row |> Value.of_string ] >>= fun () -> loop () | Message_in.CopyDone -> loop () | m -> fail_msg "Pgx: unknown response message: %s" (Message_in.to_string m) in @@ -1067,13 +966,10 @@ module Make (Thread : Io) = struct loop acc rows state | Message_in.Binary -> fail_msg "Pgx.query: CopyOutResponse for binary is not implemented yet") - | _, Message_in.CopyData row -> - loop acc ([ row |> deserialize_string |> Value.of_string ] :: rows) state + | _, Message_in.CopyData row -> loop acc ([ row |> Value.of_string ] :: rows) state | _, Message_in.CopyDone -> loop acc rows state | `Rows, Message_in.DataRow row -> - let row = - List.map (Option.bind (fun v -> deserialize_string v |> Value.of_string)) row - in + let row = List.map (Option.bind Value.of_string) row in loop acc (row :: rows) `Rows | (`Row_desc | `Rows), Message_in.CommandComplete _ -> let rows = List.rev rows in diff --git a/pgx/src/pgx_value.ml b/pgx/src/pgx_value.ml index da0e668..69c5b19 100644 --- a/pgx/src/pgx_value.ml +++ b/pgx/src/pgx_value.ml @@ -5,8 +5,14 @@ type t = string option [@@deriving sexp_of] exception Conversion_failure of string [@@deriving sexp] -let convert_failure type_ s = - Conversion_failure (Printf.sprintf "Unable to convert to %s: %s" type_ s) |> raise +let convert_failure ?hint type_ s = + let hint = + match hint with + | None -> "" + | Some hint -> Printf.sprintf " (%s)" hint + in + Conversion_failure (Printf.sprintf "Unable to convert to %s%s: %s" type_ hint s) + |> raise ;; let required f = function @@ -17,6 +23,32 @@ let required f = function let opt = Option.bind let null = None +let of_binary b = + match b with + | "" -> Some "" + | _ -> + (try + let (`Hex hex) = Hex.of_string b in + Some ("\\x" ^ hex) + with + | exn -> convert_failure ~hint:(Printexc.to_string exn) "binary" b) +;; + +let to_binary' = function + | "" -> "" + | t -> + ((* Skip if not encoded as hex *) + try + if String.sub t 0 2 <> "\\x" + then t (* Decode if encoded as hex *) + else `Hex (String.sub t 2 (String.length t - 2)) |> Hex.to_string + with + | exn -> convert_failure ~hint:(Printexc.to_string exn) "binary" t) +;; + +let to_binary_exn = required to_binary' +let to_binary = Option.map to_binary' + let of_bool = function | true -> Some "t" | false -> Some "f" @@ -48,7 +80,7 @@ let to_float' t = | "nan" -> nan | _ -> (try float_of_string t with - | Failure _ -> convert_failure "float" t) + | Failure hint -> convert_failure ~hint "float" t) ;; let to_float_exn = required to_float' @@ -159,7 +191,7 @@ let to_inet' = then addr, if Re.Group.get subs 2 = "." then 32 else 128 else addr, int_of_string mask with - | _ -> convert_failure "inet" str + | exn -> convert_failure ~hint:(Printexc.to_string exn) "inet" str ;; let to_inet_exn = required to_inet' @@ -168,7 +200,7 @@ let of_int i = Some (string_of_int i) let to_int' t = try int_of_string t with - | Failure _ -> convert_failure "int" t + | Failure hint -> convert_failure ~hint "int" t ;; let to_int_exn = required to_int' @@ -177,7 +209,7 @@ let of_int32 i = Some (Int32.to_string i) let to_int32' t = try Int32.of_string t with - | Failure _ -> convert_failure "int32" t + | Failure hint -> convert_failure ~hint "int32" t ;; let to_int32_exn = required to_int32' @@ -186,7 +218,7 @@ let of_int64 i = Some (Int64.to_string i) let to_int64' t = try Int64.of_string t with - | Failure _ -> convert_failure "int64" t + | Failure hint -> convert_failure ~hint "int64" t ;; let to_int64_exn = required to_int64' @@ -279,9 +311,7 @@ let to_point' = let subs = Re.exec point_re str in float_of_string (Re.Group.get subs 1), float_of_string (Re.Group.get subs 2) with - | e -> - Printexc.to_string e |> print_endline; - convert_failure "point" str + | exn -> convert_failure ~hint:(Printexc.to_string exn) "point" str ;; let to_point_exn = required to_point' diff --git a/pgx/src/pgx_value.mli b/pgx/src/pgx_value.mli index c1edebe..c32df3b 100644 --- a/pgx/src/pgx_value.mli +++ b/pgx/src/pgx_value.mli @@ -3,5 +3,6 @@ include Pgx_value_intf.S (* Exposed for extending this module *) (** [convert_failure type_ str] raises [Convert_failure] with a useful - error message. *) -val convert_failure : string -> string -> _ + error message. Add [~hint] if there's additional info you can give the + user about the error. *) +val convert_failure : ?hint:string -> string -> string -> _ diff --git a/pgx/src/pgx_value_intf.ml b/pgx/src/pgx_value_intf.ml index a5888e6..8a067fc 100644 --- a/pgx/src/pgx_value_intf.ml +++ b/pgx/src/pgx_value_intf.ml @@ -7,6 +7,9 @@ module type S = sig val required : ('a -> 'b) -> 'a option -> 'b val opt : ('a -> t) -> 'a option -> t val null : t + val of_binary : string -> t + val to_binary_exn : t -> string + val to_binary : t -> string option val of_bool : bool -> t val to_bool_exn : t -> bool val to_bool : t -> bool option diff --git a/pgx/test/test_pgx_value.ml b/pgx/test/test_pgx_value.ml index d108430..101be4e 100644 --- a/pgx/test/test_pgx_value.ml +++ b/pgx/test/test_pgx_value.ml @@ -44,9 +44,11 @@ let make_test name typ to_value of_value of_value_exn values fail_values = let value = of_string str in Alcotest.test_case test_name `Quick @@ fun () -> - let msg = sprintf "Unable to convert to %s: %s" name str in - Alcotest.check_raises "conversion error" (Conversion_failure msg) (fun () -> - ignore (of_value value))) + try + of_value value |> ignore; + Alcotest.fail "Expected Conversion_failure" + with + | Conversion_failure _ -> ()) fail_values in let success_opt_tests = @@ -78,6 +80,14 @@ let () = Alcotest.run "Pgx.Value" [ make_test + "binary" + Alcotest.string + of_binary + to_binary + to_binary_exn + [ ""; "normal string"; "string with null\x00 in the midddle"; all_chars ] + [] + ; make_test "bool" Alcotest.bool of_bool diff --git a/pgx_test/src/pgx_test.ml b/pgx_test/src/pgx_test.ml index 93fee23..4629baf 100644 --- a/pgx_test/src/pgx_test.ml +++ b/pgx_test/src/pgx_test.ml @@ -368,12 +368,11 @@ struct numeric);" >>= fun _ -> let expect_uuid = Uuidm.create `V4 in - let all_chars = String.init 255 char_of_int in let params = let open Pgx.Value in [ of_uuid expect_uuid ; of_int 12 - ; of_string all_chars + ; of_string "asdf" ; of_string "9223372036854775807" ] in @@ -396,7 +395,7 @@ struct (Some expect_uuid) uuid; Alcotest.(check (option int)) "int" (Some 12) int_; - Alcotest.(check (option string)) "string" (Some all_chars) string_; + Alcotest.(check (option string)) "string" (Some "asdf") string_; Alcotest.(check (option string)) "numeric" (Some "9223372036854775807") @@ -406,21 +405,185 @@ struct ; Alcotest_io.test_case "binary string handling" `Quick (fun () -> let all_chars = String.init 255 char_of_int in with_conn (fun db -> - [ "SELECT decode($1, 'base64')", Base64.encode_exn all_chars, all_chars + [ ( "SELECT decode($1, 'base64')::bytea" + , Base64.encode_exn all_chars |> Pgx.Value.of_string + , Pgx.Value.to_binary_exn + , all_chars ) (* Postgres adds whitespace to base64 encodings, so we strip it back out *) - ; ( "SELECT regexp_replace(encode($1, 'base64'), '\\s', '', 'g')" - , all_chars + ; ( "SELECT regexp_replace(encode($1::bytea, 'base64'), '\\s', '', 'g')" + , Pgx.Value.of_binary all_chars + , Pgx.Value.to_string_exn , Base64.encode_exn all_chars ) ] - |> deferred_list_map ~f:(fun (query, param, expect) -> - let params = [ param |> Pgx.Value.of_string ] in + |> deferred_list_map ~f:(fun (query, param, read_f, expect) -> + let params = [ param ] in execute ~params db query >>| function - | [ [ Some actual ] ] -> - Alcotest.(check string) "binary string" expect actual + | [ [ actual ] ] -> + read_f actual |> Alcotest.(check string) "binary string" expect | _ -> assert false)) >>| List.iter (fun () -> ())) + ; Alcotest_io.test_case "binary string round-trip" `Quick (fun () -> + let all_chars = String.init 255 char_of_int in + with_conn (fun db -> + (* This binary string should get encoded as hex and stored as one byte-per-byte of input *) + let params = [ Pgx.Value.of_binary all_chars ] in + (* Checking here that Postgres doesn't throw an exception about null characters in input, since + our encoded input has no null chars *) + execute ~params db "SELECT $1::bytea, octet_length($1::bytea)" + >>| function + | [ [ value; length ] ] -> + Pgx.Value.to_binary_exn value + |> Alcotest.(check string) "binary string contents" all_chars; + (* Our string is 255 bytes so it should be stored as 255 bytes, not as 512 (the length of the + encoded hex). What we're testing here is that we're actually storing binary, not hex + encoded binary *) + Pgx.Value.to_int_exn length + |> Alcotest.(check int) "binary string length" 255 + | _ -> assert false)) + ; Alcotest_io.test_case "Non-binary literal hex string round-trip" `Quick (fun () -> + with_conn (fun db -> + (* This hex string should get inserted into the DB as literally "\x0001etc" *) + let input = + "\\x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe" + in + let params = [ Pgx.Value.of_string input ] in + execute ~params db "SELECT $1::varchar, octet_length($1::varchar)" + >>| function + | [ [ value; length ] ] -> + Pgx.Value.to_string_exn value + |> Alcotest.(check string) "string contents" input; + Pgx.Value.to_int_exn length |> Alcotest.(check int) "string length" 512 + | _ -> assert false)) + ; Alcotest_io.test_case "Binary literal hex string round-trip" `Quick (fun () -> + with_conn (fun db -> + (* This hex string should get double encoded so it makes it into the DB as literally "\x0001etc" *) + let input = + "\\x000102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f303132333435363738393a3b3c3d3e3f404142434445464748494a4b4c4d4e4f505152535455565758595a5b5c5d5e5f606162636465666768696a6b6c6d6e6f707172737475767778797a7b7c7d7e7f808182838485868788898a8b8c8d8e8f909192939495969798999a9b9c9d9e9fa0a1a2a3a4a5a6a7a8a9aaabacadaeafb0b1b2b3b4b5b6b7b8b9babbbcbdbebfc0c1c2c3c4c5c6c7c8c9cacbcccdcecfd0d1d2d3d4d5d6d7d8d9dadbdcdddedfe0e1e2e3e4e5e6e7e8e9eaebecedeeeff0f1f2f3f4f5f6f7f8f9fafbfcfdfe" + in + let params = [ Pgx.Value.of_binary input ] in + execute ~params db "SELECT $1::bytea, octet_length($1::bytea)" + >>| function + | [ [ value; length ] ] -> + Pgx.Value.to_binary_exn value + |> Alcotest.(check string) "string contents" input; + Pgx.Value.to_int_exn length |> Alcotest.(check int) "string length" 512 + | _ -> assert false)) + ; Alcotest_io.test_case "UTF-8 partial round-trip 1" `Quick (fun () -> + (* Select a literal string *) + let expect = "test-ä-test" in + with_conn (fun db -> + simple_query + db + {| + CREATE TEMPORARY TABLE this_test (id text); + INSERT INTO this_test (id) VALUES ('test-ä-test') + |} + >>= fun _ -> + execute db "SELECT id FROM this_test" + >>| function + | [ [ result ] ] -> + Alcotest.(check (option string)) + "" + (Some expect) + (Pgx.Value.to_string result) + | _ -> assert false)) + ; Alcotest_io.test_case "UTF-8 partial round-trip 1 with where" `Quick (fun () -> + (* Select a literal string *) + let expect = "test-ä-test" in + with_conn (fun db -> + simple_query + db + {| + CREATE TEMPORARY TABLE this_test (id text); + INSERT INTO this_test (id) VALUES ('test-ä-test') + |} + >>= fun _ -> + execute + db + ~params:[ Pgx.Value.of_string expect ] + "SELECT id FROM this_test WHERE id = $1" + >>| function + | [ [ result ] ] -> + Alcotest.(check (option string)) + "" + (Some expect) + (Pgx.Value.to_string result) + | [] -> Alcotest.fail "Expected one row but got zero" + | _ -> assert false)) + ; Alcotest_io.test_case "UTF-8 partial round-trip 2" `Quick (fun () -> + (* Insert string as a param, then select back the contents of + the table *) + let expect = "test-ä-test" in + with_conn (fun db -> + simple_query db "CREATE TEMPORARY TABLE this_test (id text)" + >>= fun _ -> + execute + db + ~params:[ Pgx.Value.of_string expect ] + "INSERT INTO this_test (id) VALUES ($1)" + >>= fun _ -> + execute db "SELECT id FROM this_test" + >>| function + | [ [ result ] ] -> + Alcotest.(check (option string)) + "" + (Some expect) + (Pgx.Value.to_string result) + | _ -> assert false)) + ; Alcotest_io.test_case "UTF-8 partial round-trip 3" `Quick (fun () -> + with_conn (fun db -> + simple_query + db + {| + CREATE TEMPORARY TABLE this_test (id text); + INSERT INTO this_test (id) VALUES('test-\303\244-test') + |} + >>= fun _ -> + execute db "SELECT id FROM this_test" + >>| function + | [ [ result ] ] -> + Alcotest.(check string) + "" + {|test-\303\244-test|} + (Pgx.Value.to_string_exn result) + | _ -> assert false)) + ; Alcotest_io.test_case "UTF-8 round-trip" `Quick (fun () -> + (* Select the contents of a param *) + let expect = "test-ä-test" in + with_conn (fun db -> + execute db ~params:[ Pgx.Value.of_string expect ] "SELECT $1::VARCHAR" + >>| function + | [ [ result ] ] -> + Alcotest.(check (option string)) + "" + (Some expect) + (Pgx.Value.to_string result) + | _ -> assert false)) + ; Alcotest_io.test_case "UTF-8 round-trip where" `Quick (fun () -> + (* Insert string as a param, then select back the contents of + the table using a WHERE *) + let expect = "test-ä-test" in + with_conn (fun db -> + simple_query db "CREATE TEMPORARY TABLE this_test (id text)" + >>= fun _ -> + execute + db + ~params:[ Pgx.Value.of_string expect ] + "INSERT INTO this_test (id) VALUES ($1)" + >>= fun _ -> + execute + db + ~params:[ Pgx.Value.of_string expect ] + "SELECT id FROM this_test WHERE id = $1" + >>| function + | [ [ result ] ] -> + Alcotest.(check (option string)) + "" + (Some expect) + (Pgx.Value.to_string result) + | _ -> assert false)) ] in if force_tests || have_pg_config