Skip to content

Commit

Permalink
Make TLS support optional
Browse files Browse the repository at this point in the history
  • Loading branch information
brendanlong committed Apr 29, 2021
1 parent c7829b2 commit fda4530
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
2 changes: 1 addition & 1 deletion pgx/src/io_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module type S = sig
| Inet of string * int

val open_connection : sockaddr -> (in_channel * out_channel) t
val upgrade_ssl : in_channel -> out_channel -> (in_channel * out_channel) t
val upgrade_ssl : [ `Not_supported | `Supported of (in_channel -> out_channel -> (in_channel * out_channel) t) ]
val output_char : out_channel -> char -> unit t
val output_binary_int : out_channel -> int -> unit t
val output_string : out_channel -> string -> unit t
Expand Down
34 changes: 21 additions & 13 deletions pgx/src/pgx.ml
Original file line number Diff line number Diff line change
Expand Up @@ -540,19 +540,26 @@ module Make (Thread : Io) = struct
StartupMessage. In this case the StartupMessage and all subsequent data will be SSL-encrypted. To continue
after N, send the usual StartupMessage and proceed without encryption.
See https://www.postgresql.org/docs/9.3/protocol-flow.html#AEN100021 *)
let msg = Message_out.SSLRequest in
send_message conn msg
>>= fun () ->
flush chan
>>= fun () ->
input_char ichan
>>= (function
| 'S' ->
upgrade_ssl ichan chan
>>= fun (ichan, chan) ->
return { conn with ichan ; chan }
| 'N' -> return conn
| _c -> assert false)
match Io.upgrade_ssl with
| `Not_supported -> return conn
| `Supported upgrade_ssl ->
Stdlib.print_string "Attempting STARTLS\n";
let msg = Message_out.SSLRequest in
send_message conn msg
>>= fun () ->
flush chan
>>= fun () ->
input_char ichan
>>= (function
| 'S' ->
Stdlib.print_string "Upgrading to TLS\n";
upgrade_ssl ichan chan
>>= fun (ichan, chan) ->
return { conn with ichan ; chan }
| 'N' ->
Stdlib.print_string "Not upgrading\n";
return conn
| _c -> assert false)

let connect
?host
Expand Down Expand Up @@ -603,6 +610,7 @@ module Make (Thread : Io) = struct
(try Inet (Sys.getenv "PGHOST", port) with
| Not_found ->
(* use Unix domain socket. *)
Stdlib.print_string "Using Unix socket\n";
let path = sprintf "%s/.s.PGSQL.%d" unix_domain_socket_dir port in
Unix path)
in
Expand Down
12 changes: 9 additions & 3 deletions pgx_async/src/pgx_async.ml
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,15 @@ module Thread = struct
>>| fun (_socket, in_channel, out_channel) -> in_channel, out_channel
;;

let upgrade_ssl in_channel out_channel =
let config = Conduit_async.V1.Conduit_async_ssl.Ssl_config.configure () in
Conduit_async.V1.Conduit_async_ssl.ssl_connect config in_channel out_channel
let upgrade_ssl =
try
let config = Conduit_async.V1.Conduit_async_ssl.Ssl_config.configure () in
Stdlib.print_string "TLS supported\n";
`Supported (fun in_channel out_channel ->
Conduit_async.V1.Conduit_async_ssl.ssl_connect config in_channel out_channel)
with _ ->
Stdlib.print_string "TLS not supported\n";
`Not_supported

(* The unix getlogin syscall can fail *)
let getlogin () = Unix.getuid () |> Unix.Passwd.getbyuid_exn >>| fun { name; _ } -> name
Expand Down

0 comments on commit fda4530

Please sign in to comment.