Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TLS support for Pgx_async #108

Merged
merged 12 commits into from
May 11, 2021
8 changes: 5 additions & 3 deletions dune-project
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
(name pgx)
(synopsis "Pure-OCaml PostgreSQL client library")
(description
"PGX is a pure-OCaml PostgreSQL client library, supporting Async, LWT, or synchronous operations.")
"PGX is a pure-OCaml PostgreSQL client library, supporting Async, LWT, or synchronous operations.")
(depends
(alcotest
(and
Expand Down Expand Up @@ -52,9 +52,9 @@
(package
(name pgx_unix)
(synopsis
"PGX using the standard library's Unix module for IO (synchronous)")
"PGX using the standard library's Unix module for IO (synchronous)")
(description
"PGX using the standard library's Unix module for IO (synchronous)")
"PGX using the standard library's Unix module for IO (synchronous)")
(depends
(alcotest
(and
Expand Down Expand Up @@ -82,10 +82,12 @@
(>= "v0.13.0"))
(async_unix
(>= "v0.13.0"))
async_ssl
(base64
(and
:with-test
(>= 3.0.0)))
conduit-async
(ocaml
(>= 4.08))
(pgx
Expand Down
12 changes: 12 additions & 0 deletions pgx/src/io_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,18 @@ module type S = sig
| Inet of string * int

val open_connection : sockaddr -> (in_channel * out_channel) t

type ssl_config

val upgrade_ssl
: [ `Not_supported
| `Supported of
?ssl_config:ssl_config
-> in_channel
-> out_channel
-> (in_channel * out_channel) t
]

Comment on lines +20 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that the async_ssl dependency is mandatory, will there still be a scenario where attempting to setup a tls connection via conduit will fail because of missing support? I'm guessing with the current setup we might not need the Not_supported branch at all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need the Not_supported branch because Pgx_lwt and Pgx_unix don't support TLS right now. Hopefully at some point we can get this supported for Pgx_lwt, but I doubt we'll ever support it for Pgx_unix.

I'm open to alternative ways of doing this though if there is one?

val output_char : out_channel -> char -> unit t
val output_binary_int : out_channel -> int -> unit t
val output_string : out_channel -> string -> unit t
Expand Down
61 changes: 61 additions & 0 deletions pgx/src/pgx.ml
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ module Message_out = struct
| Describe_portal of portal (* DP *)
| Startup_message of startup
| Simple_query of query
| SSLRequest
[@@deriving sexp]

let add_byte buf i =
Expand Down Expand Up @@ -381,6 +382,10 @@ module Message_out = struct
add_byte msg 0;
None, Buffer.contents msg
| Simple_query q -> Some 'Q', str q
| SSLRequest ->
let msg = Buffer.create 8 in
add_int32 msg 80877103l;
None, Buffer.contents msg
;;
end

Expand Down Expand Up @@ -526,7 +531,59 @@ module Make (Thread : Io) = struct

(*----- Connection. -----*)

let attempt_tls_upgrade ?(ssl = `Auto) ({ ichan; chan; _ } as conn) =
(* To initiate an SSL-encrypted connection, the frontend initially sends an SSLRequest message rather than a
StartupMessage. The server then responds with a single byte containing S or N, indicating that it is willing
or unwilling to perform SSL, respectively. The frontend might close the connection at this point if it is
dissatisfied with the response. To continue after S, perform an SSL startup handshake (not described here,
part of the SSL specification) with the server. If this is successful, continue with sending the usual
StartupMessage. In this case the StartupMessage and all subsequent data will be SSL-encrypted. To continue
after N, send the usual StartupMessage and proceed without encryption.
See https://www.postgresql.org/docs/9.3/protocol-flow.html#AEN100021 *)
match ssl with
| `No -> return conn
| (`Auto | `Always _) as ssl ->
(match Io.upgrade_ssl with
| `Not_supported ->
(match ssl with
| `Always _ ->
failwith
"TLS support is not compiled into this Pgx library but ~ssl was set to \
`Always"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to make this impossible but I'm not sure how.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#suggestion: We can also enforce that the io wrappers for pgx should always pick an ssl option (async_ssl, tls or lwt_ssl), and get rid of this branch. I'd expect that eventually all io wrappers in the pgx repo will support encrypted connections, but if we don't want to support them in a particular backend, we can probably tweak the wrapper itself to not expose any ssl related options in its mli file

| _ -> ());
debug
"TLS-support is not compiled into this Pgx library, not attempting to upgrade"
>>| fun () -> conn
| `Supported upgrade_ssl ->
debug "Request SSL upgrade from server"
>>= fun () ->
let msg = Message_out.SSLRequest in
send_message conn msg
>>= fun () ->
flush chan
>>= fun () ->
input_char ichan
>>= (function
| 'S' ->
debug "Server supports TLS, attempting to upgrade"
>>= fun () ->
let ssl_config =
match ssl with
| `Auto -> None
| `Always ssl_config -> Some ssl_config
Comment on lines +572 to +573
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not completely sure on this interace. My idea was that if people actually care about SSL, they probably want to send a config.

in
upgrade_ssl ?ssl_config ichan chan
>>= fun (ichan, chan) -> return { conn with ichan; chan }
| 'N' -> debug "Server does not support TLS, not upgrading" >>| fun () -> conn
| c ->
fail_msg
"Got unexpected response '%c' from server after SSLRequest message. Response \
should always be 'S' or 'N'."
c))
;;

let connect
?ssl
?host
?port
?user
Expand Down Expand Up @@ -600,6 +657,8 @@ module Make (Thread : Io) = struct
; prepared_num = Int64.of_int 0
}
in
attempt_tls_upgrade ?ssl conn
>>= fun conn ->
(* Send the StartUpMessage. NB. At present we do not support SSL. *)
let msg = Message_out.Startup_message { Message_out.user; database } in
(* Loop around here until the database gives a ReadyForQuery message. *)
Expand Down Expand Up @@ -665,6 +724,7 @@ module Make (Thread : Io) = struct
;;

let with_conn
?ssl
?host
?port
?user
Expand All @@ -676,6 +736,7 @@ module Make (Thread : Io) = struct
f
=
connect
?ssl
?host
?port
?user
Expand Down
3 changes: 2 additions & 1 deletion pgx/src/pgx.mli
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,5 @@ module Value = Pgx_value

module type S = Pgx_intf.S

module Make (Thread : Io) : S with type 'a Io.t = 'a Thread.t
module Make (Thread : Io) :
S with type 'a Io.t = 'a Thread.t and type Io.ssl_config = Thread.ssl_config
7 changes: 5 additions & 2 deletions pgx/src/pgx_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module type S = sig

module Io : sig
type 'a t
type ssl_config

val return : 'a -> 'a t
val ( >>= ) : 'a t -> ('a -> 'b t) -> 'b t
Expand All @@ -22,7 +23,8 @@ module type S = sig
possible denial of service. You may want to set this to a smaller
size to avoid this happening. *)
val connect
: ?host:string
: ?ssl:[ `Auto | `No | `Always of Io.ssl_config ]
-> ?host:string
-> ?port:int
-> ?user:string
-> ?password:string
Expand All @@ -42,7 +44,8 @@ module type S = sig
[close]. This is the preferred way to use this library since it cleans up
after itself. *)
val with_conn
: ?host:string
: ?ssl:[ `Auto | `No | `Always of Io.ssl_config ]
-> ?host:string
-> ?port:int
-> ?user:string
-> ?password:string
Expand Down
2 changes: 2 additions & 0 deletions pgx_async.opam
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ depends: [
"alcotest-async" {with-test & >= "1.0.0"}
"async_kernel" {>= "v0.13.0"}
"async_unix" {>= "v0.13.0"}
"async_ssl"
"base64" {with-test & >= "3.0.0"}
"conduit-async"
"ocaml" {>= "4.08"}
"pgx" {= version}
"pgx_value_core" {= version}
Expand Down
2 changes: 1 addition & 1 deletion pgx_async/src/dune
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ let () = Jbuild_plugin.V1.send @@ {|
(library
(public_name pgx_async)
(wrapped false)
(libraries async_kernel async_unix pgx_value_core)
(libraries async_kernel async_unix conduit-async pgx_value_core)
|} ^ preprocess ^ {|)
|}
32 changes: 21 additions & 11 deletions pgx_async/src/pgx_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,25 @@ module Thread = struct
let close_in = Reader.close

let open_connection sockaddr =
let get_reader_writer socket =
let fd = Socket.fd socket in
Reader.create fd, Writer.create fd
in
match sockaddr with
| Unix path ->
let unix_sockaddr = Tcp.Where_to_connect.of_unix_address (`Unix path) in
Tcp.connect_sock unix_sockaddr >>| get_reader_writer
| Unix path -> Conduit_async.connect (`Unix_domain_socket path)
| Inet (host, port) ->
let inet_sockaddr =
Tcp.Where_to_connect.of_host_and_port (Host_and_port.create ~host ~port)
in
Tcp.connect_sock inet_sockaddr >>| get_reader_writer
Uri.make ~host ~port ()
|> Conduit_async.V3.resolve_uri
>>= Conduit_async.V3.connect
>>| fun (_socket, in_channel, out_channel) -> in_channel, out_channel
;;

type ssl_config = Conduit_async.Ssl.config

let upgrade_ssl =
try
let default_config = Conduit_async.V1.Conduit_async_ssl.Ssl_config.configure () in
`Supported
(fun ?(ssl_config = default_config) in_channel out_channel ->
Conduit_async.V1.Conduit_async_ssl.ssl_connect ssl_config in_channel out_channel)
with
| _ -> `Not_supported
;;

(* The unix getlogin syscall can fail *)
Expand Down Expand Up @@ -130,6 +136,7 @@ let check_pgdatabase =
;;

let connect
?ssl
?host
?port
?user
Expand All @@ -146,6 +153,7 @@ let connect
| None -> Lazy_deferred.force_exn default_unix_domain_socket_dir)
>>= fun unix_domain_socket_dir ->
connect
?ssl
?host
?port
?user
Expand All @@ -158,6 +166,7 @@ let connect
;;

let with_conn
?ssl
?host
?port
?user
Expand All @@ -169,6 +178,7 @@ let with_conn
f
=
connect
?ssl
?host
?port
?user
Expand Down
17 changes: 4 additions & 13 deletions pgx_async/src/pgx_async.mli
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
(** Async based Postgres client based on Pgx. *)
open Async_kernel

include Pgx.S with type 'a Io.t = 'a Deferred.t
include
Pgx.S
with type 'a Io.t = 'a Deferred.t
and type Io.ssl_config = Conduit_async.Ssl.config

(* for testing purposes *)
module Thread : Pgx.Io with type 'a t = 'a Deferred.t

val with_conn
: ?host:string
-> ?port:int
-> ?user:string
-> ?password:string
-> ?database:string
-> ?unix_domain_socket_dir:string
-> ?verbose:int
-> ?max_message_length:int
-> (t -> 'a Deferred.t)
-> 'a Deferred.t
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unnecessary since it's exactly the same interface that we import above in include Pgx.S


(** Like [execute] but returns a pipe so you can operate on the results before they have all returned.
Note that [execute_iter] and [execute_fold] can perform significantly better because they don't have
as much overhead. *)
Expand Down
2 changes: 2 additions & 0 deletions pgx_lwt/src/pgx_lwt.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ module Thread = struct

let close_in = Io.close_in
let open_connection = Io.open_connection
type ssl_config
let upgrade_ssl = `Not_supported
let getlogin = Io.getlogin
let debug s = Logs_lwt.debug (fun m -> m "%s" s)
let protect f ~finally = Lwt.finalize f finally
Expand Down
1 change: 0 additions & 1 deletion pgx_lwt_unix.opam
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ depends: [
"dune" {>= "1.11"}
"alcotest-lwt" {with-test & >= "1.0.0"}
"base64" {with-test & >= "3.0.0"}
"lwt"
"ocaml" {>= "4.08"}
"pgx" {= version}
"pgx_lwt" {= version}
Expand Down
3 changes: 3 additions & 0 deletions pgx_unix/src/pgx_unix.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ module Simple_thread = struct
Unix.open_connection std_socket
;;

type ssl_config

let upgrade_ssl = `Not_supported
let output_char = output_char
let output_binary_int = output_binary_int
let output_string = output_string
Expand Down