diff --git a/pgx/src/io_intf.ml b/pgx/src/io_intf.ml index e6c55c0..1b1a870 100644 --- a/pgx/src/io_intf.ml +++ b/pgx/src/io_intf.ml @@ -14,7 +14,7 @@ module type S = sig | Inet of string * int val open_connection : sockaddr -> (in_channel * out_channel) t - val upgrade_ssl : in_channel -> out_channel -> (in_channel * out_channel) t + val upgrade_ssl : [ `Not_supported | `Supported of (in_channel -> out_channel -> (in_channel * out_channel) t) ] val output_char : out_channel -> char -> unit t val output_binary_int : out_channel -> int -> unit t val output_string : out_channel -> string -> unit t diff --git a/pgx/src/pgx.ml b/pgx/src/pgx.ml index 8465d74..fa97b6f 100644 --- a/pgx/src/pgx.ml +++ b/pgx/src/pgx.ml @@ -540,19 +540,26 @@ module Make (Thread : Io) = struct StartupMessage. In this case the StartupMessage and all subsequent data will be SSL-encrypted. To continue after N, send the usual StartupMessage and proceed without encryption. See https://www.postgresql.org/docs/9.3/protocol-flow.html#AEN100021 *) - let msg = Message_out.SSLRequest in - send_message conn msg - >>= fun () -> - flush chan - >>= fun () -> - input_char ichan - >>= (function - | 'S' -> - upgrade_ssl ichan chan - >>= fun (ichan, chan) -> - return { conn with ichan ; chan } - | 'N' -> return conn - | _c -> assert false) + match Io.upgrade_ssl with + | `Not_supported -> return conn + | `Supported upgrade_ssl -> + Stdlib.print_string "Attempting STARTLS\n"; + let msg = Message_out.SSLRequest in + send_message conn msg + >>= fun () -> + flush chan + >>= fun () -> + input_char ichan + >>= (function + | 'S' -> + Stdlib.print_string "Upgrading to TLS\n"; + upgrade_ssl ichan chan + >>= fun (ichan, chan) -> + return { conn with ichan ; chan } + | 'N' -> + Stdlib.print_string "Not upgrading\n"; + return conn + | _c -> assert false) let connect ?host @@ -603,6 +610,7 @@ module Make (Thread : Io) = struct (try Inet (Sys.getenv "PGHOST", port) with | Not_found -> (* use Unix domain socket. *) + Stdlib.print_string "Using Unix socket\n"; let path = sprintf "%s/.s.PGSQL.%d" unix_domain_socket_dir port in Unix path) in diff --git a/pgx_async/src/pgx_async.ml b/pgx_async/src/pgx_async.ml index 5b37381..bdfd6be 100644 --- a/pgx_async/src/pgx_async.ml +++ b/pgx_async/src/pgx_async.ml @@ -82,9 +82,15 @@ module Thread = struct >>| fun (_socket, in_channel, out_channel) -> in_channel, out_channel ;; - let upgrade_ssl in_channel out_channel = - let config = Conduit_async.V1.Conduit_async_ssl.Ssl_config.configure () in - Conduit_async.V1.Conduit_async_ssl.ssl_connect config in_channel out_channel + let upgrade_ssl = + try + let config = Conduit_async.V1.Conduit_async_ssl.Ssl_config.configure () in + Stdlib.print_string "TLS supported\n"; + `Supported (fun in_channel out_channel -> + Conduit_async.V1.Conduit_async_ssl.ssl_connect config in_channel out_channel) + with _ -> + Stdlib.print_string "TLS not supported\n"; + `Not_supported (* The unix getlogin syscall can fail *) let getlogin () = Unix.getuid () |> Unix.Passwd.getbyuid_exn >>| fun { name; _ } -> name